test_espnet_multihead_attention.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import torch
  2. import numpy as np
  3. import unittest
  4. from fairseq.modules import (
  5. ESPNETMultiHeadedAttention,
  6. RelPositionMultiHeadedAttention,
  7. RotaryPositionMultiHeadedAttention,
  8. )
  9. torch.use_deterministic_algorithms(True)
  10. class TestESPNETMultiHeadedAttention(unittest.TestCase):
  11. def setUp(self) -> None:
  12. self.T = 3
  13. self.B = 1
  14. self.C = 2
  15. torch.manual_seed(0)
  16. self.sample = torch.randn(self.T, self.B, self.C) # TBC
  17. self.sample_scores = torch.randn(self.B, 1, self.T, self.T)
  18. self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0)
  19. def test_forward(self):
  20. expected_scores = torch.tensor(
  21. [[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]]
  22. )
  23. scores, _ = self.MHA(self.sample, self.sample, self.sample)
  24. self.assertTrue(
  25. np.allclose(
  26. expected_scores.cpu().detach().numpy(),
  27. scores.cpu().detach().numpy(),
  28. atol=1e-4,
  29. )
  30. )
  31. def test_forward_qkv(self):
  32. expected_query = torch.tensor(
  33. [[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]]
  34. )
  35. expected_key = torch.tensor(
  36. [[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]]
  37. )
  38. expected_val = torch.tensor(
  39. [[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]]
  40. )
  41. sample_t = self.sample.transpose(0, 1)
  42. query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t)
  43. self.assertTrue(
  44. np.allclose(
  45. expected_query.cpu().detach().numpy(),
  46. query.cpu().detach().numpy(),
  47. atol=1e-4,
  48. )
  49. )
  50. self.assertTrue(
  51. np.allclose(
  52. expected_key.cpu().detach().numpy(),
  53. key.cpu().detach().numpy(),
  54. atol=1e-4,
  55. )
  56. )
  57. self.assertTrue(
  58. np.allclose(
  59. expected_val.cpu().detach().numpy(),
  60. val.cpu().detach().numpy(),
  61. atol=1e-4,
  62. )
  63. )
  64. def test_forward_attention(self):
  65. expected_scores = torch.tensor(
  66. [[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]]
  67. )
  68. scores = self.MHA.forward_attention(
  69. self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C),
  70. self.sample_scores,
  71. mask=None,
  72. )
  73. self.assertTrue(
  74. np.allclose(
  75. expected_scores.cpu().detach().numpy(),
  76. scores.cpu().detach().numpy(),
  77. atol=1e-4,
  78. )
  79. )
  80. class TestRelPositionMultiHeadedAttention(unittest.TestCase):
  81. def setUp(self) -> None:
  82. self.T = 3
  83. self.B = 1
  84. self.C = 2
  85. torch.manual_seed(0)
  86. self.sample = torch.randn(self.T, self.B, self.C) # TBC
  87. self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1)
  88. self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C)
  89. self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0)
  90. def test_rel_shift(self):
  91. expected_x = torch.tensor(
  92. [
  93. [
  94. [
  95. [-0.7193, -0.4033, -0.5966],
  96. [-0.8567, 1.1006, -1.0712],
  97. [-0.5663, 0.3731, -0.8920],
  98. ]
  99. ]
  100. ]
  101. )
  102. x = self.MHA.rel_shift(self.sample_x)
  103. self.assertTrue(
  104. np.allclose(
  105. expected_x.cpu().detach().numpy(),
  106. x.cpu().detach().numpy(),
  107. atol=1e-4,
  108. )
  109. )
  110. def test_forward(self):
  111. expected_scores = torch.tensor(
  112. [
  113. [[-0.9609, -0.5020]],
  114. [[-0.9308, -0.4890]],
  115. [[-0.9473, -0.4948]],
  116. [[-0.9609, -0.5020]],
  117. [[-0.9308, -0.4890]],
  118. [[-0.9473, -0.4948]],
  119. [[-0.9609, -0.5020]],
  120. [[-0.9308, -0.4890]],
  121. [[-0.9473, -0.4948]],
  122. [[-0.9609, -0.5020]],
  123. [[-0.9308, -0.4890]],
  124. [[-0.9473, -0.4948]],
  125. [[-0.9609, -0.5020]],
  126. [[-0.9308, -0.4890]],
  127. [[-0.9473, -0.4948]],
  128. ]
  129. )
  130. scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos)
  131. self.assertTrue(
  132. np.allclose(
  133. expected_scores.cpu().detach().numpy(),
  134. scores.cpu().detach().numpy(),
  135. atol=1e-4,
  136. )
  137. )
  138. class TestRotaryPositionMultiHeadedAttention(unittest.TestCase):
  139. def setUp(self) -> None:
  140. self.T = 3
  141. self.B = 1
  142. self.C = 2
  143. torch.manual_seed(0)
  144. self.sample = torch.randn(self.T, self.B, self.C) # TBC
  145. self.MHA = RotaryPositionMultiHeadedAttention(
  146. self.C, 1, dropout=0, precision=None
  147. )
  148. def test_forward(self):
  149. expected_scores = torch.tensor(
  150. [[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]]
  151. )
  152. scores, _ = self.MHA(self.sample, self.sample, self.sample)
  153. self.assertTrue(
  154. np.allclose(
  155. expected_scores.cpu().detach().numpy(),
  156. scores.cpu().detach().numpy(),
  157. atol=1e-4,
  158. )
  159. )
  160. if __name__ == "__main__":
  161. unittest.main()