123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import unittest
- import tests.utils as test_utils
- import torch
- from fairseq.data import (
- BacktranslationDataset,
- LanguagePairDataset,
- TransformEosDataset,
- )
- from fairseq.sequence_generator import SequenceGenerator
- class TestBacktranslationDataset(unittest.TestCase):
- def setUp(self):
- (
- self.tgt_dict,
- self.w1,
- self.w2,
- self.src_tokens,
- self.src_lengths,
- self.model,
- ) = test_utils.sequence_generator_setup()
- dummy_src_samples = self.src_tokens
- self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
- self.cuda = torch.cuda.is_available()
- def _backtranslation_dataset_helper(
- self,
- remove_eos_from_input_src,
- remove_eos_from_output_src,
- ):
- tgt_dataset = LanguagePairDataset(
- src=self.tgt_dataset,
- src_sizes=self.tgt_dataset.sizes,
- src_dict=self.tgt_dict,
- tgt=None,
- tgt_sizes=None,
- tgt_dict=None,
- )
- generator = SequenceGenerator(
- [self.model],
- tgt_dict=self.tgt_dict,
- max_len_a=0,
- max_len_b=200,
- beam_size=2,
- unk_penalty=0,
- )
- backtranslation_dataset = BacktranslationDataset(
- tgt_dataset=TransformEosDataset(
- dataset=tgt_dataset,
- eos=self.tgt_dict.eos(),
- # remove eos from the input src
- remove_eos_from_src=remove_eos_from_input_src,
- ),
- src_dict=self.tgt_dict,
- backtranslation_fn=(
- lambda sample: generator.generate([self.model], sample)
- ),
- output_collater=TransformEosDataset(
- dataset=tgt_dataset,
- eos=self.tgt_dict.eos(),
- # if we remove eos from the input src, then we need to add it
- # back to the output tgt
- append_eos_to_tgt=remove_eos_from_input_src,
- remove_eos_from_src=remove_eos_from_output_src,
- ).collater,
- cuda=self.cuda,
- )
- dataloader = torch.utils.data.DataLoader(
- backtranslation_dataset,
- batch_size=2,
- collate_fn=backtranslation_dataset.collater,
- )
- backtranslation_batch_result = next(iter(dataloader))
- eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
- # Note that we sort by src_lengths and add left padding, so actually
- # ids will look like: [1, 0]
- expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
- if remove_eos_from_output_src:
- expected_src = expected_src[:, :-1]
- expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
- generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
- tgt_tokens = backtranslation_batch_result["target"]
- self.assertTensorEqual(expected_src, generated_src)
- self.assertTensorEqual(expected_tgt, tgt_tokens)
- def test_backtranslation_dataset_no_eos_in_output_src(self):
- self._backtranslation_dataset_helper(
- remove_eos_from_input_src=False,
- remove_eos_from_output_src=True,
- )
- def test_backtranslation_dataset_with_eos_in_output_src(self):
- self._backtranslation_dataset_helper(
- remove_eos_from_input_src=False,
- remove_eos_from_output_src=False,
- )
- def test_backtranslation_dataset_no_eos_in_input_src(self):
- self._backtranslation_dataset_helper(
- remove_eos_from_input_src=True,
- remove_eos_from_output_src=False,
- )
- def assertTensorEqual(self, t1, t2):
- self.assertEqual(t1.size(), t2.size(), "size mismatch")
- self.assertEqual(t1.ne(t2).long().sum(), 0)
- if __name__ == "__main__":
- unittest.main()
|