123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import json
- import os
- import tempfile
- import unittest
- import torch
- from . import test_binaries
- class TestReproducibility(unittest.TestCase):
- def _test_reproducibility(
- self,
- name,
- extra_flags=None,
- delta=0.0001,
- resume_checkpoint="checkpoint1.pt",
- max_epoch=3,
- ):
- def get_last_log_stats_containing_string(log_records, search_string):
- for log_record in logs.records[::-1]:
- if isinstance(log_record.msg, str) and search_string in log_record.msg:
- return json.loads(log_record.msg)
- if extra_flags is None:
- extra_flags = []
- with tempfile.TemporaryDirectory(name) as data_dir:
- with self.assertLogs() as logs:
- test_binaries.create_dummy_data(data_dir)
- test_binaries.preprocess_translation_data(data_dir)
- # train epochs 1 and 2 together
- with self.assertLogs() as logs:
- test_binaries.train_translation_model(
- data_dir,
- "fconv_iwslt_de_en",
- [
- "--dropout",
- "0.0",
- "--log-format",
- "json",
- "--log-interval",
- "1",
- "--max-epoch",
- str(max_epoch),
- ]
- + extra_flags,
- )
- train_log = get_last_log_stats_containing_string(logs.records, "train_loss")
- valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss")
- # train epoch 2, resuming from previous checkpoint 1
- os.rename(
- os.path.join(data_dir, resume_checkpoint),
- os.path.join(data_dir, "checkpoint_last.pt"),
- )
- with self.assertLogs() as logs:
- test_binaries.train_translation_model(
- data_dir,
- "fconv_iwslt_de_en",
- [
- "--dropout",
- "0.0",
- "--log-format",
- "json",
- "--log-interval",
- "1",
- "--max-epoch",
- str(max_epoch),
- ]
- + extra_flags,
- )
- train_res_log = get_last_log_stats_containing_string(
- logs.records, "train_loss"
- )
- valid_res_log = get_last_log_stats_containing_string(
- logs.records, "valid_loss"
- )
- for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]:
- self.assertAlmostEqual(
- float(train_log[k]), float(train_res_log[k]), delta=delta
- )
- for k in [
- "valid_loss",
- "valid_ppl",
- "valid_num_updates",
- "valid_best_loss",
- ]:
- self.assertAlmostEqual(
- float(valid_log[k]), float(valid_res_log[k]), delta=delta
- )
- def test_reproducibility(self):
- self._test_reproducibility("test_reproducibility")
- @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
- def test_reproducibility_fp16(self):
- self._test_reproducibility(
- "test_reproducibility_fp16",
- [
- "--fp16",
- "--fp16-init-scale",
- "4096",
- ],
- delta=0.011,
- )
- @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
- def test_reproducibility_memory_efficient_fp16(self):
- self._test_reproducibility(
- "test_reproducibility_memory_efficient_fp16",
- [
- "--memory-efficient-fp16",
- "--fp16-init-scale",
- "4096",
- ],
- )
- @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
- def test_reproducibility_amp(self):
- self._test_reproducibility(
- "test_reproducibility_amp",
- [
- "--amp",
- "--fp16-init-scale",
- "4096",
- ],
- delta=0.011,
- )
- def test_mid_epoch_reproducibility(self):
- self._test_reproducibility(
- "test_mid_epoch_reproducibility",
- ["--save-interval-updates", "3"],
- resume_checkpoint="checkpoint_1_3.pt",
- max_epoch=1,
- )
- if __name__ == "__main__":
- unittest.main()
|