test_multihead_attention.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import random
  6. import unittest
  7. import pytest
  8. import torch
  9. from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
  10. BATCH = [20, 41, 97]
  11. SEQ = [64]
  12. EMB = [48]
  13. HEADS = [4]
  14. DROP = 0.1
  15. DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
  16. ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float]
  17. KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool]
  18. # FIXME: some tests fail when decimal=2, fix this and set decimal to 2
  19. def assert_almost_equal(x, y, decimal=1, err_msg=""):
  20. import numpy.testing as npt
  21. if isinstance(x, torch.Tensor):
  22. x = x.cpu().detach().numpy()
  23. if isinstance(y, torch.Tensor):
  24. y = y.cpu().detach().numpy()
  25. npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
  26. def _reset_seeds():
  27. torch.manual_seed(0)
  28. torch.random.manual_seed(0)
  29. random.seed(0)
  30. torch.cuda.manual_seed_all(0)
  31. def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
  32. if to_dtype == torch.float:
  33. mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
  34. return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
  35. return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
  36. def test_mask_for_xformers():
  37. # Additive Mask
  38. m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float)
  39. m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float)
  40. m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16)
  41. m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16)
  42. m_uint = torch.tensor([1, 0]).to(torch.uint8)
  43. m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8)
  44. m_bool = torch.tensor([False, True])
  45. assert torch.equal(_mask_for_xformers(m_float_add), m_float_add)
  46. assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add)
  47. assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped)
  48. assert torch.equal(_mask_for_xformers(m_bool), ~m_bool)
  49. assert torch.equal(
  50. _mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add
  51. )
  52. assert torch.equal(
  53. _mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add
  54. )
  55. assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool)
  56. assert torch.equal(
  57. _mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped
  58. )
  59. assert torch.equal(
  60. _mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add
  61. )
  62. assert torch.equal(
  63. _mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add
  64. )
  65. assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool)
  66. assert torch.equal(
  67. _mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped
  68. )
  69. assert torch.equal(
  70. _mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped
  71. )
  72. assert torch.equal(
  73. _mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped
  74. )
  75. assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool)
  76. assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint)
  77. assert torch.equal(
  78. _mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add
  79. )
  80. assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add)
  81. assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool)
  82. assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
  83. @pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu")
  84. @pytest.mark.skip(reason="not part of latest xformers")
  85. @pytest.mark.parametrize("device", ["cuda"])
  86. @pytest.mark.parametrize("add_zero_attn", [False])
  87. @pytest.mark.parametrize("batch_size", [20])
  88. @pytest.mark.parametrize("embedding", [64])
  89. @pytest.mark.parametrize("seq_len", [64])
  90. @pytest.mark.parametrize("num_heads", [4])
  91. def test_xformers_blocksparse_parity(
  92. device,
  93. add_zero_attn,
  94. batch_size,
  95. embedding,
  96. seq_len,
  97. num_heads,
  98. ):
  99. xformers_att_config = '{"name": "scaled_dot_product"}'
  100. xformers_blocksparse_blocksize = 16
  101. xformers_blocksparse_layout = torch.ones(
  102. seq_len // xformers_blocksparse_blocksize,
  103. seq_len // xformers_blocksparse_blocksize,
  104. dtype=torch.int32,
  105. )
  106. q = torch.rand(seq_len, batch_size, embedding).to(device).half()
  107. q.requires_grad = True
  108. k = torch.rand(seq_len, batch_size, embedding).to(device).half()
  109. k.requires_grad = True
  110. v = torch.rand(seq_len, batch_size, embedding).to(device).half()
  111. v.requires_grad = True
  112. q_ = q.detach().clone().half()
  113. q_.requires_grad = True
  114. k_ = k.detach().clone().half()
  115. k_.requires_grad = True
  116. v_ = v.detach().clone().half()
  117. v_.requires_grad = True
  118. _reset_seeds()
  119. xf_blocksparse_mha = (
  120. MultiheadAttention(
  121. embedding,
  122. num_heads,
  123. dropout=0.0,
  124. add_zero_attn=add_zero_attn,
  125. xformers_att_config=xformers_att_config,
  126. xformers_blocksparse_layout=xformers_blocksparse_layout,
  127. xformers_blocksparse_blocksize=xformers_blocksparse_blocksize,
  128. )
  129. .to(device)
  130. .half()
  131. )
  132. xf_blocksparse_output, _ = xf_blocksparse_mha(
  133. q,
  134. k,
  135. v,
  136. )
  137. _reset_seeds()
  138. xformers_mha = (
  139. MultiheadAttention(
  140. embedding,
  141. num_heads,
  142. dropout=0.0,
  143. add_zero_attn=add_zero_attn,
  144. xformers_att_config=xformers_att_config,
  145. xformers_blocksparse_layout=None,
  146. )
  147. .to(device)
  148. .half()
  149. )
  150. xformers_output, _ = xformers_mha(
  151. q_,
  152. k_,
  153. v_,
  154. )
  155. # # account for when nan != nan
  156. rand = random.uniform(0, 1)
  157. xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
  158. xf_blocksparse_output = xf_blocksparse_output.masked_fill(
  159. xf_blocksparse_output.isnan(), rand
  160. )
  161. assert_almost_equal(xformers_output, xf_blocksparse_output)
  162. loss_blocksparse = torch.norm(xformers_output)
  163. loss_original = torch.norm(xf_blocksparse_output)
  164. loss_blocksparse.backward()
  165. loss_original.backward()
  166. q.masked_fill(q.isnan(), rand)
  167. q_.masked_fill(q_.isnan(), rand)
  168. k.masked_fill(k.isnan(), rand)
  169. k_.masked_fill(k_.isnan(), rand)
  170. v.masked_fill(v.isnan(), rand)
  171. v_.masked_fill(v_.isnan(), rand)
  172. assert_almost_equal(q.grad, q_.grad)
  173. assert_almost_equal(k.grad, k_.grad)
  174. assert_almost_equal(v.grad, v_.grad)
  175. @pytest.mark.parametrize("device", DEVICE)
  176. @pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
  177. @pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
  178. @pytest.mark.parametrize("add_bias_kv", [True, False])
  179. @pytest.mark.parametrize("add_zero_attn", [True, False])
  180. # TODO: test with static_kv True
  181. @pytest.mark.parametrize("static_kv", [False])
  182. @pytest.mark.parametrize("batch_size", BATCH)
  183. @pytest.mark.parametrize("embedding", EMB)
  184. @pytest.mark.parametrize("seq_len", SEQ)
  185. @pytest.mark.parametrize("num_heads", HEADS)
  186. def test_xformers_single_forward_parity(
  187. device,
  188. attn_dtype,
  189. key_padding_dtype,
  190. add_bias_kv,
  191. add_zero_attn,
  192. static_kv,
  193. batch_size,
  194. embedding,
  195. seq_len,
  196. num_heads,
  197. ):
  198. xformers_att_config = '{"name": "scaled_dot_product"}'
  199. attn_mask = (
  200. None
  201. if attn_dtype is None
  202. else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
  203. )
  204. key_padding_mask = (
  205. None
  206. if key_padding_dtype is None
  207. else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
  208. device
  209. )
  210. )
  211. q = torch.rand(seq_len, batch_size, embedding).to(device)
  212. q.requires_grad = True
  213. k = torch.rand(seq_len, batch_size, embedding).to(device)
  214. k.requires_grad = True
  215. v = torch.rand(seq_len, batch_size, embedding).to(device)
  216. v.requires_grad = True
  217. q_ = q.detach().clone()
  218. q_.requires_grad = True
  219. k_ = k.detach().clone()
  220. k_.requires_grad = True
  221. v_ = v.detach().clone()
  222. v_.requires_grad = True
  223. # TODO: dropouts in the two implementations lead to different entries dropped.
  224. _reset_seeds()
  225. xformers_mha = MultiheadAttention(
  226. embedding,
  227. num_heads,
  228. dropout=0.0,
  229. xformers_att_config=xformers_att_config,
  230. add_bias_kv=add_bias_kv,
  231. add_zero_attn=add_zero_attn,
  232. ).to(device)
  233. xformers_output, _ = xformers_mha(
  234. q,
  235. k,
  236. v,
  237. key_padding_mask=key_padding_mask,
  238. attn_mask=attn_mask,
  239. static_kv=static_kv,
  240. )
  241. _reset_seeds()
  242. original_mha = MultiheadAttention(
  243. embedding,
  244. num_heads,
  245. dropout=0.0,
  246. xformers_att_config=None,
  247. add_bias_kv=add_bias_kv,
  248. add_zero_attn=add_zero_attn,
  249. ).to(device)
  250. original_output, _ = original_mha(
  251. q_,
  252. k_,
  253. v_,
  254. key_padding_mask=key_padding_mask,
  255. attn_mask=attn_mask,
  256. static_kv=static_kv,
  257. )
  258. # account for when nan != nan
  259. if xformers_output.isnan().any() or original_output.isnan().any():
  260. rand = random.uniform(0, 1)
  261. xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
  262. original_output = original_output.masked_fill(original_output.isnan(), rand)
  263. # torch.equal works for cpu, on cuda allclose is needed.
  264. assert torch.allclose(
  265. xformers_output, original_output, atol=1e-06
  266. ), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}"
  267. loss_xformers = torch.norm(xformers_output)
  268. loss_original = torch.norm(original_output)
  269. loss_xformers.backward()
  270. loss_original.backward()
  271. # torch.equal works for cpu, on cuda allclose is needed.
  272. assert torch.allclose(
  273. q.grad, q_.grad
  274. ), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}"
  275. assert torch.allclose(
  276. k.grad, k_.grad
  277. ), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}"
  278. assert torch.allclose(
  279. v.grad, v_.grad
  280. ), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
  281. def test_mask_padding_parity():
  282. def old_padding_code(key_padding_mask, attn_mask):
  283. if attn_mask is not None:
  284. attn_mask = torch.cat(
  285. [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
  286. )
  287. if key_padding_mask is not None:
  288. key_padding_mask = torch.cat(
  289. [
  290. key_padding_mask,
  291. torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
  292. ],
  293. dim=1,
  294. )
  295. return key_padding_mask, attn_mask
  296. # values don't matter for this test.
  297. mha = MultiheadAttention(
  298. embed_dim=8,
  299. num_heads=2,
  300. dropout=0.0,
  301. add_bias_kv=True,
  302. add_zero_attn=True,
  303. )
  304. key_padding_mask = torch.rand((8, 64))
  305. attn_mask = torch.rand((64, 64))
  306. kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask)
  307. kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask)
  308. assert kp_mask_orig.size() == kp_mask_new.size()
  309. assert a_mask_orig.size() == a_mask_new.size()
  310. assert torch.equal(kp_mask_orig, kp_mask_new)
  311. assert torch.equal(a_mask_orig, a_mask_new)
  312. def test_add_bias_parity():
  313. # values don't matter for this test.
  314. mha = MultiheadAttention(
  315. embed_dim=8,
  316. num_heads=2,
  317. dropout=0.0,
  318. add_bias_kv=True,
  319. add_zero_attn=True,
  320. )
  321. def old_bias_code(k, v, key_padding_mask, attn_mask, bsz):
  322. k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)])
  323. v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)])
  324. if attn_mask is not None:
  325. attn_mask = torch.cat(
  326. [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
  327. )
  328. if key_padding_mask is not None:
  329. key_padding_mask = torch.cat(
  330. [
  331. key_padding_mask,
  332. key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
  333. ],
  334. dim=1,
  335. )
  336. return k, v, key_padding_mask, attn_mask
  337. seq_len = 64
  338. bsz = 8
  339. embedding = 8
  340. key_padding_mask = torch.rand((bsz, seq_len))
  341. attn_mask = torch.rand((seq_len, seq_len))
  342. k = torch.rand((seq_len, bsz, embedding))
  343. v = torch.rand((seq_len, bsz, embedding))
  344. k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(
  345. k, v, key_padding_mask, attn_mask, bsz
  346. )
  347. k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(
  348. k, v, key_padding_mask, attn_mask, bsz
  349. )
  350. assert torch.equal(k_orig, k_new)
  351. assert torch.equal(v_orig, v_new)
  352. assert torch.equal(kp_mask_orig, kp_mask_new)
  353. assert torch.equal(a_mask_orig, a_mask_new)
  354. class TestMultiheadAttention(unittest.TestCase):
  355. def test_append_prev_key_padding_mask(self):
  356. bsz = 1
  357. src_len = 4
  358. cases = [
  359. # no padding mask
  360. (None, None, None),
  361. # current padding mask only
  362. (
  363. torch.tensor([[1]]).bool(),
  364. None,
  365. torch.tensor([[0, 0, 0, 1]]).bool(),
  366. ),
  367. # previous padding mask only
  368. (
  369. None,
  370. torch.tensor([[0, 1, 0]]).bool(),
  371. torch.tensor([[0, 1, 0, 0]]).bool(),
  372. ),
  373. # both padding masks
  374. (
  375. torch.tensor([[1]]).bool(),
  376. torch.tensor([[0, 1, 0]]).bool(),
  377. torch.tensor([[0, 1, 0, 1]]).bool(),
  378. ),
  379. # prev_key_padding_mask already full
  380. (
  381. torch.tensor([[0, 1, 0, 1]]).bool(),
  382. None,
  383. torch.tensor([[0, 1, 0, 1]]).bool(),
  384. ),
  385. # key_padding_mask already full
  386. (
  387. None,
  388. torch.tensor([[0, 1, 0, 1]]).bool(),
  389. torch.tensor([[0, 1, 0, 1]]).bool(),
  390. ),
  391. ]
  392. for c in cases:
  393. key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
  394. c[0],
  395. c[1],
  396. batch_size=bsz,
  397. src_len=src_len,
  398. static_kv=False,
  399. )
  400. if key_padding_mask is not None:
  401. self.assertTrue(
  402. torch.all(torch.eq(key_padding_mask, c[2])),
  403. f"Unexpected resultant key padding mask: {key_padding_mask}"
  404. f" given current: {c[0]} and previous: {c[1]}",
  405. )
  406. self.assertEqual(key_padding_mask.size(0), bsz)
  407. self.assertEqual(key_padding_mask.size(1), src_len)
  408. else:
  409. self.assertIsNone(c[2])
  410. def test_pruning_heads(self):
  411. embed_dim = 768
  412. num_heads = 12
  413. num_heads_to_keep = 8
  414. dummy_input = torch.randn(32, 2, embed_dim)
  415. mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
  416. reserve_head_index = mha._get_reserve_head_index(
  417. num_heads_to_keep=num_heads_to_keep
  418. )
  419. mha._adaptive_prune_heads(reserve_head_index=reserve_head_index)
  420. mha._set_skip_embed_dim_check()
  421. mha(query=dummy_input, key=dummy_input, value=dummy_input)
  422. self.assertEqual(mha.head_dim, embed_dim / num_heads)
  423. self.assertEqual(mha.num_heads, num_heads_to_keep)
  424. if __name__ == "__main__":
  425. unittest.main()