test_valid_subset_checks.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import shutil
  3. import tempfile
  4. import unittest
  5. from fairseq import options
  6. from fairseq.dataclass.utils import convert_namespace_to_omegaconf
  7. from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
  8. from .utils import create_dummy_data, preprocess_lm_data, train_language_model
  9. def make_lm_config(
  10. data_dir=None,
  11. extra_flags=None,
  12. task="language_modeling",
  13. arch="transformer_lm_gpt2_tiny",
  14. ):
  15. task_args = [task]
  16. if data_dir is not None:
  17. task_args += [data_dir]
  18. train_parser = options.get_training_parser()
  19. train_args = options.parse_args_and_arch(
  20. train_parser,
  21. [
  22. "--task",
  23. *task_args,
  24. "--arch",
  25. arch,
  26. "--optimizer",
  27. "adam",
  28. "--lr",
  29. "0.0001",
  30. "--max-tokens",
  31. "500",
  32. "--tokens-per-sample",
  33. "500",
  34. "--save-dir",
  35. data_dir,
  36. "--max-epoch",
  37. "1",
  38. ]
  39. + (extra_flags or []),
  40. )
  41. cfg = convert_namespace_to_omegaconf(train_args)
  42. return cfg
  43. def write_empty_file(path):
  44. with open(path, "w"):
  45. pass
  46. assert os.path.exists(path)
  47. class TestValidSubsetsErrors(unittest.TestCase):
  48. """Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
  49. def _test_case(self, paths, extra_flags):
  50. with tempfile.TemporaryDirectory() as data_dir:
  51. [
  52. write_empty_file(os.path.join(data_dir, f"{p}.bin"))
  53. for p in paths + ["train"]
  54. ]
  55. cfg = make_lm_config(data_dir, extra_flags=extra_flags)
  56. raise_if_valid_subsets_unintentionally_ignored(cfg)
  57. def test_default_raises(self):
  58. with self.assertRaises(ValueError):
  59. self._test_case(["valid", "valid1"], [])
  60. with self.assertRaises(ValueError):
  61. self._test_case(
  62. ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
  63. )
  64. def partially_specified_valid_subsets(self):
  65. with self.assertRaises(ValueError):
  66. self._test_case(
  67. ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
  68. )
  69. # Fix with ignore unused
  70. self._test_case(
  71. ["valid", "valid1", "valid2"],
  72. ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
  73. )
  74. def test_legal_configs(self):
  75. self._test_case(["valid"], [])
  76. self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
  77. self._test_case(["valid", "valid1"], ["--combine-val"])
  78. self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
  79. self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
  80. self._test_case(
  81. ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
  82. )
  83. self._test_case(
  84. ["valid1"], ["--valid-subset", "valid1"]
  85. ) # valid.bin doesn't need to be ignored.
  86. def test_disable_validation(self):
  87. self._test_case([], ["--disable-validation"])
  88. self._test_case(["valid", "valid1"], ["--disable-validation"])
  89. def test_dummy_task(self):
  90. cfg = make_lm_config(task="dummy_lm")
  91. raise_if_valid_subsets_unintentionally_ignored(cfg)
  92. def test_masked_dummy_task(self):
  93. cfg = make_lm_config(task="dummy_masked_lm")
  94. raise_if_valid_subsets_unintentionally_ignored(cfg)
  95. class TestCombineValidSubsets(unittest.TestCase):
  96. def _train(self, extra_flags):
  97. with self.assertLogs() as logs:
  98. with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
  99. create_dummy_data(data_dir, num_examples=20)
  100. preprocess_lm_data(data_dir)
  101. shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
  102. shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
  103. train_language_model(
  104. data_dir,
  105. "transformer_lm",
  106. ["--max-update", "0", "--log-format", "json"] + extra_flags,
  107. run_validation=False,
  108. )
  109. return [x.message for x in logs.records]
  110. def test_combined(self):
  111. flags = ["--combine-valid-subsets", "--required-batch-size-multiple", "1"]
  112. logs = self._train(flags)
  113. assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
  114. assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
  115. def test_subsets(self):
  116. flags = [
  117. "--valid-subset",
  118. "valid,valid1",
  119. "--required-batch-size-multiple",
  120. "1",
  121. ]
  122. logs = self._train(flags)
  123. assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
  124. assert any(["valid1_ppl" in x for x in logs]) # metrics are combined