test_reproducibility.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import json
  6. import os
  7. import tempfile
  8. import unittest
  9. import torch
  10. from . import test_binaries
  11. class TestReproducibility(unittest.TestCase):
  12. def _test_reproducibility(
  13. self,
  14. name,
  15. extra_flags=None,
  16. delta=0.0001,
  17. resume_checkpoint="checkpoint1.pt",
  18. max_epoch=3,
  19. ):
  20. def get_last_log_stats_containing_string(log_records, search_string):
  21. for log_record in logs.records[::-1]:
  22. if isinstance(log_record.msg, str) and search_string in log_record.msg:
  23. return json.loads(log_record.msg)
  24. if extra_flags is None:
  25. extra_flags = []
  26. with tempfile.TemporaryDirectory(name) as data_dir:
  27. with self.assertLogs() as logs:
  28. test_binaries.create_dummy_data(data_dir)
  29. test_binaries.preprocess_translation_data(data_dir)
  30. # train epochs 1 and 2 together
  31. with self.assertLogs() as logs:
  32. test_binaries.train_translation_model(
  33. data_dir,
  34. "fconv_iwslt_de_en",
  35. [
  36. "--dropout",
  37. "0.0",
  38. "--log-format",
  39. "json",
  40. "--log-interval",
  41. "1",
  42. "--max-epoch",
  43. str(max_epoch),
  44. ]
  45. + extra_flags,
  46. )
  47. train_log = get_last_log_stats_containing_string(logs.records, "train_loss")
  48. valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss")
  49. # train epoch 2, resuming from previous checkpoint 1
  50. os.rename(
  51. os.path.join(data_dir, resume_checkpoint),
  52. os.path.join(data_dir, "checkpoint_last.pt"),
  53. )
  54. with self.assertLogs() as logs:
  55. test_binaries.train_translation_model(
  56. data_dir,
  57. "fconv_iwslt_de_en",
  58. [
  59. "--dropout",
  60. "0.0",
  61. "--log-format",
  62. "json",
  63. "--log-interval",
  64. "1",
  65. "--max-epoch",
  66. str(max_epoch),
  67. ]
  68. + extra_flags,
  69. )
  70. train_res_log = get_last_log_stats_containing_string(
  71. logs.records, "train_loss"
  72. )
  73. valid_res_log = get_last_log_stats_containing_string(
  74. logs.records, "valid_loss"
  75. )
  76. for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]:
  77. self.assertAlmostEqual(
  78. float(train_log[k]), float(train_res_log[k]), delta=delta
  79. )
  80. for k in [
  81. "valid_loss",
  82. "valid_ppl",
  83. "valid_num_updates",
  84. "valid_best_loss",
  85. ]:
  86. self.assertAlmostEqual(
  87. float(valid_log[k]), float(valid_res_log[k]), delta=delta
  88. )
  89. def test_reproducibility(self):
  90. self._test_reproducibility("test_reproducibility")
  91. @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
  92. def test_reproducibility_fp16(self):
  93. self._test_reproducibility(
  94. "test_reproducibility_fp16",
  95. [
  96. "--fp16",
  97. "--fp16-init-scale",
  98. "4096",
  99. ],
  100. delta=0.011,
  101. )
  102. @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
  103. def test_reproducibility_memory_efficient_fp16(self):
  104. self._test_reproducibility(
  105. "test_reproducibility_memory_efficient_fp16",
  106. [
  107. "--memory-efficient-fp16",
  108. "--fp16-init-scale",
  109. "4096",
  110. ],
  111. )
  112. @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
  113. def test_reproducibility_amp(self):
  114. self._test_reproducibility(
  115. "test_reproducibility_amp",
  116. [
  117. "--amp",
  118. "--fp16-init-scale",
  119. "4096",
  120. ],
  121. delta=0.011,
  122. )
  123. def test_mid_epoch_reproducibility(self):
  124. self._test_reproducibility(
  125. "test_mid_epoch_reproducibility",
  126. ["--save-interval-updates", "3"],
  127. resume_checkpoint="checkpoint_1_3.pt",
  128. max_epoch=1,
  129. )
  130. if __name__ == "__main__":
  131. unittest.main()