test_backtranslation_dataset.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. import tests.utils as test_utils
  7. import torch
  8. from fairseq.data import (
  9. BacktranslationDataset,
  10. LanguagePairDataset,
  11. TransformEosDataset,
  12. )
  13. from fairseq.sequence_generator import SequenceGenerator
  14. class TestBacktranslationDataset(unittest.TestCase):
  15. def setUp(self):
  16. (
  17. self.tgt_dict,
  18. self.w1,
  19. self.w2,
  20. self.src_tokens,
  21. self.src_lengths,
  22. self.model,
  23. ) = test_utils.sequence_generator_setup()
  24. dummy_src_samples = self.src_tokens
  25. self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
  26. self.cuda = torch.cuda.is_available()
  27. def _backtranslation_dataset_helper(
  28. self,
  29. remove_eos_from_input_src,
  30. remove_eos_from_output_src,
  31. ):
  32. tgt_dataset = LanguagePairDataset(
  33. src=self.tgt_dataset,
  34. src_sizes=self.tgt_dataset.sizes,
  35. src_dict=self.tgt_dict,
  36. tgt=None,
  37. tgt_sizes=None,
  38. tgt_dict=None,
  39. )
  40. generator = SequenceGenerator(
  41. [self.model],
  42. tgt_dict=self.tgt_dict,
  43. max_len_a=0,
  44. max_len_b=200,
  45. beam_size=2,
  46. unk_penalty=0,
  47. )
  48. backtranslation_dataset = BacktranslationDataset(
  49. tgt_dataset=TransformEosDataset(
  50. dataset=tgt_dataset,
  51. eos=self.tgt_dict.eos(),
  52. # remove eos from the input src
  53. remove_eos_from_src=remove_eos_from_input_src,
  54. ),
  55. src_dict=self.tgt_dict,
  56. backtranslation_fn=(
  57. lambda sample: generator.generate([self.model], sample)
  58. ),
  59. output_collater=TransformEosDataset(
  60. dataset=tgt_dataset,
  61. eos=self.tgt_dict.eos(),
  62. # if we remove eos from the input src, then we need to add it
  63. # back to the output tgt
  64. append_eos_to_tgt=remove_eos_from_input_src,
  65. remove_eos_from_src=remove_eos_from_output_src,
  66. ).collater,
  67. cuda=self.cuda,
  68. )
  69. dataloader = torch.utils.data.DataLoader(
  70. backtranslation_dataset,
  71. batch_size=2,
  72. collate_fn=backtranslation_dataset.collater,
  73. )
  74. backtranslation_batch_result = next(iter(dataloader))
  75. eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
  76. # Note that we sort by src_lengths and add left padding, so actually
  77. # ids will look like: [1, 0]
  78. expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
  79. if remove_eos_from_output_src:
  80. expected_src = expected_src[:, :-1]
  81. expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
  82. generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
  83. tgt_tokens = backtranslation_batch_result["target"]
  84. self.assertTensorEqual(expected_src, generated_src)
  85. self.assertTensorEqual(expected_tgt, tgt_tokens)
  86. def test_backtranslation_dataset_no_eos_in_output_src(self):
  87. self._backtranslation_dataset_helper(
  88. remove_eos_from_input_src=False,
  89. remove_eos_from_output_src=True,
  90. )
  91. def test_backtranslation_dataset_with_eos_in_output_src(self):
  92. self._backtranslation_dataset_helper(
  93. remove_eos_from_input_src=False,
  94. remove_eos_from_output_src=False,
  95. )
  96. def test_backtranslation_dataset_no_eos_in_input_src(self):
  97. self._backtranslation_dataset_helper(
  98. remove_eos_from_input_src=True,
  99. remove_eos_from_output_src=False,
  100. )
  101. def assertTensorEqual(self, t1, t2):
  102. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  103. self.assertEqual(t1.ne(t2).long().sum(), 0)
  104. if __name__ == "__main__":
  105. unittest.main()