123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import unittest
- import torch
- from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
- class TestSparseMultiheadAttention(unittest.TestCase):
- def test_sparse_multihead_attention(self):
- attn_weights = torch.randn(1, 8, 8)
- bidirectional_sparse_mask = torch.tensor(
- [
- [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
- [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
- [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
- [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
- [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
- [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
- [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
- [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
- ]
- )
- bidirectional_attention = SparseMultiheadAttention(
- 16, 1, stride=4, expressivity=1, is_bidirectional=True
- )
- bidirectional_attention_sparse_mask = (
- bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
- )
- torch.all(
- torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)
- )
- sparse_mask = torch.tensor(
- [
- [
- 0,
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- ],
- [
- 0,
- 0,
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- ],
- [
- 0,
- 0,
- 0,
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- ],
- [
- 0,
- 0,
- 0,
- 0,
- float("-inf"),
- float("-inf"),
- float("-inf"),
- float("-inf"),
- ],
- [0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")],
- [
- float("-inf"),
- float("-inf"),
- float("-inf"),
- 0,
- 0,
- 0,
- float("-inf"),
- float("-inf"),
- ],
- [
- float("-inf"),
- float("-inf"),
- float("-inf"),
- 0,
- 0,
- 0,
- 0,
- float("-inf"),
- ],
- [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
- ]
- )
- attention = SparseMultiheadAttention(
- 16, 1, stride=4, expressivity=1, is_bidirectional=False
- )
- attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
- torch.all(torch.eq(attention_sparse_mask, sparse_mask))
- if __name__ == "__main__":
- unittest.main()
|