test_file_chunker_utils.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # This source code is licensed under the MIT license found in the
  2. # LICENSE file in the root directory of this source tree.
  3. import os
  4. import shutil
  5. import tempfile
  6. import unittest
  7. from typing import Optional
  8. class TestFileChunker(unittest.TestCase):
  9. _tmpdir: Optional[str] = None
  10. _tmpfile: Optional[str] = None
  11. _line_content = "Hello, World\n"
  12. _num_bytes = None
  13. _num_lines = 200
  14. _num_splits = 20
  15. @classmethod
  16. def setUpClass(cls) -> None:
  17. cls._num_bytes = len(cls._line_content.encode("utf-8"))
  18. cls._tmpdir = tempfile.mkdtemp()
  19. with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f:
  20. cls._tmpfile = f.name
  21. for _i in range(cls._num_lines):
  22. f.write(cls._line_content)
  23. f.flush()
  24. @classmethod
  25. def tearDownClass(cls) -> None:
  26. # Cleanup temp working dir.
  27. if cls._tmpdir is not None:
  28. shutil.rmtree(cls._tmpdir) # type: ignore
  29. def test_find_offsets(self):
  30. from fairseq.file_chunker_utils import find_offsets
  31. offsets = find_offsets(self._tmpfile, self._num_splits)
  32. self.assertEqual(len(offsets), self._num_splits + 1)
  33. (zero, *real_offsets, last) = offsets
  34. self.assertEqual(zero, 0)
  35. for i, o in enumerate(real_offsets):
  36. self.assertEqual(
  37. o,
  38. self._num_bytes
  39. + ((i + 1) * self._num_bytes * self._num_lines / self._num_splits),
  40. )
  41. self.assertEqual(last, self._num_bytes * self._num_lines)
  42. def test_readchunks(self):
  43. from fairseq.file_chunker_utils import Chunker, find_offsets
  44. offsets = find_offsets(self._tmpfile, self._num_splits)
  45. for start, end in zip(offsets, offsets[1:]):
  46. with Chunker(self._tmpfile, start, end) as lines:
  47. all_lines = list(lines)
  48. num_lines = self._num_lines / self._num_splits
  49. self.assertAlmostEqual(
  50. len(all_lines), num_lines, delta=1
  51. ) # because we split on the bites, we might end up with one more/less line in a chunk
  52. self.assertListEqual(
  53. all_lines, [self._line_content for _ in range(len(all_lines))]
  54. )