test_checkpoint_utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import contextlib
  6. import logging
  7. import os
  8. import tempfile
  9. import unittest
  10. from io import StringIO
  11. from unittest.mock import patch
  12. from omegaconf import OmegaConf
  13. from fairseq import checkpoint_utils
  14. from tests.utils import (
  15. create_dummy_data,
  16. preprocess_translation_data,
  17. train_translation_model,
  18. )
  19. import torch
  20. class TestCheckpointUtils(unittest.TestCase):
  21. def setUp(self):
  22. logging.disable(logging.CRITICAL)
  23. def tearDown(self):
  24. logging.disable(logging.NOTSET)
  25. @contextlib.contextmanager
  26. def _train_transformer(self, seed, extra_args=None):
  27. if extra_args is None:
  28. extra_args = []
  29. with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
  30. create_dummy_data(data_dir)
  31. preprocess_translation_data(data_dir)
  32. train_translation_model(
  33. data_dir,
  34. "transformer_iwslt_de_en",
  35. [
  36. "--encoder-layers",
  37. "3",
  38. "--decoder-layers",
  39. "3",
  40. "--encoder-embed-dim",
  41. "8",
  42. "--decoder-embed-dim",
  43. "8",
  44. "--seed",
  45. str(seed),
  46. ]
  47. + extra_args,
  48. )
  49. yield os.path.join(data_dir, "checkpoint_last.pt")
  50. def test_load_model_ensemble_and_task(self):
  51. # with contextlib.redirect_stdout(StringIO()):
  52. with self._train_transformer(seed=123) as model1:
  53. with self._train_transformer(seed=456) as model2:
  54. ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
  55. filenames=[model1, model2]
  56. )
  57. self.assertEqual(len(ensemble), 2)
  58. # after Transformer has been migrated to Hydra, this will probably
  59. # become cfg.common.seed
  60. self.assertEqual(ensemble[0].args.seed, 123)
  61. self.assertEqual(ensemble[1].args.seed, 456)
  62. # the task from the first model should be returned
  63. self.assertTrue("seed123" in task.cfg.data)
  64. # last cfg is saved
  65. self.assertEqual(cfg.common.seed, 456)
  66. def test_prune_state_dict(self):
  67. with contextlib.redirect_stdout(StringIO()):
  68. extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
  69. with self._train_transformer(seed=1, extra_args=extra_args) as model:
  70. ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
  71. filenames=[model],
  72. arg_overrides={
  73. "encoder_layers_to_keep": "0,2",
  74. "decoder_layers_to_keep": "1",
  75. },
  76. )
  77. self.assertEqual(len(ensemble), 1)
  78. self.assertEqual(len(ensemble[0].encoder.layers), 2)
  79. self.assertEqual(len(ensemble[0].decoder.layers), 1)
  80. def test_torch_persistent_save_async(self):
  81. state_dict = {}
  82. filename = "async_checkpoint.pt"
  83. with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
  84. with patch(
  85. f"{checkpoint_utils.__name__}._torch_persistent_save"
  86. ) as mock_save:
  87. checkpoint_utils.torch_persistent_save(
  88. state_dict, filename, async_write=True
  89. )
  90. mock_opena.assert_called_with(filename, "wb")
  91. mock_save.assert_called()
  92. def test_load_ema_from_checkpoint(self):
  93. dummy_state = {"a": torch.tensor([1]), "b": torch.tensor([0.1])}
  94. with patch(f"{checkpoint_utils.__name__}.PathManager.open") as mock_open, patch(
  95. f"{checkpoint_utils.__name__}.torch.load"
  96. ) as mock_load:
  97. mock_load.return_value = {"extra_state": {"ema": dummy_state}}
  98. filename = "ema_checkpoint.pt"
  99. state = checkpoint_utils.load_ema_from_checkpoint(filename)
  100. mock_open.assert_called_with(filename, "rb")
  101. mock_load.assert_called()
  102. self.assertIn("a", state["model"])
  103. self.assertIn("b", state["model"])
  104. self.assertTrue(torch.allclose(dummy_state["a"], state["model"]["a"]))
  105. self.assertTrue(torch.allclose(dummy_state["b"], state["model"]["b"]))
  106. if __name__ == "__main__":
  107. unittest.main()