123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import unittest
- from argparse import ArgumentParser
- from dataclasses import dataclass, field
- from fairseq.dataclass import FairseqDataclass
- from fairseq.dataclass.utils import gen_parser_from_dataclass
- @dataclass
- class A(FairseqDataclass):
- data: str = field(default="test", metadata={"help": "the data input"})
- num_layers: int = field(default=200, metadata={"help": "more layers is better?"})
- @dataclass
- class B(FairseqDataclass):
- bar: A = field(default=A())
- foo: int = field(default=0, metadata={"help": "not a bar"})
- @dataclass
- class D(FairseqDataclass):
- arch: A = field(default=A())
- foo: int = field(default=0, metadata={"help": "not a bar"})
- @dataclass
- class C(FairseqDataclass):
- data: str = field(default="test", metadata={"help": "root level data input"})
- encoder: D = field(default=D())
- decoder: A = field(default=A())
- lr: int = field(default=0, metadata={"help": "learning rate"})
- class TestDataclassUtils(unittest.TestCase):
- def test_argparse_convert_basic(self):
- parser = ArgumentParser()
- gen_parser_from_dataclass(parser, A(), True)
- args = parser.parse_args(["--num-layers", "10", "the/data/path"])
- self.assertEqual(args.num_layers, 10)
- self.assertEqual(args.data, "the/data/path")
- def test_argparse_recursive(self):
- parser = ArgumentParser()
- gen_parser_from_dataclass(parser, B(), True)
- args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"])
- self.assertEqual(args.num_layers, 10)
- self.assertEqual(args.foo, 10)
- self.assertEqual(args.data, "the/data/path")
- def test_argparse_recursive_prefixing(self):
- self.maxDiff = None
- parser = ArgumentParser()
- gen_parser_from_dataclass(parser, C(), True, "")
- args = parser.parse_args(
- [
- "--encoder-arch-data",
- "ENCODER_ARCH_DATA",
- "--encoder-arch-num-layers",
- "10",
- "--encoder-foo",
- "10",
- "--decoder-data",
- "DECODER_DATA",
- "--decoder-num-layers",
- "10",
- "--lr",
- "10",
- "the/data/path",
- ]
- )
- self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA")
- self.assertEqual(args.encoder_arch_num_layers, 10)
- self.assertEqual(args.encoder_foo, 10)
- self.assertEqual(args.decoder_data, "DECODER_DATA")
- self.assertEqual(args.decoder_num_layers, 10)
- self.assertEqual(args.lr, 10)
- self.assertEqual(args.data, "the/data/path")
- if __name__ == "__main__":
- unittest.main()
|