test_rotary_positional_embedding.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import numpy as np
  3. import unittest
  4. from fairseq.modules.rotary_positional_embedding import apply_rotary_pos_emb
  5. from fairseq.modules import RotaryPositionalEmbedding
  6. class TestRotaryPositionalEmbedding(unittest.TestCase):
  7. def setUp(self) -> None:
  8. self.T = 3
  9. self.B = 1
  10. self.C = 2
  11. torch.manual_seed(0)
  12. self.sample = torch.randn(self.T, self.B, self.C) # TBC
  13. self.rope_pos_emd = RotaryPositionalEmbedding(dim=self.C)
  14. def test_forward(self):
  15. expected_cos = torch.tensor(
  16. [[[[1.0000, 1.0000]]], [[[0.5403, 0.5403]]], [[[-0.4161, -0.4161]]]]
  17. )
  18. expected_sin = torch.tensor(
  19. [[[[0.0000, 0.0000]]], [[[0.8415, 0.8415]]], [[[0.9093, 0.9093]]]]
  20. )
  21. cos, sin = self.rope_pos_emd(self.sample, self.T)
  22. self.assertTrue(
  23. np.allclose(
  24. expected_cos.cpu().detach().numpy(),
  25. cos.cpu().detach().numpy(),
  26. atol=1e-4,
  27. )
  28. )
  29. self.assertTrue(
  30. np.allclose(
  31. expected_sin.cpu().detach().numpy(),
  32. sin.cpu().detach().numpy(),
  33. atol=1e-4,
  34. )
  35. )
  36. def test_apply_rotary_pos_emb(self):
  37. cos, sin = self.rope_pos_emd(self.sample, self.T)
  38. query = self.sample.view(self.T, self.B, 1, self.C)
  39. expected_query = torch.tensor(
  40. [[[[1.5410, -0.2934]]], [[[-1.6555, -1.5263]]], [[[1.7231, -0.4041]]]]
  41. )
  42. new_query, new_key = apply_rotary_pos_emb(query, query, cos, sin)
  43. self.assertTrue(
  44. np.allclose(
  45. expected_query.cpu().detach().numpy(),
  46. new_query.cpu().detach().numpy(),
  47. atol=1e-4,
  48. )
  49. )
  50. self.assertTrue(
  51. np.allclose(
  52. expected_query.cpu().detach().numpy(),
  53. new_key.cpu().detach().numpy(),
  54. atol=1e-4,
  55. )
  56. )
  57. def test_jit_compile_rope_module(self):
  58. module_scripted = torch.jit.script(self.rope_pos_emd)
  59. apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb)
  60. # Test several different lengths
  61. for T in [3, 5, 10]:
  62. sample = torch.randn(T, self.B, self.C)
  63. # Run forward pass with the original module
  64. cos_original, sin_original = self.rope_pos_emd(sample, T)
  65. query = sample.view(T, self.B, 1, self.C)
  66. new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original)
  67. # Run forward pass with the scripted module
  68. cos_scripted, sin_scripted = module_scripted(sample, T)
  69. new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted)
  70. # Ensure the outputs are the same
  71. self.assertTrue(torch.allclose(cos_original, cos_scripted))
  72. self.assertTrue(torch.allclose(sin_original, sin_scripted))
  73. self.assertTrue(torch.allclose(new_query, new_query_scripted))
  74. self.assertTrue(torch.allclose(new_key, new_key_scripted))
  75. if __name__ == "__main__":
  76. unittest.main()