test_iterators.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. from fairseq.data import iterators, ListDataset
  7. class TestIterators(unittest.TestCase):
  8. def test_counting_iterator_index(self, ref=None, itr=None):
  9. # Test the indexing functionality of CountingIterator
  10. if ref is None:
  11. assert itr is None
  12. ref = list(range(10))
  13. itr = iterators.CountingIterator(ref)
  14. else:
  15. assert len(ref) == 10
  16. assert itr is not None
  17. self.assertTrue(itr.has_next())
  18. self.assertEqual(itr.n, 0)
  19. self.assertEqual(next(itr), ref[0])
  20. self.assertEqual(itr.n, 1)
  21. self.assertEqual(next(itr), ref[1])
  22. self.assertEqual(itr.n, 2)
  23. itr.skip(3)
  24. self.assertEqual(itr.n, 5)
  25. self.assertEqual(next(itr), ref[5])
  26. itr.skip(2)
  27. self.assertEqual(itr.n, 8)
  28. self.assertEqual(list(itr), [ref[8], ref[9]])
  29. self.assertFalse(itr.has_next())
  30. def test_counting_iterator_length_mismatch(self):
  31. ref = list(range(10))
  32. # When the underlying iterable is longer than the CountingIterator,
  33. # the remaining items in the iterable should be ignored
  34. itr = iterators.CountingIterator(ref, total=8)
  35. self.assertEqual(list(itr), ref[:8])
  36. # When the underlying iterable is shorter than the CountingIterator,
  37. # raise an IndexError when the underlying iterable is exhausted
  38. itr = iterators.CountingIterator(ref, total=12)
  39. self.assertRaises(IndexError, list, itr)
  40. def test_counting_iterator_take(self):
  41. # Test the "take" method of CountingIterator
  42. ref = list(range(10))
  43. itr = iterators.CountingIterator(ref)
  44. itr.take(5)
  45. self.assertEqual(len(itr), len(list(iter(itr))))
  46. self.assertEqual(len(itr), 5)
  47. itr = iterators.CountingIterator(ref)
  48. itr.take(5)
  49. self.assertEqual(next(itr), ref[0])
  50. self.assertEqual(next(itr), ref[1])
  51. itr.skip(2)
  52. self.assertEqual(next(itr), ref[4])
  53. self.assertFalse(itr.has_next())
  54. def test_grouped_iterator(self):
  55. # test correctness
  56. x = list(range(10))
  57. itr = iterators.GroupedIterator(x, 1)
  58. self.assertEqual(list(itr), [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]])
  59. itr = iterators.GroupedIterator(x, 4)
  60. self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]])
  61. itr = iterators.GroupedIterator(x, 5)
  62. self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
  63. # test the GroupIterator also works correctly as a CountingIterator
  64. x = list(range(30))
  65. ref = list(iterators.GroupedIterator(x, 3))
  66. itr = iterators.GroupedIterator(x, 3)
  67. self.test_counting_iterator_index(ref, itr)
  68. def test_sharded_iterator(self):
  69. # test correctness
  70. x = list(range(10))
  71. itr = iterators.ShardedIterator(x, num_shards=1, shard_id=0)
  72. self.assertEqual(list(itr), x)
  73. itr = iterators.ShardedIterator(x, num_shards=2, shard_id=0)
  74. self.assertEqual(list(itr), [0, 2, 4, 6, 8])
  75. itr = iterators.ShardedIterator(x, num_shards=2, shard_id=1)
  76. self.assertEqual(list(itr), [1, 3, 5, 7, 9])
  77. itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
  78. self.assertEqual(list(itr), [0, 3, 6, 9])
  79. itr = iterators.ShardedIterator(x, num_shards=3, shard_id=1)
  80. self.assertEqual(list(itr), [1, 4, 7, None])
  81. itr = iterators.ShardedIterator(x, num_shards=3, shard_id=2)
  82. self.assertEqual(list(itr), [2, 5, 8, None])
  83. # test CountingIterator functionality
  84. x = list(range(30))
  85. ref = list(iterators.ShardedIterator(x, num_shards=3, shard_id=0))
  86. itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
  87. self.test_counting_iterator_index(ref, itr)
  88. def test_counting_iterator_buffered_iterator_take(self):
  89. ref = list(range(10))
  90. buffered_itr = iterators.BufferedIterator(2, ref)
  91. itr = iterators.CountingIterator(buffered_itr)
  92. itr.take(5)
  93. self.assertEqual(len(itr), len(list(iter(itr))))
  94. self.assertEqual(len(itr), 5)
  95. buffered_itr = iterators.BufferedIterator(2, ref)
  96. itr = iterators.CountingIterator(buffered_itr)
  97. itr.take(5)
  98. self.assertEqual(len(buffered_itr), 5)
  99. self.assertEqual(len(list(iter(buffered_itr))), 5)
  100. buffered_itr = iterators.BufferedIterator(2, ref)
  101. itr = iterators.CountingIterator(buffered_itr)
  102. itr.take(5)
  103. self.assertEqual(next(itr), ref[0])
  104. self.assertEqual(next(itr), ref[1])
  105. itr.skip(2)
  106. self.assertEqual(next(itr), ref[4])
  107. self.assertFalse(itr.has_next())
  108. self.assertRaises(StopIteration, next, buffered_itr)
  109. ref = list(range(4, 10))
  110. buffered_itr = iterators.BufferedIterator(2, ref)
  111. itr = iterators.CountingIterator(buffered_itr, start=4)
  112. itr.take(5)
  113. self.assertEqual(len(itr), 5)
  114. self.assertEqual(len(buffered_itr), 1)
  115. self.assertEqual(next(itr), ref[0])
  116. self.assertFalse(itr.has_next())
  117. self.assertRaises(StopIteration, next, buffered_itr)
  118. def test_epoch_batch_iterator_skip_remainder_batch(self):
  119. reference = [1, 2, 3]
  120. itr1 = _get_epoch_batch_itr(reference, 2, True)
  121. self.assertEqual(len(itr1), 1)
  122. itr2 = _get_epoch_batch_itr(reference, 2, False)
  123. self.assertEqual(len(itr2), 2)
  124. itr3 = _get_epoch_batch_itr(reference, 1, True)
  125. self.assertEqual(len(itr3), 2)
  126. itr4 = _get_epoch_batch_itr(reference, 1, False)
  127. self.assertEqual(len(itr4), 3)
  128. itr5 = _get_epoch_batch_itr(reference, 4, True)
  129. self.assertEqual(len(itr5), 0)
  130. self.assertFalse(itr5.has_next())
  131. itr6 = _get_epoch_batch_itr(reference, 4, False)
  132. self.assertEqual(len(itr6), 1)
  133. def test_grouped_iterator_skip_remainder_batch(self):
  134. reference = [1, 2, 3, 4, 5, 6, 7, 8, 9]
  135. itr1 = _get_epoch_batch_itr(reference, 3, False)
  136. grouped_itr1 = iterators.GroupedIterator(itr1, 2, True)
  137. self.assertEqual(len(grouped_itr1), 1)
  138. itr2 = _get_epoch_batch_itr(reference, 3, False)
  139. grouped_itr2 = iterators.GroupedIterator(itr2, 2, False)
  140. self.assertEqual(len(grouped_itr2), 2)
  141. itr3 = _get_epoch_batch_itr(reference, 3, True)
  142. grouped_itr3 = iterators.GroupedIterator(itr3, 2, True)
  143. self.assertEqual(len(grouped_itr3), 1)
  144. itr4 = _get_epoch_batch_itr(reference, 3, True)
  145. grouped_itr4 = iterators.GroupedIterator(itr4, 2, False)
  146. self.assertEqual(len(grouped_itr4), 1)
  147. itr5 = _get_epoch_batch_itr(reference, 5, True)
  148. grouped_itr5 = iterators.GroupedIterator(itr5, 2, True)
  149. self.assertEqual(len(grouped_itr5), 0)
  150. itr6 = _get_epoch_batch_itr(reference, 5, True)
  151. grouped_itr6 = iterators.GroupedIterator(itr6, 2, False)
  152. self.assertEqual(len(grouped_itr6), 1)
  153. def _get_epoch_batch_itr(ref, bsz, skip_remainder_batch):
  154. dsz = len(ref)
  155. indices = range(dsz)
  156. starts = indices[::bsz]
  157. batch_sampler = [indices[s : s + bsz] for s in starts]
  158. dataset = ListDataset(ref)
  159. itr = iterators.EpochBatchIterator(
  160. dataset=dataset,
  161. collate_fn=dataset.collater,
  162. batch_sampler=batch_sampler,
  163. skip_remainder_batch=skip_remainder_batch,
  164. )
  165. return itr.next_epoch_itr()
  166. if __name__ == "__main__":
  167. unittest.main()