test_roberta.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import functools
  6. import unittest
  7. from typing import Any, Dict, Sequence
  8. import fairseq
  9. import fairseq.options
  10. import fairseq.tasks
  11. import torch
  12. from tests.utils import dummy_dictionary
  13. VOCAB_SIZE = 100
  14. @fairseq.tasks.register_task("fake_task")
  15. class FakeTask(fairseq.tasks.LegacyFairseqTask):
  16. def __init__(self, args):
  17. super().__init__(args)
  18. self.dictionary = dummy_dictionary(VOCAB_SIZE - 4)
  19. assert len(self.dictionary) == VOCAB_SIZE
  20. @property
  21. def source_dictionary(self):
  22. return self.dictionary
  23. @property
  24. def target_dictionary(self):
  25. return self.dictionary
  26. @functools.lru_cache()
  27. def get_toy_model(
  28. device: str,
  29. architecture: str = "roberta_enc_dec",
  30. **extra_args: Any,
  31. ):
  32. assert device in ("gpu", "cpu")
  33. kwargs = {
  34. "arch": architecture,
  35. # Use characteristics dimensions
  36. "encoder_layers": 3,
  37. "encoder_embed_dim": 12,
  38. "encoder_ffn_embed_dim": 14,
  39. "encoder_attention_heads": 4,
  40. "decoder_layers": 3,
  41. "decoder_embed_dim": 12,
  42. "decoder_ffn_embed_dim": 14,
  43. "decoder_attention_heads": 4,
  44. # Disable dropout so we have comparable tests.
  45. "dropout": 0,
  46. "attention_dropout": 0,
  47. "activation_dropout": 0,
  48. "encoder_layerdrop": 0,
  49. # required args
  50. "tokens_per_sample": 256,
  51. "data": "/tmp/test_roberta",
  52. }
  53. kwargs.update(extra_args)
  54. fake_task = FakeTask(kwargs)
  55. args = fairseq.options.get_args(
  56. task="online_backtranslation",
  57. mono_langs="en,ro",
  58. valid_lang_pairs="en-ro",
  59. **kwargs,
  60. )
  61. torch.manual_seed(0)
  62. model = fake_task.build_model(args)
  63. if device == "gpu":
  64. model.cuda()
  65. return fake_task, model
  66. def mk_sample(
  67. lang: str, device: str, tok: Sequence[int] = None, batch_size: int = 2
  68. ) -> Dict[str, Any]:
  69. assert device in ("gpu", "cpu")
  70. if not tok:
  71. if lang == "en":
  72. tok = [10, 11, 12, 13, 14, 15, 2]
  73. else:
  74. tok = [20, 21, 22, 23, 24, 25, 26, 27, 2]
  75. batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
  76. if device == "gpu":
  77. batch = batch.cuda()
  78. sample = {
  79. "net_input": {
  80. "src_tokens": batch,
  81. "prev_output_tokens": batch,
  82. "src_lengths": torch.tensor(
  83. [len(tok)] * batch_size, dtype=torch.long, device=batch.device
  84. ),
  85. },
  86. "target": batch[:, 1:],
  87. }
  88. return sample
  89. def cpu_gpu(fn):
  90. def helper(self):
  91. fn(self, "cpu")
  92. if torch.cuda.is_available():
  93. fn(self, "gpu")
  94. return helper
  95. def architectures(fn):
  96. def helper(self):
  97. for arch in ["roberta_enc_dec", "transformer"]:
  98. fn(self, arch)
  99. return helper
  100. class RobertaTest(unittest.TestCase):
  101. def assertTensorEqual(self, t1, t2, delta: float = 1e-6):
  102. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  103. if delta == 0.0:
  104. self.assertEqual(t1.ne(t2).long().sum(), 0)
  105. else:
  106. self.assertEqual(((t2 - t1).abs() > delta).long().sum(), 0)
  107. def assertSharing(self, model, link_groups: Sequence[Sequence[str]]):
  108. ids = {}
  109. for group in link_groups:
  110. group_ids = {name: id(params(model, name)) for name in group}
  111. shared_id = group_ids[group[0]]
  112. self.assertEqual(group_ids, {name: shared_id for name in group})
  113. self.assertNotIn(shared_id, ids)
  114. ids[shared_id] = group
  115. def test_roberta_shared_params(self):
  116. _, roberta = get_toy_model("cpu", architecture="roberta")
  117. self.assertSharing(
  118. roberta,
  119. [
  120. [
  121. "encoder.sentence_encoder.embed_tokens.weight",
  122. "encoder.lm_head.weight",
  123. ]
  124. ],
  125. )
  126. _, roberta = get_toy_model(
  127. "cpu", architecture="roberta", untie_weights_roberta=True
  128. )
  129. self.assertSharing(
  130. roberta,
  131. [
  132. ["encoder.sentence_encoder.embed_tokens.weight"],
  133. ["encoder.lm_head.weight"],
  134. ],
  135. )
  136. def test_roberta_enc_dec_shared_params(self):
  137. # 3 distinct embeddings
  138. _, enc_dec = get_toy_model("cpu", architecture="roberta_enc_dec")
  139. self.assertSharing(
  140. enc_dec,
  141. [
  142. ["encoder.embed_tokens.weight"],
  143. ["decoder.embed_tokens.weight"],
  144. ["decoder.output_projection.weight"],
  145. ],
  146. )
  147. # 2 distinct embeddings, one for encoder, one for decoder
  148. _, enc_dec = get_toy_model(
  149. "cpu", architecture="roberta_enc_dec", share_decoder_input_output_embed=True
  150. )
  151. self.assertSharing(
  152. enc_dec,
  153. [
  154. ["encoder.embed_tokens.weight"],
  155. [
  156. "decoder.embed_tokens.weight",
  157. "decoder.output_projection.weight",
  158. ],
  159. ],
  160. )
  161. # shared embeddings
  162. _, enc_dec = get_toy_model(
  163. "cpu", architecture="roberta_enc_dec", share_all_embeddings=True
  164. )
  165. self.assertSharing(
  166. enc_dec,
  167. [
  168. [
  169. "encoder.embed_tokens.weight",
  170. "decoder.embed_tokens.weight",
  171. "decoder.output_projection.weight",
  172. ]
  173. ],
  174. )
  175. def test_roberta_max_positions_is_correctly_set(self):
  176. device = "cpu"
  177. task, model = get_toy_model(device)
  178. max_pos = model.max_decoder_positions()
  179. self.assertEqual(max_pos, 256)
  180. self.assertEqual(max_pos, model.decoder.max_positions())
  181. self.assertEqual(max_pos, model.encoder.max_positions())
  182. self.assertEqual(max_pos, model.encoder.embed_positions.max_positions)
  183. sentence = [31 for _ in range(max_pos)]
  184. sample = mk_sample("en", device, sentence, batch_size=1)
  185. self.assertEqual(list(sample["net_input"]["src_lengths"]), [max_pos])
  186. self.assertEqual(len(sample["net_input"]["src_tokens"][0]), max_pos)
  187. x, _ = model.forward(**sample["net_input"])
  188. self.assertEqual(x.shape, (1, max_pos, VOCAB_SIZE))
  189. @cpu_gpu
  190. def test_roberta_forward_backward(self, device: str):
  191. _, model = get_toy_model(device)
  192. sample = mk_sample("en", device)
  193. en_tokens = sample["net_input"]["src_tokens"]
  194. (bs, l) = en_tokens.shape
  195. # Forward
  196. logits, _ = model(**sample["net_input"])
  197. self.assertEqual(logits.shape, (bs, l, VOCAB_SIZE))
  198. # Backward
  199. loss = logits.sum()
  200. loss.backward()
  201. @cpu_gpu
  202. def test_roberta_forward_backward_bs1(self, device: str):
  203. _, model = get_toy_model(device)
  204. sample = mk_sample("en", device, batch_size=1)
  205. o, _ = model.forward(**sample["net_input"])
  206. loss = o.sum()
  207. sample2 = mk_sample("ro", device, batch_size=1)
  208. o, _ = model.forward(**sample2["net_input"])
  209. loss += o.sum()
  210. loss.backward()
  211. @cpu_gpu
  212. def test_roberta_batching(self, device: str):
  213. """
  214. Checks that the batch of size 2 give twice the same results than the batch of size 1.
  215. """
  216. _, model = get_toy_model(device)
  217. sample = mk_sample("en", device, batch_size=1)
  218. slen = sample["net_input"]["src_lengths"][0]
  219. sample2 = mk_sample("en", device, batch_size=2)
  220. with torch.no_grad():
  221. z = model.encoder.forward(
  222. sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
  223. )
  224. z = z["encoder_out"][-1]
  225. logits, _ = model.forward(**sample["net_input"])
  226. z2 = model.encoder.forward(
  227. sample2["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
  228. )
  229. z2 = z2["encoder_out"][-1]
  230. logits2, _ = model.forward(**sample2["net_input"])
  231. self.assertEqual(z.shape, (slen, 1, 12))
  232. self.assertEqual(z2.shape, (slen, 2, 12))
  233. self.assertTensorEqual(logits2[0], logits2[1])
  234. self.assertTensorEqual(logits[0], logits2[0])
  235. @cpu_gpu
  236. def test_roberta_incremental_decoder(self, device: str):
  237. """
  238. Checks that incremental decoding yields the same result than non incremental one.
  239. """
  240. task, model = get_toy_model(device)
  241. en_sample = mk_sample("en", device)
  242. en_tokens = en_sample["net_input"]["src_tokens"]
  243. ro_sample = mk_sample("ro", device)
  244. ro_tokens = ro_sample["net_input"]["src_tokens"]
  245. en_enc = model.encoder.forward(
  246. en_tokens, src_lengths=en_sample["net_input"]["src_lengths"]
  247. )
  248. (bs, tgt_len) = ro_tokens.shape
  249. # Decode without incremental state
  250. ro_dec, _ = model.decoder.forward(ro_tokens, encoder_out=en_enc)
  251. self.assertEqual(ro_dec.shape, (bs, tgt_len, VOCAB_SIZE))
  252. self.assertTensorEqual(ro_dec[0], ro_dec[1])
  253. # Decode with incremental state
  254. inc_state = {}
  255. ro_dec_inc = []
  256. for i in range(tgt_len):
  257. ro, _ = model.decoder.forward(
  258. ro_tokens[:, : i + 1], encoder_out=en_enc, incremental_state=inc_state
  259. )
  260. self.assertEqual(ro.shape, (bs, 1, VOCAB_SIZE))
  261. ro_dec_inc.append(ro)
  262. for i in range(tgt_len):
  263. # Intra-batch
  264. self.assertTensorEqual(ro_dec_inc[i][0], ro_dec_inc[i][1])
  265. # Incremental vs non-incremental
  266. self.assertTensorEqual(ro_dec_inc[i][:, 0], ro_dec[:, i])
  267. @cpu_gpu
  268. def test_regularize_for_adaprune_in_roberta(self, device: str):
  269. _, model = get_toy_model(
  270. device=device,
  271. architecture="roberta_base",
  272. mha_reg_scale_factor=0.000375,
  273. ffn_reg_scale_factor=0.000375,
  274. )
  275. sample = mk_sample("en", device, batch_size=1)
  276. task_loss, _ = model.forward(**sample["net_input"])
  277. head_loss = model._get_adaptive_head_loss()
  278. ffn_loss = model._get_adaptive_ffn_loss()
  279. loss = task_loss.sum() + head_loss + ffn_loss
  280. loss.backward()
  281. @cpu_gpu
  282. def test_ffn_prune_for_adaprune_in_roberta(self, device: str):
  283. _, model = get_toy_model(
  284. device=device,
  285. architecture="roberta_base",
  286. )
  287. sample = mk_sample("en", device, batch_size=1)
  288. for layer in model.encoder.sentence_encoder.layers:
  289. fc1_original_size = layer.fc1.out_features
  290. remove_index = layer._get_fc_rank(remove_num=2)
  291. layer._prune_fc_layer(remove_index=remove_index)
  292. self.assertEqual(layer.fc1.out_features, fc1_original_size - 2)
  293. task_loss, _ = model.forward(**sample["net_input"])
  294. def params(model, name):
  295. if "." not in name:
  296. return getattr(model, name)
  297. prefix, name = name.split(".", 1)
  298. return params(getattr(model, prefix), name)