test_resampling_dataset.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import collections
  6. import unittest
  7. import numpy as np
  8. from fairseq.data import ListDataset, ResamplingDataset
  9. class TestResamplingDataset(unittest.TestCase):
  10. def setUp(self):
  11. self.strings = ["ab", "c", "def", "ghij"]
  12. self.weights = [4.0, 2.0, 7.0, 1.5]
  13. self.size_ratio = 2
  14. self.dataset = ListDataset(
  15. self.strings, np.array([len(s) for s in self.strings])
  16. )
  17. def _test_common(self, resampling_dataset, iters):
  18. assert len(self.dataset) == len(self.strings) == len(self.weights)
  19. assert len(resampling_dataset) == self.size_ratio * len(self.strings)
  20. results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
  21. totalfreqs = 0
  22. freqs = collections.defaultdict(int)
  23. for epoch_num in range(iters):
  24. resampling_dataset.set_epoch(epoch_num)
  25. indices = resampling_dataset.ordered_indices()
  26. assert len(indices) == len(resampling_dataset)
  27. prev_size = -1
  28. for i in indices:
  29. cur_size = resampling_dataset.size(i)
  30. # Make sure indices map to same sequences within an epoch
  31. assert resampling_dataset[i] == resampling_dataset[i]
  32. # Make sure length of sequence is correct
  33. assert cur_size == len(resampling_dataset[i])
  34. freqs[resampling_dataset[i]] += 1
  35. totalfreqs += 1
  36. if prev_size > cur_size:
  37. results["ordered_by_size"] = False
  38. prev_size = cur_size
  39. assert set(freqs.keys()) == set(self.strings)
  40. for s, weight in zip(self.strings, self.weights):
  41. freq = freqs[s] / totalfreqs
  42. expected_freq = weight / sum(self.weights)
  43. results["max_distribution_diff"] = max(
  44. results["max_distribution_diff"], abs(expected_freq - freq)
  45. )
  46. return results
  47. def test_resampling_dataset_batch_by_size_false(self):
  48. resampling_dataset = ResamplingDataset(
  49. self.dataset,
  50. self.weights,
  51. size_ratio=self.size_ratio,
  52. batch_by_size=False,
  53. seed=0,
  54. )
  55. results = self._test_common(resampling_dataset, iters=1000)
  56. # For batch_by_size = False, the batches should be returned in
  57. # arbitrary order of size.
  58. assert not results["ordered_by_size"]
  59. # Allow tolerance in distribution error of 2%.
  60. assert results["max_distribution_diff"] < 0.02
  61. def test_resampling_dataset_batch_by_size_true(self):
  62. resampling_dataset = ResamplingDataset(
  63. self.dataset,
  64. self.weights,
  65. size_ratio=self.size_ratio,
  66. batch_by_size=True,
  67. seed=0,
  68. )
  69. results = self._test_common(resampling_dataset, iters=1000)
  70. # For batch_by_size = True, the batches should be returned in
  71. # increasing order of size.
  72. assert results["ordered_by_size"]
  73. # Allow tolerance in distribution error of 2%.
  74. assert results["max_distribution_diff"] < 0.02
  75. if __name__ == "__main__":
  76. unittest.main()