test_ema.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. from copy import deepcopy
  7. from dataclasses import dataclass
  8. import pytest
  9. from typing import Optional
  10. from unittest.mock import patch
  11. import torch
  12. from fairseq.models.ema import EMA
  13. class DummyModule(torch.nn.Module):
  14. def __init__(self) -> None:
  15. """LightningModule for testing purposes
  16. Args:
  17. epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
  18. validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
  19. """
  20. super().__init__()
  21. self.layer = torch.nn.Linear(in_features=32, out_features=2)
  22. self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
  23. def forward(self, x: torch.Tensor) -> torch.Tensor:
  24. x = self.layer(x)
  25. return self.another_layer(x)
  26. @dataclass
  27. class EMAConfig(object):
  28. ema_decay: float = 0.99
  29. ema_start_update: int = 0
  30. ema_fp32: bool = False
  31. ema_seed_model: Optional[str] = None
  32. ema_update_freq: int = 1
  33. class TestEMA(unittest.TestCase):
  34. def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
  35. diff = x.float() - y.float()
  36. diff_norm = torch.norm(diff)
  37. other_norm = torch.norm(y.float())
  38. if msg is None:
  39. msg = "|input - other| > {} + {} * |other|".format(atol, rtol)
  40. self.assertLessEqual(
  41. diff_norm,
  42. atol + rtol * other_norm,
  43. msg=msg,
  44. )
  45. def test_ema(self):
  46. model = DummyModule()
  47. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  48. state = deepcopy(model.state_dict())
  49. config = EMAConfig()
  50. ema = EMA(model, config)
  51. # set decay
  52. ema._set_decay(config.ema_decay)
  53. self.assertEqual(ema.get_decay(), config.ema_decay)
  54. # get model
  55. self.assertEqual(ema.get_model(), ema.model)
  56. # Since fp32 params is not used, it should be of size 0
  57. self.assertEqual(len(ema.fp32_params), 0)
  58. # EMA step
  59. x = torch.randn(32)
  60. y = model(x)
  61. loss = y.sum()
  62. loss.backward()
  63. optimizer.step()
  64. ema.step(model)
  65. ema_state_dict = ema.get_model().state_dict()
  66. for key, param in model.state_dict().items():
  67. prev_param = state[key]
  68. ema_param = ema_state_dict[key]
  69. if "version" in key:
  70. # Do not decay a model.version pytorch param
  71. continue
  72. self.assertTorchAllClose(
  73. ema_param,
  74. config.ema_decay * prev_param + (1 - config.ema_decay) * param,
  75. )
  76. # Since fp32 params is not used, it should be of size 0
  77. self.assertEqual(len(ema.fp32_params), 0)
  78. # Load EMA into model
  79. model2 = DummyModule()
  80. ema.reverse(model2)
  81. for key, param in model2.state_dict().items():
  82. ema_param = ema_state_dict[key]
  83. self.assertTrue(torch.allclose(ema_param, param))
  84. # Check that step_internal is called once
  85. with patch.object(ema, "_step_internal", return_value=None) as mock_method:
  86. ema.step(model)
  87. mock_method.assert_called_once_with(model, None)
  88. def _test_ema_start_update(self, updates):
  89. model = DummyModule()
  90. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  91. state = deepcopy(model.state_dict())
  92. config = EMAConfig(ema_start_update=1)
  93. ema = EMA(model, config)
  94. # EMA step
  95. x = torch.randn(32)
  96. y = model(x)
  97. loss = y.sum()
  98. loss.backward()
  99. optimizer.step()
  100. ema.step(model, updates=updates)
  101. ema_state_dict = ema.get_model().state_dict()
  102. self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay)
  103. for key, param in model.state_dict().items():
  104. ema_param = ema_state_dict[key]
  105. prev_param = state[key]
  106. if "version" in key:
  107. # Do not decay a model.version pytorch param
  108. continue
  109. if updates == 0:
  110. self.assertTorchAllClose(
  111. ema_param,
  112. param,
  113. )
  114. else:
  115. self.assertTorchAllClose(
  116. ema_param,
  117. config.ema_decay * prev_param + (1 - config.ema_decay) * param,
  118. )
  119. # Check that step_internal is called once
  120. with patch.object(ema, "_step_internal", return_value=None) as mock_method:
  121. ema.step(model, updates=updates)
  122. mock_method.assert_called_once_with(model, updates)
  123. def test_ema_before_start_update(self):
  124. self._test_ema_start_update(updates=0)
  125. def test_ema_after_start_update(self):
  126. self._test_ema_start_update(updates=1)
  127. def test_ema_fp32(self):
  128. dtype = torch.float
  129. model = DummyModule().to(dtype)
  130. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  131. state = deepcopy(model.state_dict())
  132. config = EMAConfig(ema_fp32=True)
  133. ema = EMA(model, config)
  134. x = torch.randn(32)
  135. y = model(x.to(dtype))
  136. loss = y.sum()
  137. loss.backward()
  138. optimizer.step()
  139. ema.step(model)
  140. for key, param in model.state_dict().items():
  141. prev_param = state[key]
  142. ema_param = ema.get_model().state_dict()[key]
  143. if "version" in key:
  144. # Do not decay a model.version pytorch param
  145. continue
  146. self.assertIn(key, ema.fp32_params)
  147. # EMA update is done in fp32, and hence the EMA param must be
  148. # closer to the EMA update done in fp32 than in fp16.
  149. self.assertLessEqual(
  150. torch.norm(
  151. ema_param.float()
  152. - (
  153. config.ema_decay * prev_param.float()
  154. + (1 - config.ema_decay) * param.float()
  155. )
  156. .to(dtype)
  157. .float()
  158. ),
  159. torch.norm(
  160. ema_param.float()
  161. - (
  162. config.ema_decay * prev_param + (1 - config.ema_decay) * param
  163. ).float()
  164. ),
  165. )
  166. self.assertTorchAllClose(
  167. ema_param,
  168. (
  169. config.ema_decay * prev_param.float()
  170. + (1 - config.ema_decay) * param.float()
  171. ).to(dtype),
  172. )
  173. @pytest.mark.skipif(
  174. not torch.cuda.is_available(),
  175. reason="CPU no longer supports Linear in half precision",
  176. )
  177. def test_ema_fp16(self):
  178. model = DummyModule().cuda().half()
  179. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  180. state = deepcopy(model.state_dict())
  181. config = EMAConfig(ema_fp32=False)
  182. ema = EMA(model, config)
  183. # Since fp32 params is not used, it should be of size 0
  184. self.assertEqual(len(ema.fp32_params), 0)
  185. x = torch.randn(32).cuda()
  186. y = model(x.half())
  187. loss = y.sum()
  188. loss.backward()
  189. optimizer.step()
  190. ema.step(model)
  191. for key, param in model.state_dict().items():
  192. prev_param = state[key]
  193. ema_param = ema.get_model().state_dict()[key]
  194. if "version" in key:
  195. # Do not decay a model.version pytorch param
  196. continue
  197. # EMA update is done in fp16, and hence the EMA param must be
  198. # closer to the EMA update done in fp16 than in fp32.
  199. self.assertLessEqual(
  200. torch.norm(
  201. ema_param.float()
  202. - (
  203. config.ema_decay * prev_param + (1 - config.ema_decay) * param
  204. ).float()
  205. ),
  206. torch.norm(
  207. ema_param.float()
  208. - (
  209. config.ema_decay * prev_param.float()
  210. + (1 - config.ema_decay) * param.float()
  211. )
  212. .half()
  213. .float()
  214. ),
  215. )
  216. self.assertTorchAllClose(
  217. ema_param,
  218. config.ema_decay * prev_param + (1 - config.ema_decay) * param,
  219. )
  220. # Since fp32 params is not used, it should be of size 0
  221. self.assertEqual(len(ema.fp32_params), 0)
  222. if __name__ == "__main__":
  223. unittest.main()