test_transformer.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import argparse
  2. import unittest
  3. from typing import Any, Dict, Sequence
  4. import torch
  5. from fairseq.models import transformer
  6. from tests.test_roberta import FakeTask
  7. def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
  8. if not tok:
  9. tok = [10, 11, 12, 13, 14, 15, 2]
  10. batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
  11. sample = {
  12. "net_input": {
  13. "src_tokens": batch,
  14. "prev_output_tokens": batch,
  15. "src_lengths": torch.tensor(
  16. [len(tok)] * batch_size, dtype=torch.long, device=batch.device
  17. ),
  18. },
  19. "target": batch[:, 1:],
  20. }
  21. return sample
  22. def mk_transformer(**extra_args: Any):
  23. overrides = {
  24. # Use characteristics dimensions
  25. "encoder_embed_dim": 12,
  26. "encoder_ffn_embed_dim": 14,
  27. "decoder_embed_dim": 12,
  28. "decoder_ffn_embed_dim": 14,
  29. # Disable dropout so we have comparable tests.
  30. "dropout": 0,
  31. "attention_dropout": 0,
  32. "activation_dropout": 0,
  33. "encoder_layerdrop": 0,
  34. }
  35. overrides.update(extra_args)
  36. # Overrides the defaults from the parser
  37. args = argparse.Namespace(**overrides)
  38. transformer.tiny_architecture(args)
  39. torch.manual_seed(0)
  40. task = FakeTask(args)
  41. return transformer.TransformerModel.build_model(args, task)
  42. class TransformerTestCase(unittest.TestCase):
  43. def test_forward_backward(self):
  44. model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
  45. sample = mk_sample()
  46. o, _ = model.forward(**sample["net_input"])
  47. loss = o.sum()
  48. loss.backward()
  49. def test_different_encoder_decoder_embed_dim(self):
  50. model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
  51. sample = mk_sample()
  52. o, _ = model.forward(**sample["net_input"])
  53. loss = o.sum()
  54. loss.backward()