test_denoising.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import os
  6. import unittest
  7. from tempfile import TemporaryDirectory
  8. from fairseq import options
  9. from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
  10. from fairseq.dataclass.utils import convert_namespace_to_omegaconf
  11. from fairseq.tasks.denoising import DenoisingTask
  12. from tests.utils import build_vocab, make_data
  13. class TestDenoising(unittest.TestCase):
  14. def test_denoising(self):
  15. with TemporaryDirectory() as dirname:
  16. # prep input file
  17. raw_file = os.path.join(dirname, "raw")
  18. data = make_data(out_file=raw_file)
  19. vocab = build_vocab(data)
  20. # binarize
  21. binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
  22. split = "train"
  23. bin_file = os.path.join(dirname, split)
  24. dataset_impl = "mmap"
  25. FileBinarizer.multiprocess_dataset(
  26. input_file=raw_file,
  27. binarizer=binarizer,
  28. dataset_impl=dataset_impl,
  29. vocab_size=len(vocab),
  30. output_prefix=bin_file,
  31. )
  32. # setup task
  33. train_args = options.parse_args_and_arch(
  34. options.get_training_parser(),
  35. [
  36. "--task",
  37. "denoising",
  38. "--arch",
  39. "bart_base",
  40. "--seed",
  41. "42",
  42. "--mask-length",
  43. "word",
  44. "--permute-sentences",
  45. "1",
  46. "--rotate",
  47. "0",
  48. "--replace-length",
  49. "-1",
  50. "--mask",
  51. "0.2",
  52. dirname,
  53. ],
  54. )
  55. cfg = convert_namespace_to_omegaconf(train_args)
  56. task = DenoisingTask(cfg.task, binarizer.dict)
  57. # load datasets
  58. original_dataset = task._load_dataset_split(bin_file, 1, False)
  59. task.load_dataset(split)
  60. masked_dataset = task.dataset(split)
  61. iterator = task.get_batch_iterator(
  62. dataset=masked_dataset,
  63. max_tokens=65_536,
  64. max_positions=4_096,
  65. ).next_epoch_itr(shuffle=False)
  66. mask_index = task.source_dictionary.index("<mask>")
  67. for batch in iterator:
  68. for sample in range(len(batch)):
  69. net_input = batch["net_input"]
  70. masked_src_tokens = net_input["src_tokens"][sample]
  71. masked_src_length = net_input["src_lengths"][sample]
  72. masked_tgt_tokens = batch["target"][sample]
  73. sample_id = batch["id"][sample]
  74. original_tokens = original_dataset[sample_id]
  75. original_tokens = original_tokens.masked_select(
  76. masked_src_tokens[:masked_src_length] == mask_index
  77. )
  78. masked_tokens = masked_tgt_tokens.masked_select(
  79. masked_src_tokens == mask_index
  80. )
  81. assert masked_tokens.equal(original_tokens)
  82. if __name__ == "__main__":
  83. unittest.main()