test_dataclass_utils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. from argparse import ArgumentParser
  7. from dataclasses import dataclass, field
  8. from fairseq.dataclass import FairseqDataclass
  9. from fairseq.dataclass.utils import gen_parser_from_dataclass
  10. @dataclass
  11. class A(FairseqDataclass):
  12. data: str = field(default="test", metadata={"help": "the data input"})
  13. num_layers: int = field(default=200, metadata={"help": "more layers is better?"})
  14. @dataclass
  15. class B(FairseqDataclass):
  16. bar: A = field(default=A())
  17. foo: int = field(default=0, metadata={"help": "not a bar"})
  18. @dataclass
  19. class D(FairseqDataclass):
  20. arch: A = field(default=A())
  21. foo: int = field(default=0, metadata={"help": "not a bar"})
  22. @dataclass
  23. class C(FairseqDataclass):
  24. data: str = field(default="test", metadata={"help": "root level data input"})
  25. encoder: D = field(default=D())
  26. decoder: A = field(default=A())
  27. lr: int = field(default=0, metadata={"help": "learning rate"})
  28. class TestDataclassUtils(unittest.TestCase):
  29. def test_argparse_convert_basic(self):
  30. parser = ArgumentParser()
  31. gen_parser_from_dataclass(parser, A(), True)
  32. args = parser.parse_args(["--num-layers", "10", "the/data/path"])
  33. self.assertEqual(args.num_layers, 10)
  34. self.assertEqual(args.data, "the/data/path")
  35. def test_argparse_recursive(self):
  36. parser = ArgumentParser()
  37. gen_parser_from_dataclass(parser, B(), True)
  38. args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"])
  39. self.assertEqual(args.num_layers, 10)
  40. self.assertEqual(args.foo, 10)
  41. self.assertEqual(args.data, "the/data/path")
  42. def test_argparse_recursive_prefixing(self):
  43. self.maxDiff = None
  44. parser = ArgumentParser()
  45. gen_parser_from_dataclass(parser, C(), True, "")
  46. args = parser.parse_args(
  47. [
  48. "--encoder-arch-data",
  49. "ENCODER_ARCH_DATA",
  50. "--encoder-arch-num-layers",
  51. "10",
  52. "--encoder-foo",
  53. "10",
  54. "--decoder-data",
  55. "DECODER_DATA",
  56. "--decoder-num-layers",
  57. "10",
  58. "--lr",
  59. "10",
  60. "the/data/path",
  61. ]
  62. )
  63. self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA")
  64. self.assertEqual(args.encoder_arch_num_layers, 10)
  65. self.assertEqual(args.encoder_foo, 10)
  66. self.assertEqual(args.decoder_data, "DECODER_DATA")
  67. self.assertEqual(args.decoder_num_layers, 10)
  68. self.assertEqual(args.lr, 10)
  69. self.assertEqual(args.data, "the/data/path")
  70. if __name__ == "__main__":
  71. unittest.main()