test_sparse_multihead_attention.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import unittest
  6. import torch
  7. from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
  8. class TestSparseMultiheadAttention(unittest.TestCase):
  9. def test_sparse_multihead_attention(self):
  10. attn_weights = torch.randn(1, 8, 8)
  11. bidirectional_sparse_mask = torch.tensor(
  12. [
  13. [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
  14. [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
  15. [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
  16. [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
  17. [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
  18. [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
  19. [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
  20. [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
  21. ]
  22. )
  23. bidirectional_attention = SparseMultiheadAttention(
  24. 16, 1, stride=4, expressivity=1, is_bidirectional=True
  25. )
  26. bidirectional_attention_sparse_mask = (
  27. bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
  28. )
  29. torch.all(
  30. torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)
  31. )
  32. sparse_mask = torch.tensor(
  33. [
  34. [
  35. 0,
  36. float("-inf"),
  37. float("-inf"),
  38. float("-inf"),
  39. float("-inf"),
  40. float("-inf"),
  41. float("-inf"),
  42. float("-inf"),
  43. ],
  44. [
  45. 0,
  46. 0,
  47. float("-inf"),
  48. float("-inf"),
  49. float("-inf"),
  50. float("-inf"),
  51. float("-inf"),
  52. float("-inf"),
  53. ],
  54. [
  55. 0,
  56. 0,
  57. 0,
  58. float("-inf"),
  59. float("-inf"),
  60. float("-inf"),
  61. float("-inf"),
  62. float("-inf"),
  63. ],
  64. [
  65. 0,
  66. 0,
  67. 0,
  68. 0,
  69. float("-inf"),
  70. float("-inf"),
  71. float("-inf"),
  72. float("-inf"),
  73. ],
  74. [0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")],
  75. [
  76. float("-inf"),
  77. float("-inf"),
  78. float("-inf"),
  79. 0,
  80. 0,
  81. 0,
  82. float("-inf"),
  83. float("-inf"),
  84. ],
  85. [
  86. float("-inf"),
  87. float("-inf"),
  88. float("-inf"),
  89. 0,
  90. 0,
  91. 0,
  92. 0,
  93. float("-inf"),
  94. ],
  95. [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
  96. ]
  97. )
  98. attention = SparseMultiheadAttention(
  99. 16, 1, stride=4, expressivity=1, is_bidirectional=False
  100. )
  101. attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
  102. torch.all(torch.eq(attention_sparse_mask, sparse_mask))
  103. if __name__ == "__main__":
  104. unittest.main()