test_token_block_dataset.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. import tests.utils as test_utils
  7. import torch
  8. from fairseq.data import TokenBlockDataset
  9. class TestTokenBlockDataset(unittest.TestCase):
  10. def _build_dataset(self, data, **kwargs):
  11. sizes = [len(x) for x in data]
  12. underlying_ds = test_utils.TestDataset(data)
  13. return TokenBlockDataset(underlying_ds, sizes, **kwargs)
  14. def test_eos_break_mode(self):
  15. data = [
  16. torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
  17. torch.tensor([1], dtype=torch.long),
  18. torch.tensor([8, 7, 6, 1], dtype=torch.long),
  19. ]
  20. ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
  21. self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
  22. self.assertEqual(ds[1].tolist(), [1])
  23. self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
  24. data = [
  25. torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
  26. torch.tensor([8, 7, 6, 1], dtype=torch.long),
  27. torch.tensor([1], dtype=torch.long),
  28. ]
  29. ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
  30. self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
  31. self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
  32. self.assertEqual(ds[2].tolist(), [1])
  33. def test_block_break_mode(self):
  34. data = [
  35. torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
  36. torch.tensor([8, 7, 6, 1], dtype=torch.long),
  37. torch.tensor([9, 1], dtype=torch.long),
  38. ]
  39. ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode="none")
  40. self.assertEqual(ds[0].tolist(), [5, 4, 3])
  41. self.assertEqual(ds[1].tolist(), [2, 1, 8])
  42. self.assertEqual(ds[2].tolist(), [7, 6, 1])
  43. self.assertEqual(ds[3].tolist(), [9, 1])
  44. def test_complete_break_mode(self):
  45. data = [
  46. torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
  47. torch.tensor([8, 7, 6, 1], dtype=torch.long),
  48. torch.tensor([9, 1], dtype=torch.long),
  49. ]
  50. ds = self._build_dataset(
  51. data, block_size=6, pad=0, eos=1, break_mode="complete"
  52. )
  53. self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
  54. self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
  55. data = [
  56. torch.tensor([4, 3, 2, 1], dtype=torch.long),
  57. torch.tensor([5, 1], dtype=torch.long),
  58. torch.tensor([1], dtype=torch.long),
  59. torch.tensor([6, 1], dtype=torch.long),
  60. ]
  61. ds = self._build_dataset(
  62. data, block_size=3, pad=0, eos=1, break_mode="complete"
  63. )
  64. self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
  65. self.assertEqual(ds[1].tolist(), [5, 1, 1])
  66. self.assertEqual(ds[2].tolist(), [6, 1])
  67. def test_4billion_tokens(self):
  68. """Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745"""
  69. data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000
  70. ds = self._build_dataset(
  71. data, block_size=6, pad=0, eos=1, break_mode="complete"
  72. )
  73. ds[-1] # __getitem__ works
  74. start, end = ds.slice_indices[-1]
  75. assert end > 4294967295 # data must be sufficiently large to overflow uint32
  76. assert not isinstance(
  77. end + 1, float
  78. ) # this would also raise, since np.uint64(1) + 1 => 2.0
  79. if __name__ == "__main__":
  80. unittest.main()