test_activation_checkpointing.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. import torch
  7. import torch.nn as nn
  8. from fairseq.modules.checkpoint_activations import checkpoint_wrapper
  9. from torch.utils.checkpoint import checkpoint
  10. class Model(nn.Module):
  11. def __init__(
  12. self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs
  13. ):
  14. super().__init__()
  15. torch.manual_seed(0)
  16. self.use_pytorch_checkpoint = use_pytorch_checkpoint
  17. self.ffn = nn.Sequential(
  18. nn.Linear(32, 128),
  19. # add a Dropout layer to test RNG save/restore
  20. nn.Dropout(p=0.5),
  21. nn.Linear(128, 32),
  22. )
  23. if use_fairseq_checkpoint:
  24. self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
  25. self.out = nn.Linear(32, 1)
  26. def forward(self, x):
  27. if self.use_pytorch_checkpoint:
  28. x = checkpoint(self.ffn, x)
  29. else:
  30. x = self.ffn(x)
  31. return self.out(x)
  32. class TestActivationCheckpointing(unittest.TestCase):
  33. def _test_checkpoint_wrapper(self, device, log_memory_usage=False):
  34. def get_loss_and_gnorm(model):
  35. torch.manual_seed(1)
  36. input = torch.rand(2, 16, 32).requires_grad_(True).to(device)
  37. model.zero_grad()
  38. loss = model(input).sum()
  39. loss.backward()
  40. gnorm = torch.norm(
  41. torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])
  42. )
  43. return {"loss": loss, "gnorm": gnorm}
  44. model = Model().to(device)
  45. no_cpt = get_loss_and_gnorm(model)
  46. model = Model(use_pytorch_checkpoint=True).to(device)
  47. pyt_cpt = get_loss_and_gnorm(model)
  48. torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
  49. torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
  50. model = Model(use_fairseq_checkpoint=True).to(device)
  51. fairseq_cpt = get_loss_and_gnorm(model)
  52. torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
  53. torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
  54. model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
  55. fairseq_cpt_offload = get_loss_and_gnorm(model)
  56. torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
  57. torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
  58. def test_checkpoint_wrapper_cpu(self):
  59. self._test_checkpoint_wrapper(device=torch.device("cpu"))
  60. @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
  61. def test_checkpoint_wrapper_cuda(self):
  62. self._test_checkpoint_wrapper(device=torch.device("cuda"))
  63. if __name__ == "__main__":
  64. unittest.main()