test_sequence_scorer.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import argparse
  6. import unittest
  7. import tests.utils as test_utils
  8. import torch
  9. from fairseq.sequence_scorer import SequenceScorer
  10. class TestSequenceScorer(unittest.TestCase):
  11. def test_sequence_scorer(self):
  12. # construct dummy dictionary
  13. d = test_utils.dummy_dictionary(vocab_size=2)
  14. self.assertEqual(d.pad(), 1)
  15. self.assertEqual(d.eos(), 2)
  16. self.assertEqual(d.unk(), 3)
  17. eos = d.eos()
  18. w1 = 4
  19. w2 = 5
  20. # construct dataloader
  21. data = [
  22. {
  23. "source": torch.LongTensor([w1, w2, eos]),
  24. "target": torch.LongTensor([w1, w2, w1, eos]),
  25. },
  26. {
  27. "source": torch.LongTensor([w2, eos]),
  28. "target": torch.LongTensor([w2, w1, eos]),
  29. },
  30. {
  31. "source": torch.LongTensor([w2, eos]),
  32. "target": torch.LongTensor([w2, eos]),
  33. },
  34. ]
  35. data_itr = test_utils.dummy_dataloader(data)
  36. # specify expected output probabilities
  37. args = argparse.Namespace()
  38. unk = 0.0
  39. args.beam_probs = [
  40. # step 0:
  41. torch.FloatTensor(
  42. [
  43. # eos w1 w2
  44. [0.0, unk, 0.6, 0.4], # sentence 1
  45. [0.0, unk, 0.4, 0.6], # sentence 2
  46. [0.0, unk, 0.7, 0.3], # sentence 3
  47. ]
  48. ),
  49. # step 1:
  50. torch.FloatTensor(
  51. [
  52. # eos w1 w2
  53. [0.0, unk, 0.2, 0.7], # sentence 1
  54. [0.0, unk, 0.8, 0.2], # sentence 2
  55. [0.7, unk, 0.1, 0.2], # sentence 3
  56. ]
  57. ),
  58. # step 2:
  59. torch.FloatTensor(
  60. [
  61. # eos w1 w2
  62. [0.10, unk, 0.50, 0.4], # sentence 1
  63. [0.15, unk, 0.15, 0.7], # sentence 2
  64. [0.00, unk, 0.00, 0.0], # sentence 3
  65. ]
  66. ),
  67. # step 3:
  68. torch.FloatTensor(
  69. [
  70. # eos w1 w2
  71. [0.9, unk, 0.05, 0.05], # sentence 1
  72. [0.0, unk, 0.00, 0.0], # sentence 2
  73. [0.0, unk, 0.00, 0.0], # sentence 3
  74. ]
  75. ),
  76. ]
  77. expected_scores = [
  78. [0.6, 0.7, 0.5, 0.9], # sentence 1
  79. [0.6, 0.8, 0.15], # sentence 2
  80. [0.3, 0.7], # sentence 3
  81. ]
  82. task = test_utils.TestTranslationTask.setup_task(args, d, d)
  83. model = task.build_model(args)
  84. scorer = SequenceScorer(task.target_dictionary)
  85. for sample in data_itr:
  86. hypos = task.inference_step(scorer, [model], sample)
  87. for id, hypos_id in zip(sample["id"].tolist(), hypos):
  88. self.assertHypoTokens(hypos_id[0], data[id]["target"])
  89. self.assertHypoScore(hypos_id[0], expected_scores[id])
  90. def assertHypoTokens(self, hypo, tokens):
  91. self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
  92. def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
  93. pos_scores = torch.FloatTensor(pos_probs).log()
  94. self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
  95. self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
  96. score = pos_scores.sum()
  97. if normalized:
  98. score /= pos_scores.numel() ** lenpen
  99. self.assertLess(abs(score - hypo["score"]), 1e-6)
  100. def assertAlmostEqual(self, t1, t2):
  101. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  102. self.assertLess((t1 - t2).abs().max(), 1e-4)
  103. def assertTensorEqual(self, t1, t2):
  104. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  105. self.assertEqual(t1.ne(t2).long().sum(), 0)
  106. if __name__ == "__main__":
  107. unittest.main()