123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import torch
- import numpy as np
- import unittest
- from fairseq.modules import (
- ESPNETMultiHeadedAttention,
- RelPositionMultiHeadedAttention,
- RotaryPositionMultiHeadedAttention,
- )
- torch.use_deterministic_algorithms(True)
- class TestESPNETMultiHeadedAttention(unittest.TestCase):
- def setUp(self) -> None:
- self.T = 3
- self.B = 1
- self.C = 2
- torch.manual_seed(0)
- self.sample = torch.randn(self.T, self.B, self.C) # TBC
- self.sample_scores = torch.randn(self.B, 1, self.T, self.T)
- self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0)
- def test_forward(self):
- expected_scores = torch.tensor(
- [[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]]
- )
- scores, _ = self.MHA(self.sample, self.sample, self.sample)
- self.assertTrue(
- np.allclose(
- expected_scores.cpu().detach().numpy(),
- scores.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- def test_forward_qkv(self):
- expected_query = torch.tensor(
- [[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]]
- )
- expected_key = torch.tensor(
- [[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]]
- )
- expected_val = torch.tensor(
- [[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]]
- )
- sample_t = self.sample.transpose(0, 1)
- query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t)
- self.assertTrue(
- np.allclose(
- expected_query.cpu().detach().numpy(),
- query.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- self.assertTrue(
- np.allclose(
- expected_key.cpu().detach().numpy(),
- key.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- self.assertTrue(
- np.allclose(
- expected_val.cpu().detach().numpy(),
- val.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- def test_forward_attention(self):
- expected_scores = torch.tensor(
- [[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]]
- )
- scores = self.MHA.forward_attention(
- self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C),
- self.sample_scores,
- mask=None,
- )
- self.assertTrue(
- np.allclose(
- expected_scores.cpu().detach().numpy(),
- scores.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- class TestRelPositionMultiHeadedAttention(unittest.TestCase):
- def setUp(self) -> None:
- self.T = 3
- self.B = 1
- self.C = 2
- torch.manual_seed(0)
- self.sample = torch.randn(self.T, self.B, self.C) # TBC
- self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1)
- self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C)
- self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0)
- def test_rel_shift(self):
- expected_x = torch.tensor(
- [
- [
- [
- [-0.7193, -0.4033, -0.5966],
- [-0.8567, 1.1006, -1.0712],
- [-0.5663, 0.3731, -0.8920],
- ]
- ]
- ]
- )
- x = self.MHA.rel_shift(self.sample_x)
- self.assertTrue(
- np.allclose(
- expected_x.cpu().detach().numpy(),
- x.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- def test_forward(self):
- expected_scores = torch.tensor(
- [
- [[-0.9609, -0.5020]],
- [[-0.9308, -0.4890]],
- [[-0.9473, -0.4948]],
- [[-0.9609, -0.5020]],
- [[-0.9308, -0.4890]],
- [[-0.9473, -0.4948]],
- [[-0.9609, -0.5020]],
- [[-0.9308, -0.4890]],
- [[-0.9473, -0.4948]],
- [[-0.9609, -0.5020]],
- [[-0.9308, -0.4890]],
- [[-0.9473, -0.4948]],
- [[-0.9609, -0.5020]],
- [[-0.9308, -0.4890]],
- [[-0.9473, -0.4948]],
- ]
- )
- scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos)
- self.assertTrue(
- np.allclose(
- expected_scores.cpu().detach().numpy(),
- scores.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- class TestRotaryPositionMultiHeadedAttention(unittest.TestCase):
- def setUp(self) -> None:
- self.T = 3
- self.B = 1
- self.C = 2
- torch.manual_seed(0)
- self.sample = torch.randn(self.T, self.B, self.C) # TBC
- self.MHA = RotaryPositionMultiHeadedAttention(
- self.C, 1, dropout=0, precision=None
- )
- def test_forward(self):
- expected_scores = torch.tensor(
- [[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]]
- )
- scores, _ = self.MHA(self.sample, self.sample, self.sample)
- self.assertTrue(
- np.allclose(
- expected_scores.cpu().detach().numpy(),
- scores.cpu().detach().numpy(),
- atol=1e-4,
- )
- )
- if __name__ == "__main__":
- unittest.main()
|