tutorial_simple_lstm.rst 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. Tutorial: Simple LSTM
  2. =====================
  3. In this tutorial we will extend fairseq by adding a new
  4. :class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
  5. sentence with an LSTM and then passes the final hidden state to a second LSTM
  6. that decodes the target sentence (without attention).
  7. This tutorial covers:
  8. 1. **Writing an Encoder and Decoder** to encode/decode the source/target
  9. sentence, respectively.
  10. 2. **Registering a new Model** so that it can be used with the existing
  11. :ref:`Command-line tools`.
  12. 3. **Training the Model** using the existing command-line tools.
  13. 4. **Making generation faster** by modifying the Decoder to use
  14. :ref:`Incremental decoding`.
  15. 1. Building an Encoder and Decoder
  16. ----------------------------------
  17. In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
  18. should implement the :class:`~fairseq.models.FairseqEncoder` interface and
  19. Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
  20. These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
  21. and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
  22. Modules.
  23. Encoder
  24. ~~~~~~~
  25. Our Encoder will embed the tokens in the source sentence, feed them to a
  26. :class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
  27. save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
  28. import torch.nn as nn
  29. from fairseq import utils
  30. from fairseq.models import FairseqEncoder
  31. class SimpleLSTMEncoder(FairseqEncoder):
  32. def __init__(
  33. self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
  34. ):
  35. super().__init__(dictionary)
  36. self.args = args
  37. # Our encoder will embed the inputs before feeding them to the LSTM.
  38. self.embed_tokens = nn.Embedding(
  39. num_embeddings=len(dictionary),
  40. embedding_dim=embed_dim,
  41. padding_idx=dictionary.pad(),
  42. )
  43. self.dropout = nn.Dropout(p=dropout)
  44. # We'll use a single-layer, unidirectional LSTM for simplicity.
  45. self.lstm = nn.LSTM(
  46. input_size=embed_dim,
  47. hidden_size=hidden_dim,
  48. num_layers=1,
  49. bidirectional=False,
  50. batch_first=True,
  51. )
  52. def forward(self, src_tokens, src_lengths):
  53. # The inputs to the ``forward()`` function are determined by the
  54. # Task, and in particular the ``'net_input'`` key in each
  55. # mini-batch. We discuss Tasks in the next tutorial, but for now just
  56. # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
  57. # has shape `(batch)`.
  58. # Note that the source is typically padded on the left. This can be
  59. # configured by adding the `--left-pad-source "False"` command-line
  60. # argument, but here we'll make the Encoder handle either kind of
  61. # padding by converting everything to be right-padded.
  62. if self.args.left_pad_source:
  63. # Convert left-padding to right-padding.
  64. src_tokens = utils.convert_padding_direction(
  65. src_tokens,
  66. padding_idx=self.dictionary.pad(),
  67. left_to_right=True
  68. )
  69. # Embed the source.
  70. x = self.embed_tokens(src_tokens)
  71. # Apply dropout.
  72. x = self.dropout(x)
  73. # Pack the sequence into a PackedSequence object to feed to the LSTM.
  74. x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
  75. # Get the output from the LSTM.
  76. _outputs, (final_hidden, _final_cell) = self.lstm(x)
  77. # Return the Encoder's output. This can be any object and will be
  78. # passed directly to the Decoder.
  79. return {
  80. # this will have shape `(bsz, hidden_dim)`
  81. 'final_hidden': final_hidden.squeeze(0),
  82. }
  83. # Encoders are required to implement this method so that we can rearrange
  84. # the order of the batch elements during inference (e.g., beam search).
  85. def reorder_encoder_out(self, encoder_out, new_order):
  86. """
  87. Reorder encoder output according to `new_order`.
  88. Args:
  89. encoder_out: output from the ``forward()`` method
  90. new_order (LongTensor): desired order
  91. Returns:
  92. `encoder_out` rearranged according to `new_order`
  93. """
  94. final_hidden = encoder_out['final_hidden']
  95. return {
  96. 'final_hidden': final_hidden.index_select(0, new_order),
  97. }
  98. Decoder
  99. ~~~~~~~
  100. Our Decoder will predict the next word, conditioned on the Encoder's final
  101. hidden state and an embedded representation of the previous target word -- which
  102. is sometimes called *teacher forcing*. More specifically, we'll use a
  103. :class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
  104. to the size of the output vocabulary to predict each target word.
  105. ::
  106. import torch
  107. from fairseq.models import FairseqDecoder
  108. class SimpleLSTMDecoder(FairseqDecoder):
  109. def __init__(
  110. self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
  111. dropout=0.1,
  112. ):
  113. super().__init__(dictionary)
  114. # Our decoder will embed the inputs before feeding them to the LSTM.
  115. self.embed_tokens = nn.Embedding(
  116. num_embeddings=len(dictionary),
  117. embedding_dim=embed_dim,
  118. padding_idx=dictionary.pad(),
  119. )
  120. self.dropout = nn.Dropout(p=dropout)
  121. # We'll use a single-layer, unidirectional LSTM for simplicity.
  122. self.lstm = nn.LSTM(
  123. # For the first layer we'll concatenate the Encoder's final hidden
  124. # state with the embedded target tokens.
  125. input_size=encoder_hidden_dim + embed_dim,
  126. hidden_size=hidden_dim,
  127. num_layers=1,
  128. bidirectional=False,
  129. )
  130. # Define the output projection.
  131. self.output_projection = nn.Linear(hidden_dim, len(dictionary))
  132. # During training Decoders are expected to take the entire target sequence
  133. # (shifted right by one position) and produce logits over the vocabulary.
  134. # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
  135. # ``dictionary.eos()``, followed by the target sequence.
  136. def forward(self, prev_output_tokens, encoder_out):
  137. """
  138. Args:
  139. prev_output_tokens (LongTensor): previous decoder outputs of shape
  140. `(batch, tgt_len)`, for teacher forcing
  141. encoder_out (Tensor, optional): output from the encoder, used for
  142. encoder-side attention
  143. Returns:
  144. tuple:
  145. - the last decoder layer's output of shape
  146. `(batch, tgt_len, vocab)`
  147. - the last decoder layer's attention weights of shape
  148. `(batch, tgt_len, src_len)`
  149. """
  150. bsz, tgt_len = prev_output_tokens.size()
  151. # Extract the final hidden state from the Encoder.
  152. final_encoder_hidden = encoder_out['final_hidden']
  153. # Embed the target sequence, which has been shifted right by one
  154. # position and now starts with the end-of-sentence symbol.
  155. x = self.embed_tokens(prev_output_tokens)
  156. # Apply dropout.
  157. x = self.dropout(x)
  158. # Concatenate the Encoder's final hidden state to *every* embedded
  159. # target token.
  160. x = torch.cat(
  161. [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
  162. dim=2,
  163. )
  164. # Using PackedSequence objects in the Decoder is harder than in the
  165. # Encoder, since the targets are not sorted in descending length order,
  166. # which is a requirement of ``pack_padded_sequence()``. Instead we'll
  167. # feed nn.LSTM directly.
  168. initial_state = (
  169. final_encoder_hidden.unsqueeze(0), # hidden
  170. torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
  171. )
  172. output, _ = self.lstm(
  173. x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
  174. initial_state,
  175. )
  176. x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
  177. # Project the outputs to the size of the vocabulary.
  178. x = self.output_projection(x)
  179. # Return the logits and ``None`` for the attention weights
  180. return x, None
  181. 2. Registering the Model
  182. ------------------------
  183. Now that we've defined our Encoder and Decoder we must *register* our model with
  184. fairseq using the :func:`~fairseq.models.register_model` function decorator.
  185. Once the model is registered we'll be able to use it with the existing
  186. :ref:`Command-line Tools`.
  187. All registered models must implement the
  188. :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
  189. models (i.e., any model with a single Encoder and Decoder), we can instead
  190. implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
  191. Create a small wrapper class in the same file and register it in fairseq with
  192. the name ``'simple_lstm'``::
  193. from fairseq.models import FairseqEncoderDecoderModel, register_model
  194. # Note: the register_model "decorator" should immediately precede the
  195. # definition of the Model class.
  196. @register_model('simple_lstm')
  197. class SimpleLSTMModel(FairseqEncoderDecoderModel):
  198. @staticmethod
  199. def add_args(parser):
  200. # Models can override this method to add new command-line arguments.
  201. # Here we'll add some new command-line arguments to configure dropout
  202. # and the dimensionality of the embeddings and hidden states.
  203. parser.add_argument(
  204. '--encoder-embed-dim', type=int, metavar='N',
  205. help='dimensionality of the encoder embeddings',
  206. )
  207. parser.add_argument(
  208. '--encoder-hidden-dim', type=int, metavar='N',
  209. help='dimensionality of the encoder hidden state',
  210. )
  211. parser.add_argument(
  212. '--encoder-dropout', type=float, default=0.1,
  213. help='encoder dropout probability',
  214. )
  215. parser.add_argument(
  216. '--decoder-embed-dim', type=int, metavar='N',
  217. help='dimensionality of the decoder embeddings',
  218. )
  219. parser.add_argument(
  220. '--decoder-hidden-dim', type=int, metavar='N',
  221. help='dimensionality of the decoder hidden state',
  222. )
  223. parser.add_argument(
  224. '--decoder-dropout', type=float, default=0.1,
  225. help='decoder dropout probability',
  226. )
  227. @classmethod
  228. def build_model(cls, args, task):
  229. # Fairseq initializes models by calling the ``build_model()``
  230. # function. This provides more flexibility, since the returned model
  231. # instance can be of a different type than the one that was called.
  232. # In this case we'll just return a SimpleLSTMModel instance.
  233. # Initialize our Encoder and Decoder.
  234. encoder = SimpleLSTMEncoder(
  235. args=args,
  236. dictionary=task.source_dictionary,
  237. embed_dim=args.encoder_embed_dim,
  238. hidden_dim=args.encoder_hidden_dim,
  239. dropout=args.encoder_dropout,
  240. )
  241. decoder = SimpleLSTMDecoder(
  242. dictionary=task.target_dictionary,
  243. encoder_hidden_dim=args.encoder_hidden_dim,
  244. embed_dim=args.decoder_embed_dim,
  245. hidden_dim=args.decoder_hidden_dim,
  246. dropout=args.decoder_dropout,
  247. )
  248. model = SimpleLSTMModel(encoder, decoder)
  249. # Print the model architecture.
  250. print(model)
  251. return model
  252. # We could override the ``forward()`` if we wanted more control over how
  253. # the encoder and decoder interact, but it's not necessary for this
  254. # tutorial since we can inherit the default implementation provided by
  255. # the FairseqEncoderDecoderModel base class, which looks like:
  256. #
  257. # def forward(self, src_tokens, src_lengths, prev_output_tokens):
  258. # encoder_out = self.encoder(src_tokens, src_lengths)
  259. # decoder_out = self.decoder(prev_output_tokens, encoder_out)
  260. # return decoder_out
  261. Finally let's define a *named architecture* with the configuration for our
  262. model. This is done with the :func:`~fairseq.models.register_model_architecture`
  263. function decorator. Thereafter this named architecture can be used with the
  264. ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
  265. from fairseq.models import register_model_architecture
  266. # The first argument to ``register_model_architecture()`` should be the name
  267. # of the model we registered above (i.e., 'simple_lstm'). The function we
  268. # register here should take a single argument *args* and modify it in-place
  269. # to match the desired architecture.
  270. @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
  271. def tutorial_simple_lstm(args):
  272. # We use ``getattr()`` to prioritize arguments that are explicitly given
  273. # on the command-line, so that the defaults defined below are only used
  274. # when no other value has been specified.
  275. args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
  276. args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
  277. args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
  278. args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
  279. 3. Training the Model
  280. ---------------------
  281. Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
  282. command-line tool for this, making sure to specify our new Model architecture
  283. (``--arch tutorial_simple_lstm``).
  284. .. note::
  285. Make sure you've already preprocessed the data from the IWSLT example in the
  286. :file:`examples/translation/` directory.
  287. .. code-block:: console
  288. > fairseq-train data-bin/iwslt14.tokenized.de-en \
  289. --arch tutorial_simple_lstm \
  290. --encoder-dropout 0.2 --decoder-dropout 0.2 \
  291. --optimizer adam --lr 0.005 --lr-shrink 0.5 \
  292. --max-tokens 12000
  293. (...)
  294. | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
  295. | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
  296. The model files should appear in the :file:`checkpoints/` directory. While this
  297. model architecture is not very good, we can use the :ref:`fairseq-generate` script to
  298. generate translations and compute our BLEU score over the test set:
  299. .. code-block:: console
  300. > fairseq-generate data-bin/iwslt14.tokenized.de-en \
  301. --path checkpoints/checkpoint_best.pt \
  302. --beam 5 \
  303. --remove-bpe
  304. (...)
  305. | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  306. | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
  307. 4. Making generation faster
  308. ---------------------------
  309. While autoregressive generation from sequence-to-sequence models is inherently
  310. slow, our implementation above is especially slow because it recomputes the
  311. entire sequence of Decoder hidden states for every output token (i.e., it is
  312. ``O(n^2)``). We can make this significantly faster by instead caching the
  313. previous hidden states.
  314. In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
  315. special mode at inference time where the Model only receives a single timestep
  316. of input corresponding to the immediately previous output token (for teacher
  317. forcing) and must produce the next output incrementally. Thus the model must
  318. cache any long-term state that is needed about the sequence, e.g., hidden
  319. states, convolutional states, etc.
  320. To implement incremental decoding we will modify our model to implement the
  321. :class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
  322. standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
  323. decoder interface allows ``forward()`` methods to take an extra keyword argument
  324. (*incremental_state*) that can be used to cache state across time-steps.
  325. Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
  326. import torch
  327. from fairseq.models import FairseqIncrementalDecoder
  328. class SimpleLSTMDecoder(FairseqIncrementalDecoder):
  329. def __init__(
  330. self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
  331. dropout=0.1,
  332. ):
  333. # This remains the same as before.
  334. super().__init__(dictionary)
  335. self.embed_tokens = nn.Embedding(
  336. num_embeddings=len(dictionary),
  337. embedding_dim=embed_dim,
  338. padding_idx=dictionary.pad(),
  339. )
  340. self.dropout = nn.Dropout(p=dropout)
  341. self.lstm = nn.LSTM(
  342. input_size=encoder_hidden_dim + embed_dim,
  343. hidden_size=hidden_dim,
  344. num_layers=1,
  345. bidirectional=False,
  346. )
  347. self.output_projection = nn.Linear(hidden_dim, len(dictionary))
  348. # We now take an additional kwarg (*incremental_state*) for caching the
  349. # previous hidden and cell states.
  350. def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
  351. if incremental_state is not None:
  352. # If the *incremental_state* argument is not ``None`` then we are
  353. # in incremental inference mode. While *prev_output_tokens* will
  354. # still contain the entire decoded prefix, we will only use the
  355. # last step and assume that the rest of the state is cached.
  356. prev_output_tokens = prev_output_tokens[:, -1:]
  357. # This remains the same as before.
  358. bsz, tgt_len = prev_output_tokens.size()
  359. final_encoder_hidden = encoder_out['final_hidden']
  360. x = self.embed_tokens(prev_output_tokens)
  361. x = self.dropout(x)
  362. x = torch.cat(
  363. [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
  364. dim=2,
  365. )
  366. # We will now check the cache and load the cached previous hidden and
  367. # cell states, if they exist, otherwise we will initialize them to
  368. # zeros (as before). We will use the ``utils.get_incremental_state()``
  369. # and ``utils.set_incremental_state()`` helpers.
  370. initial_state = utils.get_incremental_state(
  371. self, incremental_state, 'prev_state',
  372. )
  373. if initial_state is None:
  374. # first time initialization, same as the original version
  375. initial_state = (
  376. final_encoder_hidden.unsqueeze(0), # hidden
  377. torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
  378. )
  379. # Run one step of our LSTM.
  380. output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
  381. # Update the cache with the latest hidden and cell states.
  382. utils.set_incremental_state(
  383. self, incremental_state, 'prev_state', latest_state,
  384. )
  385. # This remains the same as before
  386. x = output.transpose(0, 1)
  387. x = self.output_projection(x)
  388. return x, None
  389. # The ``FairseqIncrementalDecoder`` interface also requires implementing a
  390. # ``reorder_incremental_state()`` method, which is used during beam search
  391. # to select and reorder the incremental state.
  392. def reorder_incremental_state(self, incremental_state, new_order):
  393. # Load the cached state.
  394. prev_state = utils.get_incremental_state(
  395. self, incremental_state, 'prev_state',
  396. )
  397. # Reorder batches according to *new_order*.
  398. reordered_state = (
  399. prev_state[0].index_select(1, new_order), # hidden
  400. prev_state[1].index_select(1, new_order), # cell
  401. )
  402. # Update the cached state.
  403. utils.set_incremental_state(
  404. self, incremental_state, 'prev_state', reordered_state,
  405. )
  406. Finally, we can rerun generation and observe the speedup:
  407. .. code-block:: console
  408. # Before
  409. > fairseq-generate data-bin/iwslt14.tokenized.de-en \
  410. --path checkpoints/checkpoint_best.pt \
  411. --beam 5 \
  412. --remove-bpe
  413. (...)
  414. | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  415. | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
  416. # After
  417. > fairseq-generate data-bin/iwslt14.tokenized.de-en \
  418. --path checkpoints/checkpoint_best.pt \
  419. --beam 5 \
  420. --remove-bpe
  421. (...)
  422. | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
  423. | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)