123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import io
- import os
- import string
- import tempfile
- import unittest
- import torch
- from fairseq import tokenizer
- from fairseq.data import Dictionary
- class TestDictionary(unittest.TestCase):
- def test_finalize(self):
- txt = [
- "A B C D",
- "B C D",
- "C D",
- "D",
- ]
- ref_ids1 = list(
- map(
- torch.IntTensor,
- [
- [4, 5, 6, 7, 2],
- [5, 6, 7, 2],
- [6, 7, 2],
- [7, 2],
- ],
- )
- )
- ref_ids2 = list(
- map(
- torch.IntTensor,
- [
- [7, 6, 5, 4, 2],
- [6, 5, 4, 2],
- [5, 4, 2],
- [4, 2],
- ],
- )
- )
- # build dictionary
- d = Dictionary()
- for line in txt:
- d.encode_line(line, add_if_not_exist=True)
- def get_ids(dictionary):
- ids = []
- for line in txt:
- ids.append(dictionary.encode_line(line, add_if_not_exist=False))
- return ids
- def assertMatch(ids, ref_ids):
- for toks, ref_toks in zip(ids, ref_ids):
- self.assertEqual(toks.size(), ref_toks.size())
- self.assertEqual(0, (toks != ref_toks).sum().item())
- ids = get_ids(d)
- assertMatch(ids, ref_ids1)
- # check finalized dictionary
- d.finalize()
- finalized_ids = get_ids(d)
- assertMatch(finalized_ids, ref_ids2)
- # write to disk and reload
- with tempfile.NamedTemporaryFile(mode="w") as tmp_dict:
- d.save(tmp_dict.name)
- d = Dictionary.load(tmp_dict.name)
- reload_ids = get_ids(d)
- assertMatch(reload_ids, ref_ids2)
- assertMatch(finalized_ids, reload_ids)
- def test_overwrite(self):
- # for example, Camembert overwrites <unk>, <s> and </s>
- dict_file = io.StringIO(
- "<unk> 999 #fairseq:overwrite\n"
- "<s> 999 #fairseq:overwrite\n"
- "</s> 999 #fairseq:overwrite\n"
- ", 999\n"
- "▁de 999\n"
- )
- d = Dictionary()
- d.add_from_file(dict_file)
- self.assertEqual(d.index("<pad>"), 1)
- self.assertEqual(d.index("foo"), 3)
- self.assertEqual(d.index("<unk>"), 4)
- self.assertEqual(d.index("<s>"), 5)
- self.assertEqual(d.index("</s>"), 6)
- self.assertEqual(d.index(","), 7)
- self.assertEqual(d.index("▁de"), 8)
- def test_no_overwrite(self):
- # for example, Camembert overwrites <unk>, <s> and </s>
- dict_file = io.StringIO(
- "<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n"
- )
- d = Dictionary()
- with self.assertRaisesRegex(RuntimeError, "Duplicate"):
- d.add_from_file(dict_file)
- def test_space(self):
- # for example, character models treat space as a symbol
- dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n")
- d = Dictionary()
- d.add_from_file(dict_file)
- self.assertEqual(d.index(" "), 4)
- self.assertEqual(d.index("a"), 5)
- self.assertEqual(d.index("b"), 6)
- def test_add_file_to_dict(self):
- counts = {}
- num_lines = 100
- per_line = 10
- with tempfile.TemporaryDirectory("test_sampling") as data_dir:
- filename = os.path.join(data_dir, "dummy.txt")
- with open(filename, "w", encoding="utf-8") as data:
- for c in string.ascii_letters:
- line = f"{c} " * per_line
- for _ in range(num_lines):
- data.write(f"{line}\n")
- counts[c] = per_line * num_lines
- per_line += 5
- dict = Dictionary()
- Dictionary.add_file_to_dictionary(
- filename, dict, tokenizer.tokenize_line, 10
- )
- dict.finalize(threshold=0, nwords=-1, padding_factor=8)
- for c in string.ascii_letters:
- count = dict.get_count(dict.index(c))
- self.assertEqual(
- counts[c], count, f"{c} count is {count} but should be {counts[c]}"
- )
- if __name__ == "__main__":
- unittest.main()
|