123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import os
- import shutil
- import tempfile
- import unittest
- from fairseq import options
- from fairseq.dataclass.utils import convert_namespace_to_omegaconf
- from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
- from .utils import create_dummy_data, preprocess_lm_data, train_language_model
- def make_lm_config(
- data_dir=None,
- extra_flags=None,
- task="language_modeling",
- arch="transformer_lm_gpt2_tiny",
- ):
- task_args = [task]
- if data_dir is not None:
- task_args += [data_dir]
- train_parser = options.get_training_parser()
- train_args = options.parse_args_and_arch(
- train_parser,
- [
- "--task",
- *task_args,
- "--arch",
- arch,
- "--optimizer",
- "adam",
- "--lr",
- "0.0001",
- "--max-tokens",
- "500",
- "--tokens-per-sample",
- "500",
- "--save-dir",
- data_dir,
- "--max-epoch",
- "1",
- ]
- + (extra_flags or []),
- )
- cfg = convert_namespace_to_omegaconf(train_args)
- return cfg
- def write_empty_file(path):
- with open(path, "w"):
- pass
- assert os.path.exists(path)
- class TestValidSubsetsErrors(unittest.TestCase):
- """Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
- def _test_case(self, paths, extra_flags):
- with tempfile.TemporaryDirectory() as data_dir:
- [
- write_empty_file(os.path.join(data_dir, f"{p}.bin"))
- for p in paths + ["train"]
- ]
- cfg = make_lm_config(data_dir, extra_flags=extra_flags)
- raise_if_valid_subsets_unintentionally_ignored(cfg)
- def test_default_raises(self):
- with self.assertRaises(ValueError):
- self._test_case(["valid", "valid1"], [])
- with self.assertRaises(ValueError):
- self._test_case(
- ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
- )
- def partially_specified_valid_subsets(self):
- with self.assertRaises(ValueError):
- self._test_case(
- ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
- )
- # Fix with ignore unused
- self._test_case(
- ["valid", "valid1", "valid2"],
- ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
- )
- def test_legal_configs(self):
- self._test_case(["valid"], [])
- self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
- self._test_case(["valid", "valid1"], ["--combine-val"])
- self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
- self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
- self._test_case(
- ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
- )
- self._test_case(
- ["valid1"], ["--valid-subset", "valid1"]
- ) # valid.bin doesn't need to be ignored.
- def test_disable_validation(self):
- self._test_case([], ["--disable-validation"])
- self._test_case(["valid", "valid1"], ["--disable-validation"])
- def test_dummy_task(self):
- cfg = make_lm_config(task="dummy_lm")
- raise_if_valid_subsets_unintentionally_ignored(cfg)
- def test_masked_dummy_task(self):
- cfg = make_lm_config(task="dummy_masked_lm")
- raise_if_valid_subsets_unintentionally_ignored(cfg)
- class TestCombineValidSubsets(unittest.TestCase):
- def _train(self, extra_flags):
- with self.assertLogs() as logs:
- with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
- create_dummy_data(data_dir, num_examples=20)
- preprocess_lm_data(data_dir)
- shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
- shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
- train_language_model(
- data_dir,
- "transformer_lm",
- ["--max-update", "0", "--log-format", "json"] + extra_flags,
- run_validation=False,
- )
- return [x.message for x in logs.records]
- def test_combined(self):
- flags = ["--combine-valid-subsets", "--required-batch-size-multiple", "1"]
- logs = self._train(flags)
- assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
- assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
- def test_subsets(self):
- flags = [
- "--valid-subset",
- "valid,valid1",
- "--required-batch-size-multiple",
- "1",
- ]
- logs = self._train(flags)
- assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
- assert any(["valid1_ppl" in x for x in logs]) # metrics are combined
|