test_multi_corpus_dataset.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. from collections import OrderedDict
  7. import torch
  8. from fairseq.data import LanguagePairDataset, TokenBlockDataset
  9. from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
  10. from tests.test_train import mock_dict
  11. class TestMultiCorpusDataset(unittest.TestCase):
  12. def setUp(self):
  13. d = mock_dict()
  14. tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1)
  15. tokens_ds1 = TokenBlockDataset(
  16. tokens_1,
  17. sizes=[tokens_1.size(-1)],
  18. block_size=1,
  19. pad=0,
  20. eos=1,
  21. include_targets=False,
  22. )
  23. self.dataset_1 = LanguagePairDataset(
  24. tokens_ds1, tokens_ds1.sizes, d, shuffle=False
  25. )
  26. tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1)
  27. tokens_ds2 = TokenBlockDataset(
  28. tokens_2,
  29. sizes=[tokens_2.size(-1)],
  30. block_size=1,
  31. pad=0,
  32. eos=1,
  33. include_targets=False,
  34. )
  35. self.dataset_2 = LanguagePairDataset(
  36. tokens_ds2, tokens_ds2.sizes, d, shuffle=False
  37. )
  38. def _test_sample_helper(
  39. self,
  40. distribution,
  41. ):
  42. m = MultiCorpusDataset(
  43. OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
  44. distribution=distribution,
  45. seed=0,
  46. sort_indices=True,
  47. )
  48. m.set_epoch(1)
  49. indices = m.ordered_indices()
  50. count_sample_from_first_dataset = 0
  51. items = set()
  52. for i in indices:
  53. item = m[i]["source"].item()
  54. if item % 2 == 1:
  55. count_sample_from_first_dataset += 1
  56. items.add(item)
  57. sample_from_first_ds_percentage = (
  58. 1.0 * count_sample_from_first_dataset / len(indices)
  59. )
  60. self.assertLess(
  61. abs(sample_from_first_ds_percentage - distribution[0]),
  62. 0.01,
  63. )
  64. self.assertEqual(
  65. len(items),
  66. int(
  67. min(len(self.dataset_1), len(indices) * distribution[0])
  68. + min(len(self.dataset_1), len(indices) * distribution[1])
  69. ),
  70. )
  71. print(distribution)
  72. def test_multi_corpus_dataset(self):
  73. for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1], [0.0, 1.0]]:
  74. self._test_sample_helper(distribution=distribution)