123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import unittest
- from typing import List
- import torch
- from fairseq.token_generation_constraints import (
- ConstraintNode,
- OrderedConstraintState,
- UnorderedConstraintState,
- pack_constraints,
- )
- def tensorize(constraints: List[List[int]]) -> torch.Tensor:
- return [torch.tensor(x) for x in constraints]
- class TestHelperRoutines(unittest.TestCase):
- def setUp(self):
- self.examples = [
- ([[]], torch.tensor([[0]])),
- ([[], []], torch.tensor([[0], [0]])),
- ([[torch.tensor([1, 2])], []], torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])),
- (
- [
- [
- torch.tensor([3, 1, 2]),
- torch.tensor([3]),
- torch.tensor([4, 5, 6, 7]),
- ],
- [],
- [torch.tensor([1, 8, 9, 10, 1, 4, 11, 12])],
- ],
- torch.tensor(
- [
- [3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0],
- ]
- ),
- ),
- ]
- def test_packing(self):
- """Ensures the list of lists of tensors gets packed correctly."""
- for batch_constraints, expected_tensor in self.examples:
- packed = pack_constraints(batch_constraints)
- assert torch.equal(packed, expected_tensor)
- class TestUnorderedConstraintState(unittest.TestCase):
- def setUp(self):
- # Tuples of (contraint set, expected printed graph, token counts per node)
- self.examples = [
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- "([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))", # noqa
- {1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1},
- ),
- ([], "[None].False#0", {}),
- (tensorize([[0]]), "([None].False#1 [0].True#1)", {0: 1}),
- (
- tensorize([[100000, 1, 2, 3, 4, 5]]),
- "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))",
- {100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1},
- ),
- (
- tensorize([[1, 2], [1, 2]]),
- "([None].False#2 ([1].False#2 [2].True#2))",
- {1: 2, 2: 2},
- ),
- (
- tensorize([[1, 2], [3, 4]]),
- "([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))",
- {1: 1, 2: 1, 3: 1, 4: 1},
- ),
- ]
- self.sequences = [
- (
- self.examples[0][0],
- [],
- {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
- ),
- (
- self.examples[0][0],
- [1, 2],
- {"bank": 2, "num_completed": 0, "finished": False, "is_root": False},
- ),
- (
- self.examples[0][0],
- [1, 2, 94],
- {"bank": 1, "num_completed": 1, "finished": False, "is_root": True},
- ),
- (
- self.examples[0][0],
- [1, 3, 999, 1, 4],
- {"bank": 4, "num_completed": 2, "finished": False, "is_root": False},
- ),
- (
- self.examples[0][0],
- [1, 3, 999, 1, 4, 999],
- {"bank": 4, "num_completed": 2, "finished": False, "is_root": True},
- ),
- (
- self.examples[0][0],
- [4, 5, 6, 8],
- {"bank": 2, "num_completed": 1, "finished": False, "is_root": True},
- ),
- (
- self.examples[0][0],
- # Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5]
- # [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]],
- [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
- {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
- ),
- (
- self.examples[0][0],
- [1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
- {"bank": 14, "num_completed": 6, "finished": True, "is_root": True},
- ),
- (
- tensorize([[1], [2, 3]]),
- # Should not be able to get credit for entering 1 a second time
- [1, 1],
- {"bank": 1, "num_completed": 1, "finished": False, "is_root": True},
- ),
- (
- self.examples[4][0],
- [1, 2, 1, 2],
- {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
- ),
- (
- self.examples[4][0],
- [1, 2, 1, 2, 1],
- {"bank": 4, "num_completed": 2, "finished": True, "is_root": True},
- ),
- (
- self.examples[5][0],
- [1, 2, 3, 4, 5],
- {"bank": 4, "num_completed": 2, "finished": True, "is_root": True},
- ),
- ]
- def test_graphs(self):
- """
- Test whether unordered graph systems are created correctly.
- """
- for example in self.examples:
- constraints, expected, gold_counts = example
- c = ConstraintNode.create(constraints)
- assert (
- ConstraintNode.print_graph(c) == expected
- ), f"got {ConstraintNode.print_graph(c)}, expected {expected}"
- assert (
- c.token_counts() == gold_counts
- ), f"{c} got {c.token_counts()} wanted {gold_counts}"
- def test_next_tokens(self):
- """
- Tests that the set of next tokens is correct.
- """
- for example in self.examples:
- constraints, expected, gold_counts = example
- root = ConstraintNode.create(constraints)
- root_tokens = set(root.children.keys())
- for sequence in constraints:
- state = UnorderedConstraintState(root)
- for token in sequence:
- all_tokens = root_tokens.union(state.node.children.keys())
- assert (
- all_tokens == state.next_tokens()
- ), f"ALL {all_tokens} NEXT {state.next_tokens()}"
- state = state.advance(token)
- def test_sequences(self):
- for constraints, tokens, expected in self.sequences:
- state = UnorderedConstraintState.create(pack_constraints([constraints])[0])
- for token in tokens:
- state = state.advance(token)
- result = {}
- for attr in expected.keys():
- result[attr] = getattr(state, attr)
- assert (
- result == expected
- ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
- class TestOrderedConstraintState(unittest.TestCase):
- def setUp(self):
- self.sequences = [
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [],
- {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 2],
- {"bank": 2, "num_completed": 0, "finished": False, "is_root": False},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 2, 94],
- {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 3, 999, 1, 4],
- {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 2, 3, 999, 999],
- {"bank": 3, "num_completed": 1, "finished": False, "is_root": False},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 2, 3, 77, 1, 3, 1],
- {"bank": 6, "num_completed": 2, "finished": False, "is_root": False},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
- {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
- ),
- (
- tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
- [1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
- {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
- ),
- (
- tensorize([[1], [2, 3]]),
- [1, 1],
- {"bank": 1, "num_completed": 1, "finished": False, "is_root": False},
- ),
- (
- tensorize([[1, 2], [1, 2]]),
- [1, 2, 1, 2],
- {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
- ),
- (
- tensorize([[1, 2], [1, 2]]),
- [1, 2, 1, 2, 1],
- {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
- ),
- (
- tensorize([[1, 2], [3, 4]]),
- [1, 2, 3, 4, 5],
- {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
- ),
- ]
- def test_sequences(self):
- for i, (constraints, tokens, expected) in enumerate(self.sequences):
- state = OrderedConstraintState.create(pack_constraints([constraints])[0])
- for token in tokens:
- state = state.advance(token)
- result = {}
- for attr in expected.keys():
- result[attr] = getattr(state, attr)
- assert (
- result == expected
- ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
- if __name__ == "__main__":
- unittest.main()
|