test_data_utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. import numpy as np
  7. from fairseq.data.data_utils_fast import batch_by_size_fn, batch_by_size_vec
  8. class TestBatchBySize(unittest.TestCase):
  9. @classmethod
  10. def batch_by_size_baseline(
  11. cls,
  12. indices,
  13. num_tokens_vec,
  14. max_tokens,
  15. max_sentences,
  16. bsz_mult,
  17. ):
  18. """Simple, reliable and slow implementation of batch by size"""
  19. batches = []
  20. start = 0
  21. while start < len(indices):
  22. for end in range(start + 1, len(indices) + 1):
  23. max_val = max(num_tokens_vec[pos] for pos in range(start, end))
  24. sent_count = end - start
  25. num_tokens = max_val * sent_count
  26. overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0
  27. terminate = overflow or end == len(indices)
  28. if overflow:
  29. sent_count -= 1
  30. if terminate:
  31. if sent_count > bsz_mult:
  32. sent_count = sent_count - sent_count % bsz_mult
  33. batches.append(indices[start : start + sent_count])
  34. start = start + sent_count
  35. break
  36. return batches
  37. @classmethod
  38. def _get_error_message(
  39. cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
  40. ):
  41. return f"""Reference batch_by_size implementation should produce
  42. same output as the baseline method.
  43. Params:
  44. max_sentences={max_sentences},
  45. max_tokens={max_tokens},
  46. bsz_mult={bsz_mult},
  47. num_tokens_vec={num_tokens_vec},
  48. expected_batches={validation},
  49. returned_batches={results}"""
  50. def _compare_results(
  51. self,
  52. indices_len,
  53. batch_by_size_impl,
  54. max_sentences,
  55. max_tokens,
  56. bsz_mult,
  57. num_tokens_vec,
  58. ):
  59. indices = np.array(list(range(indices_len)))
  60. validation = self.batch_by_size_baseline(
  61. indices,
  62. num_tokens_vec,
  63. max_tokens=max_tokens,
  64. max_sentences=max_sentences,
  65. bsz_mult=bsz_mult,
  66. )
  67. results = batch_by_size_impl(
  68. indices,
  69. num_tokens_vec,
  70. max_tokens=max_tokens,
  71. max_sentences=max_sentences,
  72. bsz_mult=bsz_mult,
  73. )
  74. error_msg = self._get_error_message(
  75. max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
  76. )
  77. self.assertEqual(len(validation), len(results), error_msg)
  78. for first, second in zip(validation, results):
  79. self.assertTrue(np.array_equal(first, second), error_msg)
  80. def _run_compare_with_baseline_sweep(self, batch_by_size_impl):
  81. """Compare reference batch_by_size implementation with batch_by_size_baseline
  82. across a dense grid of hyperparam values"""
  83. MAX_MAX_TOKENS = 10
  84. NUM_TOKENS_VECS_COUNT = 5
  85. for indices_len in [10, 11]: # try odd and even len of indices
  86. for max_sentences in range(0, indices_len + 2):
  87. for max_tokens in range(0, MAX_MAX_TOKENS):
  88. for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2):
  89. for _ in range(NUM_TOKENS_VECS_COUNT):
  90. num_tokens_vec = np.random.randint(
  91. 0, max_tokens + 1, size=indices_len
  92. )
  93. self._compare_results(
  94. indices_len,
  95. batch_by_size_impl,
  96. max_sentences,
  97. max_tokens,
  98. bsz_mult,
  99. num_tokens_vec,
  100. )
  101. class TestBatchBySizeVec(TestBatchBySize):
  102. def test_compare_with_baseline(self):
  103. self._run_compare_with_baseline_sweep(batch_by_size_vec)
  104. class TestBatchBySizeFn(TestBatchBySize):
  105. def test_compare_with_baseline(self):
  106. def batch_by_size_fn_wrapper(
  107. indices,
  108. num_tokens_vec,
  109. max_tokens,
  110. max_sentences,
  111. bsz_mult,
  112. ):
  113. def num_tokens_fn(idx):
  114. return num_tokens_vec[idx]
  115. return batch_by_size_fn(
  116. indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult
  117. )
  118. self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper)
  119. if __name__ == "__main__":
  120. unittest.main()