test_constraints.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. from typing import List
  7. import torch
  8. from fairseq.token_generation_constraints import (
  9. ConstraintNode,
  10. OrderedConstraintState,
  11. UnorderedConstraintState,
  12. pack_constraints,
  13. )
  14. def tensorize(constraints: List[List[int]]) -> torch.Tensor:
  15. return [torch.tensor(x) for x in constraints]
  16. class TestHelperRoutines(unittest.TestCase):
  17. def setUp(self):
  18. self.examples = [
  19. ([[]], torch.tensor([[0]])),
  20. ([[], []], torch.tensor([[0], [0]])),
  21. ([[torch.tensor([1, 2])], []], torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])),
  22. (
  23. [
  24. [
  25. torch.tensor([3, 1, 2]),
  26. torch.tensor([3]),
  27. torch.tensor([4, 5, 6, 7]),
  28. ],
  29. [],
  30. [torch.tensor([1, 8, 9, 10, 1, 4, 11, 12])],
  31. ],
  32. torch.tensor(
  33. [
  34. [3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0],
  35. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  36. [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0],
  37. ]
  38. ),
  39. ),
  40. ]
  41. def test_packing(self):
  42. """Ensures the list of lists of tensors gets packed correctly."""
  43. for batch_constraints, expected_tensor in self.examples:
  44. packed = pack_constraints(batch_constraints)
  45. assert torch.equal(packed, expected_tensor)
  46. class TestUnorderedConstraintState(unittest.TestCase):
  47. def setUp(self):
  48. # Tuples of (contraint set, expected printed graph, token counts per node)
  49. self.examples = [
  50. (
  51. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  52. "([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))", # noqa
  53. {1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1},
  54. ),
  55. ([], "[None].False#0", {}),
  56. (tensorize([[0]]), "([None].False#1 [0].True#1)", {0: 1}),
  57. (
  58. tensorize([[100000, 1, 2, 3, 4, 5]]),
  59. "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))",
  60. {100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1},
  61. ),
  62. (
  63. tensorize([[1, 2], [1, 2]]),
  64. "([None].False#2 ([1].False#2 [2].True#2))",
  65. {1: 2, 2: 2},
  66. ),
  67. (
  68. tensorize([[1, 2], [3, 4]]),
  69. "([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))",
  70. {1: 1, 2: 1, 3: 1, 4: 1},
  71. ),
  72. ]
  73. self.sequences = [
  74. (
  75. self.examples[0][0],
  76. [],
  77. {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
  78. ),
  79. (
  80. self.examples[0][0],
  81. [1, 2],
  82. {"bank": 2, "num_completed": 0, "finished": False, "is_root": False},
  83. ),
  84. (
  85. self.examples[0][0],
  86. [1, 2, 94],
  87. {"bank": 1, "num_completed": 1, "finished": False, "is_root": True},
  88. ),
  89. (
  90. self.examples[0][0],
  91. [1, 3, 999, 1, 4],
  92. {"bank": 4, "num_completed": 2, "finished": False, "is_root": False},
  93. ),
  94. (
  95. self.examples[0][0],
  96. [1, 3, 999, 1, 4, 999],
  97. {"bank": 4, "num_completed": 2, "finished": False, "is_root": True},
  98. ),
  99. (
  100. self.examples[0][0],
  101. [4, 5, 6, 8],
  102. {"bank": 2, "num_completed": 1, "finished": False, "is_root": True},
  103. ),
  104. (
  105. self.examples[0][0],
  106. # Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5]
  107. # [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]],
  108. [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
  109. {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
  110. ),
  111. (
  112. self.examples[0][0],
  113. [1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
  114. {"bank": 14, "num_completed": 6, "finished": True, "is_root": True},
  115. ),
  116. (
  117. tensorize([[1], [2, 3]]),
  118. # Should not be able to get credit for entering 1 a second time
  119. [1, 1],
  120. {"bank": 1, "num_completed": 1, "finished": False, "is_root": True},
  121. ),
  122. (
  123. self.examples[4][0],
  124. [1, 2, 1, 2],
  125. {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
  126. ),
  127. (
  128. self.examples[4][0],
  129. [1, 2, 1, 2, 1],
  130. {"bank": 4, "num_completed": 2, "finished": True, "is_root": True},
  131. ),
  132. (
  133. self.examples[5][0],
  134. [1, 2, 3, 4, 5],
  135. {"bank": 4, "num_completed": 2, "finished": True, "is_root": True},
  136. ),
  137. ]
  138. def test_graphs(self):
  139. """
  140. Test whether unordered graph systems are created correctly.
  141. """
  142. for example in self.examples:
  143. constraints, expected, gold_counts = example
  144. c = ConstraintNode.create(constraints)
  145. assert (
  146. ConstraintNode.print_graph(c) == expected
  147. ), f"got {ConstraintNode.print_graph(c)}, expected {expected}"
  148. assert (
  149. c.token_counts() == gold_counts
  150. ), f"{c} got {c.token_counts()} wanted {gold_counts}"
  151. def test_next_tokens(self):
  152. """
  153. Tests that the set of next tokens is correct.
  154. """
  155. for example in self.examples:
  156. constraints, expected, gold_counts = example
  157. root = ConstraintNode.create(constraints)
  158. root_tokens = set(root.children.keys())
  159. for sequence in constraints:
  160. state = UnorderedConstraintState(root)
  161. for token in sequence:
  162. all_tokens = root_tokens.union(state.node.children.keys())
  163. assert (
  164. all_tokens == state.next_tokens()
  165. ), f"ALL {all_tokens} NEXT {state.next_tokens()}"
  166. state = state.advance(token)
  167. def test_sequences(self):
  168. for constraints, tokens, expected in self.sequences:
  169. state = UnorderedConstraintState.create(pack_constraints([constraints])[0])
  170. for token in tokens:
  171. state = state.advance(token)
  172. result = {}
  173. for attr in expected.keys():
  174. result[attr] = getattr(state, attr)
  175. assert (
  176. result == expected
  177. ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
  178. class TestOrderedConstraintState(unittest.TestCase):
  179. def setUp(self):
  180. self.sequences = [
  181. (
  182. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  183. [],
  184. {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
  185. ),
  186. (
  187. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  188. [1, 2],
  189. {"bank": 2, "num_completed": 0, "finished": False, "is_root": False},
  190. ),
  191. (
  192. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  193. [1, 2, 94],
  194. {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
  195. ),
  196. (
  197. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  198. [1, 3, 999, 1, 4],
  199. {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
  200. ),
  201. (
  202. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  203. [1, 2, 3, 999, 999],
  204. {"bank": 3, "num_completed": 1, "finished": False, "is_root": False},
  205. ),
  206. (
  207. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  208. [1, 2, 3, 77, 1, 3, 1],
  209. {"bank": 6, "num_completed": 2, "finished": False, "is_root": False},
  210. ),
  211. (
  212. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  213. [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
  214. {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
  215. ),
  216. (
  217. tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
  218. [1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
  219. {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
  220. ),
  221. (
  222. tensorize([[1], [2, 3]]),
  223. [1, 1],
  224. {"bank": 1, "num_completed": 1, "finished": False, "is_root": False},
  225. ),
  226. (
  227. tensorize([[1, 2], [1, 2]]),
  228. [1, 2, 1, 2],
  229. {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
  230. ),
  231. (
  232. tensorize([[1, 2], [1, 2]]),
  233. [1, 2, 1, 2, 1],
  234. {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
  235. ),
  236. (
  237. tensorize([[1, 2], [3, 4]]),
  238. [1, 2, 3, 4, 5],
  239. {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
  240. ),
  241. ]
  242. def test_sequences(self):
  243. for i, (constraints, tokens, expected) in enumerate(self.sequences):
  244. state = OrderedConstraintState.create(pack_constraints([constraints])[0])
  245. for token in tokens:
  246. state = state.advance(token)
  247. result = {}
  248. for attr in expected.keys():
  249. result[attr] = getattr(state, attr)
  250. assert (
  251. result == expected
  252. ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
  253. if __name__ == "__main__":
  254. unittest.main()