test_export.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import tempfile
  8. import unittest
  9. import torch
  10. from fairseq.data.dictionary import Dictionary
  11. from fairseq.models.transformer import TransformerModel
  12. from fairseq.modules import multihead_attention, sinusoidal_positional_embedding
  13. from fairseq.tasks.fairseq_task import LegacyFairseqTask
  14. DEFAULT_TEST_VOCAB_SIZE = 100
  15. class DummyTask(LegacyFairseqTask):
  16. def __init__(self, args):
  17. super().__init__(args)
  18. self.dictionary = get_dummy_dictionary()
  19. if getattr(self.args, "ctc", False):
  20. self.dictionary.add_symbol("<ctc_blank>")
  21. self.src_dict = self.dictionary
  22. self.tgt_dict = self.dictionary
  23. @property
  24. def source_dictionary(self):
  25. return self.src_dict
  26. @property
  27. def target_dictionary(self):
  28. return self.dictionary
  29. def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
  30. dummy_dict = Dictionary()
  31. # add dummy symbol to satisfy vocab size
  32. for id, _ in enumerate(range(vocab_size)):
  33. dummy_dict.add_symbol("{}".format(id), 1000)
  34. return dummy_dict
  35. def get_dummy_task_and_parser():
  36. """
  37. Return a dummy task and argument parser, which can be used to
  38. create a model/criterion.
  39. """
  40. parser = argparse.ArgumentParser(
  41. description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
  42. )
  43. DummyTask.add_args(parser)
  44. args = parser.parse_args([])
  45. task = DummyTask.setup_task(args)
  46. return task, parser
  47. def _test_save_and_load(scripted_module):
  48. with tempfile.NamedTemporaryFile() as f:
  49. scripted_module.save(f.name)
  50. torch.jit.load(f.name)
  51. class TestExportModels(unittest.TestCase):
  52. def test_export_multihead_attention(self):
  53. module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
  54. scripted = torch.jit.script(module)
  55. _test_save_and_load(scripted)
  56. def test_incremental_state_multihead_attention(self):
  57. module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
  58. module1 = torch.jit.script(module1)
  59. module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
  60. module2 = torch.jit.script(module2)
  61. state = {}
  62. state = module1.set_incremental_state(state, "key", {"a": torch.tensor([1])})
  63. state = module2.set_incremental_state(state, "key", {"a": torch.tensor([2])})
  64. v1 = module1.get_incremental_state(state, "key")["a"]
  65. v2 = module2.get_incremental_state(state, "key")["a"]
  66. self.assertEqual(v1, 1)
  67. self.assertEqual(v2, 2)
  68. def test_positional_embedding(self):
  69. module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding(
  70. embedding_dim=8, padding_idx=1
  71. )
  72. scripted = torch.jit.script(module)
  73. _test_save_and_load(scripted)
  74. @unittest.skipIf(
  75. torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
  76. )
  77. def test_export_transformer(self):
  78. task, parser = get_dummy_task_and_parser()
  79. TransformerModel.add_args(parser)
  80. args = parser.parse_args([])
  81. model = TransformerModel.build_model(args, task)
  82. scripted = torch.jit.script(model)
  83. _test_save_and_load(scripted)
  84. @unittest.skipIf(
  85. torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
  86. )
  87. def test_export_transformer_no_token_pos_emb(self):
  88. task, parser = get_dummy_task_and_parser()
  89. TransformerModel.add_args(parser)
  90. args = parser.parse_args([])
  91. args.no_token_positional_embeddings = True
  92. model = TransformerModel.build_model(args, task)
  93. scripted = torch.jit.script(model)
  94. _test_save_and_load(scripted)
  95. if __name__ == "__main__":
  96. unittest.main()