test_plasma_utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import contextlib
  2. import tempfile
  3. import unittest
  4. from io import StringIO
  5. import numpy as np
  6. from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model
  7. try:
  8. from pyarrow import plasma
  9. from fairseq.data.plasma_utils import PlasmaStore, PlasmaView
  10. PYARROW_AVAILABLE = True
  11. except ImportError:
  12. PYARROW_AVAILABLE = False
  13. dummy_path = "dummy"
  14. @unittest.skipUnless(PYARROW_AVAILABLE, "")
  15. class TestPlasmaView(unittest.TestCase):
  16. def setUp(self) -> None:
  17. self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201
  18. self.path = self.tmp_file.name
  19. self.server = PlasmaStore.start(path=self.path, nbytes=10000)
  20. self.client = plasma.connect(self.path, num_retries=10)
  21. def tearDown(self) -> None:
  22. self.client.disconnect()
  23. self.tmp_file.close()
  24. self.server.kill()
  25. def test_two_servers_do_not_share_object_id_space(self):
  26. data_server_1 = np.array([0, 1])
  27. data_server_2 = np.array([2, 3])
  28. server_2_path = self.path
  29. with tempfile.NamedTemporaryFile() as server_1_path:
  30. server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
  31. arr1 = PlasmaView(
  32. data_server_1, dummy_path, 1, plasma_path=server_1_path.name
  33. )
  34. assert len(arr1.client.list()) == 1
  35. assert (arr1.array == data_server_1).all()
  36. arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
  37. assert (arr2.array == data_server_2).all()
  38. assert (arr1.array == data_server_1).all()
  39. server.kill()
  40. def test_hash_collision(self):
  41. data_server_1 = np.array([0, 1])
  42. data_server_2 = np.array([2, 3])
  43. arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
  44. assert len(arr1.client.list()) == 1
  45. arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
  46. assert len(arr1.client.list()) == 1
  47. assert len(arr2.client.list()) == 1
  48. assert (arr2.array == data_server_1).all()
  49. # New hash key based on tuples
  50. arr3 = PlasmaView(
  51. data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
  52. )
  53. assert (
  54. len(arr2.client.list()) == 2
  55. ), "No new object was created by using a novel hash key"
  56. assert (
  57. arr3.object_id in arr2.client.list()
  58. ), "No new object was created by using a novel hash key"
  59. assert (
  60. arr3.object_id in arr3.client.list()
  61. ), "No new object was created by using a novel hash key"
  62. del arr3, arr2, arr1
  63. @staticmethod
  64. def _assert_view_equal(pv1, pv2):
  65. np.testing.assert_array_equal(pv1.array, pv2.array)
  66. def test_putting_same_array_twice(self):
  67. data = np.array([4, 4, 4])
  68. arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
  69. assert len(self.client.list()) == 1
  70. arr1b = PlasmaView(
  71. data, dummy_path, 1, plasma_path=self.path
  72. ) # should not change contents of store
  73. arr1c = PlasmaView(
  74. None, dummy_path, 1, plasma_path=self.path
  75. ) # should not change contents of store
  76. assert len(self.client.list()) == 1
  77. self._assert_view_equal(arr1, arr1b)
  78. self._assert_view_equal(arr1, arr1c)
  79. PlasmaView(
  80. data, dummy_path, 2, plasma_path=self.path
  81. ) # new object id, adds new entry
  82. assert len(self.client.list()) == 2
  83. new_client = plasma.connect(self.path)
  84. assert len(new_client.list()) == 2 # new client can access same objects
  85. assert isinstance(arr1.object_id, plasma.ObjectID)
  86. del arr1b
  87. del arr1c
  88. def test_plasma_store_full_raises(self):
  89. with tempfile.NamedTemporaryFile() as new_path:
  90. server = PlasmaStore.start(path=new_path.name, nbytes=10000)
  91. with self.assertRaises(plasma.PlasmaStoreFull):
  92. # 2000 floats is more than 2000 bytes
  93. PlasmaView(
  94. np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
  95. )
  96. server.kill()
  97. def test_object_id_overflow(self):
  98. PlasmaView.get_object_id("", 2**21)
  99. def test_training_lm_plasma(self):
  100. with contextlib.redirect_stdout(StringIO()):
  101. with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
  102. create_dummy_data(data_dir)
  103. preprocess_lm_data(data_dir)
  104. train_language_model(
  105. data_dir,
  106. "transformer_lm",
  107. ["--use-plasma-view", "--plasma-path", self.path],
  108. run_validation=True,
  109. )