123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import random
- import unittest
- import pytest
- import torch
- from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
- BATCH = [20, 41, 97]
- SEQ = [64]
- EMB = [48]
- HEADS = [4]
- DROP = 0.1
- DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
- ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float]
- KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool]
- # FIXME: some tests fail when decimal=2, fix this and set decimal to 2
- def assert_almost_equal(x, y, decimal=1, err_msg=""):
- import numpy.testing as npt
- if isinstance(x, torch.Tensor):
- x = x.cpu().detach().numpy()
- if isinstance(y, torch.Tensor):
- y = y.cpu().detach().numpy()
- npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
- def _reset_seeds():
- torch.manual_seed(0)
- torch.random.manual_seed(0)
- random.seed(0)
- torch.cuda.manual_seed_all(0)
- def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
- if to_dtype == torch.float:
- mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
- return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
- return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
- def test_mask_for_xformers():
- # Additive Mask
- m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float)
- m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float)
- m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16)
- m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16)
- m_uint = torch.tensor([1, 0]).to(torch.uint8)
- m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8)
- m_bool = torch.tensor([False, True])
- assert torch.equal(_mask_for_xformers(m_float_add), m_float_add)
- assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add)
- assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped)
- assert torch.equal(_mask_for_xformers(m_bool), ~m_bool)
- assert torch.equal(
- _mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add
- )
- assert torch.equal(
- _mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add
- )
- assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool)
- assert torch.equal(
- _mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped
- )
- assert torch.equal(
- _mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add
- )
- assert torch.equal(
- _mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add
- )
- assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool)
- assert torch.equal(
- _mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped
- )
- assert torch.equal(
- _mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped
- )
- assert torch.equal(
- _mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped
- )
- assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool)
- assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint)
- assert torch.equal(
- _mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add
- )
- assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add)
- assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool)
- assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
- @pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu")
- @pytest.mark.skip(reason="not part of latest xformers")
- @pytest.mark.parametrize("device", ["cuda"])
- @pytest.mark.parametrize("add_zero_attn", [False])
- @pytest.mark.parametrize("batch_size", [20])
- @pytest.mark.parametrize("embedding", [64])
- @pytest.mark.parametrize("seq_len", [64])
- @pytest.mark.parametrize("num_heads", [4])
- def test_xformers_blocksparse_parity(
- device,
- add_zero_attn,
- batch_size,
- embedding,
- seq_len,
- num_heads,
- ):
- xformers_att_config = '{"name": "scaled_dot_product"}'
- xformers_blocksparse_blocksize = 16
- xformers_blocksparse_layout = torch.ones(
- seq_len // xformers_blocksparse_blocksize,
- seq_len // xformers_blocksparse_blocksize,
- dtype=torch.int32,
- )
- q = torch.rand(seq_len, batch_size, embedding).to(device).half()
- q.requires_grad = True
- k = torch.rand(seq_len, batch_size, embedding).to(device).half()
- k.requires_grad = True
- v = torch.rand(seq_len, batch_size, embedding).to(device).half()
- v.requires_grad = True
- q_ = q.detach().clone().half()
- q_.requires_grad = True
- k_ = k.detach().clone().half()
- k_.requires_grad = True
- v_ = v.detach().clone().half()
- v_.requires_grad = True
- _reset_seeds()
- xf_blocksparse_mha = (
- MultiheadAttention(
- embedding,
- num_heads,
- dropout=0.0,
- add_zero_attn=add_zero_attn,
- xformers_att_config=xformers_att_config,
- xformers_blocksparse_layout=xformers_blocksparse_layout,
- xformers_blocksparse_blocksize=xformers_blocksparse_blocksize,
- )
- .to(device)
- .half()
- )
- xf_blocksparse_output, _ = xf_blocksparse_mha(
- q,
- k,
- v,
- )
- _reset_seeds()
- xformers_mha = (
- MultiheadAttention(
- embedding,
- num_heads,
- dropout=0.0,
- add_zero_attn=add_zero_attn,
- xformers_att_config=xformers_att_config,
- xformers_blocksparse_layout=None,
- )
- .to(device)
- .half()
- )
- xformers_output, _ = xformers_mha(
- q_,
- k_,
- v_,
- )
- # # account for when nan != nan
- rand = random.uniform(0, 1)
- xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
- xf_blocksparse_output = xf_blocksparse_output.masked_fill(
- xf_blocksparse_output.isnan(), rand
- )
- assert_almost_equal(xformers_output, xf_blocksparse_output)
- loss_blocksparse = torch.norm(xformers_output)
- loss_original = torch.norm(xf_blocksparse_output)
- loss_blocksparse.backward()
- loss_original.backward()
- q.masked_fill(q.isnan(), rand)
- q_.masked_fill(q_.isnan(), rand)
- k.masked_fill(k.isnan(), rand)
- k_.masked_fill(k_.isnan(), rand)
- v.masked_fill(v.isnan(), rand)
- v_.masked_fill(v_.isnan(), rand)
- assert_almost_equal(q.grad, q_.grad)
- assert_almost_equal(k.grad, k_.grad)
- assert_almost_equal(v.grad, v_.grad)
- @pytest.mark.parametrize("device", DEVICE)
- @pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
- @pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
- @pytest.mark.parametrize("add_bias_kv", [True, False])
- @pytest.mark.parametrize("add_zero_attn", [True, False])
- # TODO: test with static_kv True
- @pytest.mark.parametrize("static_kv", [False])
- @pytest.mark.parametrize("batch_size", BATCH)
- @pytest.mark.parametrize("embedding", EMB)
- @pytest.mark.parametrize("seq_len", SEQ)
- @pytest.mark.parametrize("num_heads", HEADS)
- def test_xformers_single_forward_parity(
- device,
- attn_dtype,
- key_padding_dtype,
- add_bias_kv,
- add_zero_attn,
- static_kv,
- batch_size,
- embedding,
- seq_len,
- num_heads,
- ):
- xformers_att_config = '{"name": "scaled_dot_product"}'
- attn_mask = (
- None
- if attn_dtype is None
- else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
- )
- key_padding_mask = (
- None
- if key_padding_dtype is None
- else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
- device
- )
- )
- q = torch.rand(seq_len, batch_size, embedding).to(device)
- q.requires_grad = True
- k = torch.rand(seq_len, batch_size, embedding).to(device)
- k.requires_grad = True
- v = torch.rand(seq_len, batch_size, embedding).to(device)
- v.requires_grad = True
- q_ = q.detach().clone()
- q_.requires_grad = True
- k_ = k.detach().clone()
- k_.requires_grad = True
- v_ = v.detach().clone()
- v_.requires_grad = True
- # TODO: dropouts in the two implementations lead to different entries dropped.
- _reset_seeds()
- xformers_mha = MultiheadAttention(
- embedding,
- num_heads,
- dropout=0.0,
- xformers_att_config=xformers_att_config,
- add_bias_kv=add_bias_kv,
- add_zero_attn=add_zero_attn,
- ).to(device)
- xformers_output, _ = xformers_mha(
- q,
- k,
- v,
- key_padding_mask=key_padding_mask,
- attn_mask=attn_mask,
- static_kv=static_kv,
- )
- _reset_seeds()
- original_mha = MultiheadAttention(
- embedding,
- num_heads,
- dropout=0.0,
- xformers_att_config=None,
- add_bias_kv=add_bias_kv,
- add_zero_attn=add_zero_attn,
- ).to(device)
- original_output, _ = original_mha(
- q_,
- k_,
- v_,
- key_padding_mask=key_padding_mask,
- attn_mask=attn_mask,
- static_kv=static_kv,
- )
- # account for when nan != nan
- if xformers_output.isnan().any() or original_output.isnan().any():
- rand = random.uniform(0, 1)
- xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
- original_output = original_output.masked_fill(original_output.isnan(), rand)
- # torch.equal works for cpu, on cuda allclose is needed.
- assert torch.allclose(
- xformers_output, original_output, atol=1e-06
- ), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}"
- loss_xformers = torch.norm(xformers_output)
- loss_original = torch.norm(original_output)
- loss_xformers.backward()
- loss_original.backward()
- # torch.equal works for cpu, on cuda allclose is needed.
- assert torch.allclose(
- q.grad, q_.grad
- ), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}"
- assert torch.allclose(
- k.grad, k_.grad
- ), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}"
- assert torch.allclose(
- v.grad, v_.grad
- ), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
- def test_mask_padding_parity():
- def old_padding_code(key_padding_mask, attn_mask):
- if attn_mask is not None:
- attn_mask = torch.cat(
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
- )
- if key_padding_mask is not None:
- key_padding_mask = torch.cat(
- [
- key_padding_mask,
- torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
- ],
- dim=1,
- )
- return key_padding_mask, attn_mask
- # values don't matter for this test.
- mha = MultiheadAttention(
- embed_dim=8,
- num_heads=2,
- dropout=0.0,
- add_bias_kv=True,
- add_zero_attn=True,
- )
- key_padding_mask = torch.rand((8, 64))
- attn_mask = torch.rand((64, 64))
- kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask)
- kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask)
- assert kp_mask_orig.size() == kp_mask_new.size()
- assert a_mask_orig.size() == a_mask_new.size()
- assert torch.equal(kp_mask_orig, kp_mask_new)
- assert torch.equal(a_mask_orig, a_mask_new)
- def test_add_bias_parity():
- # values don't matter for this test.
- mha = MultiheadAttention(
- embed_dim=8,
- num_heads=2,
- dropout=0.0,
- add_bias_kv=True,
- add_zero_attn=True,
- )
- def old_bias_code(k, v, key_padding_mask, attn_mask, bsz):
- k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)])
- v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)])
- if attn_mask is not None:
- attn_mask = torch.cat(
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
- )
- if key_padding_mask is not None:
- key_padding_mask = torch.cat(
- [
- key_padding_mask,
- key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
- ],
- dim=1,
- )
- return k, v, key_padding_mask, attn_mask
- seq_len = 64
- bsz = 8
- embedding = 8
- key_padding_mask = torch.rand((bsz, seq_len))
- attn_mask = torch.rand((seq_len, seq_len))
- k = torch.rand((seq_len, bsz, embedding))
- v = torch.rand((seq_len, bsz, embedding))
- k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(
- k, v, key_padding_mask, attn_mask, bsz
- )
- k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(
- k, v, key_padding_mask, attn_mask, bsz
- )
- assert torch.equal(k_orig, k_new)
- assert torch.equal(v_orig, v_new)
- assert torch.equal(kp_mask_orig, kp_mask_new)
- assert torch.equal(a_mask_orig, a_mask_new)
- class TestMultiheadAttention(unittest.TestCase):
- def test_append_prev_key_padding_mask(self):
- bsz = 1
- src_len = 4
- cases = [
- # no padding mask
- (None, None, None),
- # current padding mask only
- (
- torch.tensor([[1]]).bool(),
- None,
- torch.tensor([[0, 0, 0, 1]]).bool(),
- ),
- # previous padding mask only
- (
- None,
- torch.tensor([[0, 1, 0]]).bool(),
- torch.tensor([[0, 1, 0, 0]]).bool(),
- ),
- # both padding masks
- (
- torch.tensor([[1]]).bool(),
- torch.tensor([[0, 1, 0]]).bool(),
- torch.tensor([[0, 1, 0, 1]]).bool(),
- ),
- # prev_key_padding_mask already full
- (
- torch.tensor([[0, 1, 0, 1]]).bool(),
- None,
- torch.tensor([[0, 1, 0, 1]]).bool(),
- ),
- # key_padding_mask already full
- (
- None,
- torch.tensor([[0, 1, 0, 1]]).bool(),
- torch.tensor([[0, 1, 0, 1]]).bool(),
- ),
- ]
- for c in cases:
- key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
- c[0],
- c[1],
- batch_size=bsz,
- src_len=src_len,
- static_kv=False,
- )
- if key_padding_mask is not None:
- self.assertTrue(
- torch.all(torch.eq(key_padding_mask, c[2])),
- f"Unexpected resultant key padding mask: {key_padding_mask}"
- f" given current: {c[0]} and previous: {c[1]}",
- )
- self.assertEqual(key_padding_mask.size(0), bsz)
- self.assertEqual(key_padding_mask.size(1), src_len)
- else:
- self.assertIsNone(c[2])
- def test_pruning_heads(self):
- embed_dim = 768
- num_heads = 12
- num_heads_to_keep = 8
- dummy_input = torch.randn(32, 2, embed_dim)
- mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
- reserve_head_index = mha._get_reserve_head_index(
- num_heads_to_keep=num_heads_to_keep
- )
- mha._adaptive_prune_heads(reserve_head_index=reserve_head_index)
- mha._set_skip_embed_dim_check()
- mha(query=dummy_input, key=dummy_input, value=dummy_input)
- self.assertEqual(mha.head_dim, embed_dim / num_heads)
- self.assertEqual(mha.num_heads, num_heads_to_keep)
- if __name__ == "__main__":
- unittest.main()
|