12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import unittest
- import torch
- import torch.nn as nn
- from fairseq.modules.checkpoint_activations import checkpoint_wrapper
- from torch.utils.checkpoint import checkpoint
- class Model(nn.Module):
- def __init__(
- self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs
- ):
- super().__init__()
- torch.manual_seed(0)
- self.use_pytorch_checkpoint = use_pytorch_checkpoint
- self.ffn = nn.Sequential(
- nn.Linear(32, 128),
- # add a Dropout layer to test RNG save/restore
- nn.Dropout(p=0.5),
- nn.Linear(128, 32),
- )
- if use_fairseq_checkpoint:
- self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
- self.out = nn.Linear(32, 1)
- def forward(self, x):
- if self.use_pytorch_checkpoint:
- x = checkpoint(self.ffn, x)
- else:
- x = self.ffn(x)
- return self.out(x)
- class TestActivationCheckpointing(unittest.TestCase):
- def _test_checkpoint_wrapper(self, device, log_memory_usage=False):
- def get_loss_and_gnorm(model):
- torch.manual_seed(1)
- input = torch.rand(2, 16, 32).requires_grad_(True).to(device)
- model.zero_grad()
- loss = model(input).sum()
- loss.backward()
- gnorm = torch.norm(
- torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])
- )
- return {"loss": loss, "gnorm": gnorm}
- model = Model().to(device)
- no_cpt = get_loss_and_gnorm(model)
- model = Model(use_pytorch_checkpoint=True).to(device)
- pyt_cpt = get_loss_and_gnorm(model)
- torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
- torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
- model = Model(use_fairseq_checkpoint=True).to(device)
- fairseq_cpt = get_loss_and_gnorm(model)
- torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
- torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
- model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
- fairseq_cpt_offload = get_loss_and_gnorm(model)
- torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
- torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
- def test_checkpoint_wrapper_cpu(self):
- self._test_checkpoint_wrapper(device=torch.device("cpu"))
- @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
- def test_checkpoint_wrapper_cuda(self):
- self._test_checkpoint_wrapper(device=torch.device("cuda"))
- if __name__ == "__main__":
- unittest.main()
|