test_character_token_embedder.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. import torch
  7. from fairseq.data import Dictionary
  8. from fairseq.modules import CharacterTokenEmbedder
  9. class TestCharacterTokenEmbedder(unittest.TestCase):
  10. def test_character_token_embedder(self):
  11. vocab = Dictionary()
  12. vocab.add_symbol("hello")
  13. vocab.add_symbol("there")
  14. embedder = CharacterTokenEmbedder(
  15. vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2
  16. )
  17. test_sents = [["hello", "unk", "there"], ["there"], ["hello", "there"]]
  18. max_len = max(len(s) for s in test_sents)
  19. input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
  20. for i in range(len(test_sents)):
  21. input[i][0] = vocab.eos()
  22. for j in range(len(test_sents[i])):
  23. input[i][j + 1] = vocab.index(test_sents[i][j])
  24. input[i][j + 2] = vocab.eos()
  25. embs = embedder(input)
  26. assert embs.size() == (len(test_sents), max_len + 2, 5)
  27. self.assertAlmostEqual(embs[0][0], embs[1][0])
  28. self.assertAlmostEqual(embs[0][0], embs[0][-1])
  29. self.assertAlmostEqual(embs[0][1], embs[2][1])
  30. self.assertAlmostEqual(embs[0][3], embs[1][1])
  31. embs.sum().backward()
  32. assert embedder.char_embeddings.weight.grad is not None
  33. def assertAlmostEqual(self, t1, t2):
  34. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  35. self.assertLess((t1 - t2).abs().max(), 1e-6)
  36. if __name__ == "__main__":
  37. unittest.main()