1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import argparse
- import unittest
- from typing import Any, Dict, Sequence
- import torch
- from fairseq.models import transformer
- from tests.test_roberta import FakeTask
- def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
- if not tok:
- tok = [10, 11, 12, 13, 14, 15, 2]
- batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
- sample = {
- "net_input": {
- "src_tokens": batch,
- "prev_output_tokens": batch,
- "src_lengths": torch.tensor(
- [len(tok)] * batch_size, dtype=torch.long, device=batch.device
- ),
- },
- "target": batch[:, 1:],
- }
- return sample
- def mk_transformer(**extra_args: Any):
- overrides = {
- # Use characteristics dimensions
- "encoder_embed_dim": 12,
- "encoder_ffn_embed_dim": 14,
- "decoder_embed_dim": 12,
- "decoder_ffn_embed_dim": 14,
- # Disable dropout so we have comparable tests.
- "dropout": 0,
- "attention_dropout": 0,
- "activation_dropout": 0,
- "encoder_layerdrop": 0,
- }
- overrides.update(extra_args)
- # Overrides the defaults from the parser
- args = argparse.Namespace(**overrides)
- transformer.tiny_architecture(args)
- torch.manual_seed(0)
- task = FakeTask(args)
- return transformer.TransformerModel.build_model(args, task)
- class TransformerTestCase(unittest.TestCase):
- def test_forward_backward(self):
- model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
- sample = mk_sample()
- o, _ = model.forward(**sample["net_input"])
- loss = o.sum()
- loss.backward()
- def test_different_encoder_decoder_embed_dim(self):
- model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
- sample = mk_sample()
- o, _ = model.forward(**sample["net_input"])
- loss = o.sum()
- loss.backward()
|