test_dictionary.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import io
  6. import os
  7. import string
  8. import tempfile
  9. import unittest
  10. import torch
  11. from fairseq import tokenizer
  12. from fairseq.data import Dictionary
  13. class TestDictionary(unittest.TestCase):
  14. def test_finalize(self):
  15. txt = [
  16. "A B C D",
  17. "B C D",
  18. "C D",
  19. "D",
  20. ]
  21. ref_ids1 = list(
  22. map(
  23. torch.IntTensor,
  24. [
  25. [4, 5, 6, 7, 2],
  26. [5, 6, 7, 2],
  27. [6, 7, 2],
  28. [7, 2],
  29. ],
  30. )
  31. )
  32. ref_ids2 = list(
  33. map(
  34. torch.IntTensor,
  35. [
  36. [7, 6, 5, 4, 2],
  37. [6, 5, 4, 2],
  38. [5, 4, 2],
  39. [4, 2],
  40. ],
  41. )
  42. )
  43. # build dictionary
  44. d = Dictionary()
  45. for line in txt:
  46. d.encode_line(line, add_if_not_exist=True)
  47. def get_ids(dictionary):
  48. ids = []
  49. for line in txt:
  50. ids.append(dictionary.encode_line(line, add_if_not_exist=False))
  51. return ids
  52. def assertMatch(ids, ref_ids):
  53. for toks, ref_toks in zip(ids, ref_ids):
  54. self.assertEqual(toks.size(), ref_toks.size())
  55. self.assertEqual(0, (toks != ref_toks).sum().item())
  56. ids = get_ids(d)
  57. assertMatch(ids, ref_ids1)
  58. # check finalized dictionary
  59. d.finalize()
  60. finalized_ids = get_ids(d)
  61. assertMatch(finalized_ids, ref_ids2)
  62. # write to disk and reload
  63. with tempfile.NamedTemporaryFile(mode="w") as tmp_dict:
  64. d.save(tmp_dict.name)
  65. d = Dictionary.load(tmp_dict.name)
  66. reload_ids = get_ids(d)
  67. assertMatch(reload_ids, ref_ids2)
  68. assertMatch(finalized_ids, reload_ids)
  69. def test_overwrite(self):
  70. # for example, Camembert overwrites <unk>, <s> and </s>
  71. dict_file = io.StringIO(
  72. "<unk> 999 #fairseq:overwrite\n"
  73. "<s> 999 #fairseq:overwrite\n"
  74. "</s> 999 #fairseq:overwrite\n"
  75. ", 999\n"
  76. "▁de 999\n"
  77. )
  78. d = Dictionary()
  79. d.add_from_file(dict_file)
  80. self.assertEqual(d.index("<pad>"), 1)
  81. self.assertEqual(d.index("foo"), 3)
  82. self.assertEqual(d.index("<unk>"), 4)
  83. self.assertEqual(d.index("<s>"), 5)
  84. self.assertEqual(d.index("</s>"), 6)
  85. self.assertEqual(d.index(","), 7)
  86. self.assertEqual(d.index("▁de"), 8)
  87. def test_no_overwrite(self):
  88. # for example, Camembert overwrites <unk>, <s> and </s>
  89. dict_file = io.StringIO(
  90. "<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n"
  91. )
  92. d = Dictionary()
  93. with self.assertRaisesRegex(RuntimeError, "Duplicate"):
  94. d.add_from_file(dict_file)
  95. def test_space(self):
  96. # for example, character models treat space as a symbol
  97. dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n")
  98. d = Dictionary()
  99. d.add_from_file(dict_file)
  100. self.assertEqual(d.index(" "), 4)
  101. self.assertEqual(d.index("a"), 5)
  102. self.assertEqual(d.index("b"), 6)
  103. def test_add_file_to_dict(self):
  104. counts = {}
  105. num_lines = 100
  106. per_line = 10
  107. with tempfile.TemporaryDirectory("test_sampling") as data_dir:
  108. filename = os.path.join(data_dir, "dummy.txt")
  109. with open(filename, "w", encoding="utf-8") as data:
  110. for c in string.ascii_letters:
  111. line = f"{c} " * per_line
  112. for _ in range(num_lines):
  113. data.write(f"{line}\n")
  114. counts[c] = per_line * num_lines
  115. per_line += 5
  116. dict = Dictionary()
  117. Dictionary.add_file_to_dictionary(
  118. filename, dict, tokenizer.tokenize_line, 10
  119. )
  120. dict.finalize(threshold=0, nwords=-1, padding_factor=8)
  121. for c in string.ascii_letters:
  122. count = dict.get_count(dict.index(c))
  123. self.assertEqual(
  124. counts[c], count, f"{c} count is {count} but should be {counts[c]}"
  125. )
  126. if __name__ == "__main__":
  127. unittest.main()