123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518 |
- Tutorial: Simple LSTM
- =====================
- In this tutorial we will extend fairseq by adding a new
- :class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
- sentence with an LSTM and then passes the final hidden state to a second LSTM
- that decodes the target sentence (without attention).
- This tutorial covers:
- 1. **Writing an Encoder and Decoder** to encode/decode the source/target
- sentence, respectively.
- 2. **Registering a new Model** so that it can be used with the existing
- :ref:`Command-line tools`.
- 3. **Training the Model** using the existing command-line tools.
- 4. **Making generation faster** by modifying the Decoder to use
- :ref:`Incremental decoding`.
- 1. Building an Encoder and Decoder
- ----------------------------------
- In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
- should implement the :class:`~fairseq.models.FairseqEncoder` interface and
- Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
- These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
- and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
- Modules.
- Encoder
- ~~~~~~~
- Our Encoder will embed the tokens in the source sentence, feed them to a
- :class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
- save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
- import torch.nn as nn
- from fairseq import utils
- from fairseq.models import FairseqEncoder
- class SimpleLSTMEncoder(FairseqEncoder):
- def __init__(
- self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
- ):
- super().__init__(dictionary)
- self.args = args
- # Our encoder will embed the inputs before feeding them to the LSTM.
- self.embed_tokens = nn.Embedding(
- num_embeddings=len(dictionary),
- embedding_dim=embed_dim,
- padding_idx=dictionary.pad(),
- )
- self.dropout = nn.Dropout(p=dropout)
- # We'll use a single-layer, unidirectional LSTM for simplicity.
- self.lstm = nn.LSTM(
- input_size=embed_dim,
- hidden_size=hidden_dim,
- num_layers=1,
- bidirectional=False,
- batch_first=True,
- )
- def forward(self, src_tokens, src_lengths):
- # The inputs to the ``forward()`` function are determined by the
- # Task, and in particular the ``'net_input'`` key in each
- # mini-batch. We discuss Tasks in the next tutorial, but for now just
- # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
- # has shape `(batch)`.
- # Note that the source is typically padded on the left. This can be
- # configured by adding the `--left-pad-source "False"` command-line
- # argument, but here we'll make the Encoder handle either kind of
- # padding by converting everything to be right-padded.
- if self.args.left_pad_source:
- # Convert left-padding to right-padding.
- src_tokens = utils.convert_padding_direction(
- src_tokens,
- padding_idx=self.dictionary.pad(),
- left_to_right=True
- )
- # Embed the source.
- x = self.embed_tokens(src_tokens)
- # Apply dropout.
- x = self.dropout(x)
- # Pack the sequence into a PackedSequence object to feed to the LSTM.
- x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
- # Get the output from the LSTM.
- _outputs, (final_hidden, _final_cell) = self.lstm(x)
- # Return the Encoder's output. This can be any object and will be
- # passed directly to the Decoder.
- return {
- # this will have shape `(bsz, hidden_dim)`
- 'final_hidden': final_hidden.squeeze(0),
- }
- # Encoders are required to implement this method so that we can rearrange
- # the order of the batch elements during inference (e.g., beam search).
- def reorder_encoder_out(self, encoder_out, new_order):
- """
- Reorder encoder output according to `new_order`.
- Args:
- encoder_out: output from the ``forward()`` method
- new_order (LongTensor): desired order
- Returns:
- `encoder_out` rearranged according to `new_order`
- """
- final_hidden = encoder_out['final_hidden']
- return {
- 'final_hidden': final_hidden.index_select(0, new_order),
- }
- Decoder
- ~~~~~~~
- Our Decoder will predict the next word, conditioned on the Encoder's final
- hidden state and an embedded representation of the previous target word -- which
- is sometimes called *teacher forcing*. More specifically, we'll use a
- :class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
- to the size of the output vocabulary to predict each target word.
- ::
- import torch
- from fairseq.models import FairseqDecoder
- class SimpleLSTMDecoder(FairseqDecoder):
- def __init__(
- self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
- dropout=0.1,
- ):
- super().__init__(dictionary)
- # Our decoder will embed the inputs before feeding them to the LSTM.
- self.embed_tokens = nn.Embedding(
- num_embeddings=len(dictionary),
- embedding_dim=embed_dim,
- padding_idx=dictionary.pad(),
- )
- self.dropout = nn.Dropout(p=dropout)
- # We'll use a single-layer, unidirectional LSTM for simplicity.
- self.lstm = nn.LSTM(
- # For the first layer we'll concatenate the Encoder's final hidden
- # state with the embedded target tokens.
- input_size=encoder_hidden_dim + embed_dim,
- hidden_size=hidden_dim,
- num_layers=1,
- bidirectional=False,
- )
- # Define the output projection.
- self.output_projection = nn.Linear(hidden_dim, len(dictionary))
- # During training Decoders are expected to take the entire target sequence
- # (shifted right by one position) and produce logits over the vocabulary.
- # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
- # ``dictionary.eos()``, followed by the target sequence.
- def forward(self, prev_output_tokens, encoder_out):
- """
- Args:
- prev_output_tokens (LongTensor): previous decoder outputs of shape
- `(batch, tgt_len)`, for teacher forcing
- encoder_out (Tensor, optional): output from the encoder, used for
- encoder-side attention
- Returns:
- tuple:
- - the last decoder layer's output of shape
- `(batch, tgt_len, vocab)`
- - the last decoder layer's attention weights of shape
- `(batch, tgt_len, src_len)`
- """
- bsz, tgt_len = prev_output_tokens.size()
- # Extract the final hidden state from the Encoder.
- final_encoder_hidden = encoder_out['final_hidden']
- # Embed the target sequence, which has been shifted right by one
- # position and now starts with the end-of-sentence symbol.
- x = self.embed_tokens(prev_output_tokens)
- # Apply dropout.
- x = self.dropout(x)
- # Concatenate the Encoder's final hidden state to *every* embedded
- # target token.
- x = torch.cat(
- [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
- dim=2,
- )
- # Using PackedSequence objects in the Decoder is harder than in the
- # Encoder, since the targets are not sorted in descending length order,
- # which is a requirement of ``pack_padded_sequence()``. Instead we'll
- # feed nn.LSTM directly.
- initial_state = (
- final_encoder_hidden.unsqueeze(0), # hidden
- torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
- )
- output, _ = self.lstm(
- x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
- initial_state,
- )
- x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
- # Project the outputs to the size of the vocabulary.
- x = self.output_projection(x)
- # Return the logits and ``None`` for the attention weights
- return x, None
- 2. Registering the Model
- ------------------------
- Now that we've defined our Encoder and Decoder we must *register* our model with
- fairseq using the :func:`~fairseq.models.register_model` function decorator.
- Once the model is registered we'll be able to use it with the existing
- :ref:`Command-line Tools`.
- All registered models must implement the
- :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
- models (i.e., any model with a single Encoder and Decoder), we can instead
- implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
- Create a small wrapper class in the same file and register it in fairseq with
- the name ``'simple_lstm'``::
- from fairseq.models import FairseqEncoderDecoderModel, register_model
- # Note: the register_model "decorator" should immediately precede the
- # definition of the Model class.
- @register_model('simple_lstm')
- class SimpleLSTMModel(FairseqEncoderDecoderModel):
- @staticmethod
- def add_args(parser):
- # Models can override this method to add new command-line arguments.
- # Here we'll add some new command-line arguments to configure dropout
- # and the dimensionality of the embeddings and hidden states.
- parser.add_argument(
- '--encoder-embed-dim', type=int, metavar='N',
- help='dimensionality of the encoder embeddings',
- )
- parser.add_argument(
- '--encoder-hidden-dim', type=int, metavar='N',
- help='dimensionality of the encoder hidden state',
- )
- parser.add_argument(
- '--encoder-dropout', type=float, default=0.1,
- help='encoder dropout probability',
- )
- parser.add_argument(
- '--decoder-embed-dim', type=int, metavar='N',
- help='dimensionality of the decoder embeddings',
- )
- parser.add_argument(
- '--decoder-hidden-dim', type=int, metavar='N',
- help='dimensionality of the decoder hidden state',
- )
- parser.add_argument(
- '--decoder-dropout', type=float, default=0.1,
- help='decoder dropout probability',
- )
- @classmethod
- def build_model(cls, args, task):
- # Fairseq initializes models by calling the ``build_model()``
- # function. This provides more flexibility, since the returned model
- # instance can be of a different type than the one that was called.
- # In this case we'll just return a SimpleLSTMModel instance.
- # Initialize our Encoder and Decoder.
- encoder = SimpleLSTMEncoder(
- args=args,
- dictionary=task.source_dictionary,
- embed_dim=args.encoder_embed_dim,
- hidden_dim=args.encoder_hidden_dim,
- dropout=args.encoder_dropout,
- )
- decoder = SimpleLSTMDecoder(
- dictionary=task.target_dictionary,
- encoder_hidden_dim=args.encoder_hidden_dim,
- embed_dim=args.decoder_embed_dim,
- hidden_dim=args.decoder_hidden_dim,
- dropout=args.decoder_dropout,
- )
- model = SimpleLSTMModel(encoder, decoder)
- # Print the model architecture.
- print(model)
- return model
- # We could override the ``forward()`` if we wanted more control over how
- # the encoder and decoder interact, but it's not necessary for this
- # tutorial since we can inherit the default implementation provided by
- # the FairseqEncoderDecoderModel base class, which looks like:
- #
- # def forward(self, src_tokens, src_lengths, prev_output_tokens):
- # encoder_out = self.encoder(src_tokens, src_lengths)
- # decoder_out = self.decoder(prev_output_tokens, encoder_out)
- # return decoder_out
- Finally let's define a *named architecture* with the configuration for our
- model. This is done with the :func:`~fairseq.models.register_model_architecture`
- function decorator. Thereafter this named architecture can be used with the
- ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
- from fairseq.models import register_model_architecture
- # The first argument to ``register_model_architecture()`` should be the name
- # of the model we registered above (i.e., 'simple_lstm'). The function we
- # register here should take a single argument *args* and modify it in-place
- # to match the desired architecture.
- @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
- def tutorial_simple_lstm(args):
- # We use ``getattr()`` to prioritize arguments that are explicitly given
- # on the command-line, so that the defaults defined below are only used
- # when no other value has been specified.
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
- args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
- args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
- 3. Training the Model
- ---------------------
- Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
- command-line tool for this, making sure to specify our new Model architecture
- (``--arch tutorial_simple_lstm``).
- .. note::
- Make sure you've already preprocessed the data from the IWSLT example in the
- :file:`examples/translation/` directory.
- .. code-block:: console
- > fairseq-train data-bin/iwslt14.tokenized.de-en \
- --arch tutorial_simple_lstm \
- --encoder-dropout 0.2 --decoder-dropout 0.2 \
- --optimizer adam --lr 0.005 --lr-shrink 0.5 \
- --max-tokens 12000
- (...)
- | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
- | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
- The model files should appear in the :file:`checkpoints/` directory. While this
- model architecture is not very good, we can use the :ref:`fairseq-generate` script to
- generate translations and compute our BLEU score over the test set:
- .. code-block:: console
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
- --path checkpoints/checkpoint_best.pt \
- --beam 5 \
- --remove-bpe
- (...)
- | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
- | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
- 4. Making generation faster
- ---------------------------
- While autoregressive generation from sequence-to-sequence models is inherently
- slow, our implementation above is especially slow because it recomputes the
- entire sequence of Decoder hidden states for every output token (i.e., it is
- ``O(n^2)``). We can make this significantly faster by instead caching the
- previous hidden states.
- In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
- special mode at inference time where the Model only receives a single timestep
- of input corresponding to the immediately previous output token (for teacher
- forcing) and must produce the next output incrementally. Thus the model must
- cache any long-term state that is needed about the sequence, e.g., hidden
- states, convolutional states, etc.
- To implement incremental decoding we will modify our model to implement the
- :class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
- standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
- decoder interface allows ``forward()`` methods to take an extra keyword argument
- (*incremental_state*) that can be used to cache state across time-steps.
- Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
- import torch
- from fairseq.models import FairseqIncrementalDecoder
- class SimpleLSTMDecoder(FairseqIncrementalDecoder):
- def __init__(
- self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
- dropout=0.1,
- ):
- # This remains the same as before.
- super().__init__(dictionary)
- self.embed_tokens = nn.Embedding(
- num_embeddings=len(dictionary),
- embedding_dim=embed_dim,
- padding_idx=dictionary.pad(),
- )
- self.dropout = nn.Dropout(p=dropout)
- self.lstm = nn.LSTM(
- input_size=encoder_hidden_dim + embed_dim,
- hidden_size=hidden_dim,
- num_layers=1,
- bidirectional=False,
- )
- self.output_projection = nn.Linear(hidden_dim, len(dictionary))
- # We now take an additional kwarg (*incremental_state*) for caching the
- # previous hidden and cell states.
- def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
- if incremental_state is not None:
- # If the *incremental_state* argument is not ``None`` then we are
- # in incremental inference mode. While *prev_output_tokens* will
- # still contain the entire decoded prefix, we will only use the
- # last step and assume that the rest of the state is cached.
- prev_output_tokens = prev_output_tokens[:, -1:]
- # This remains the same as before.
- bsz, tgt_len = prev_output_tokens.size()
- final_encoder_hidden = encoder_out['final_hidden']
- x = self.embed_tokens(prev_output_tokens)
- x = self.dropout(x)
- x = torch.cat(
- [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
- dim=2,
- )
- # We will now check the cache and load the cached previous hidden and
- # cell states, if they exist, otherwise we will initialize them to
- # zeros (as before). We will use the ``utils.get_incremental_state()``
- # and ``utils.set_incremental_state()`` helpers.
- initial_state = utils.get_incremental_state(
- self, incremental_state, 'prev_state',
- )
- if initial_state is None:
- # first time initialization, same as the original version
- initial_state = (
- final_encoder_hidden.unsqueeze(0), # hidden
- torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
- )
- # Run one step of our LSTM.
- output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
- # Update the cache with the latest hidden and cell states.
- utils.set_incremental_state(
- self, incremental_state, 'prev_state', latest_state,
- )
- # This remains the same as before
- x = output.transpose(0, 1)
- x = self.output_projection(x)
- return x, None
- # The ``FairseqIncrementalDecoder`` interface also requires implementing a
- # ``reorder_incremental_state()`` method, which is used during beam search
- # to select and reorder the incremental state.
- def reorder_incremental_state(self, incremental_state, new_order):
- # Load the cached state.
- prev_state = utils.get_incremental_state(
- self, incremental_state, 'prev_state',
- )
- # Reorder batches according to *new_order*.
- reordered_state = (
- prev_state[0].index_select(1, new_order), # hidden
- prev_state[1].index_select(1, new_order), # cell
- )
- # Update the cached state.
- utils.set_incremental_state(
- self, incremental_state, 'prev_state', reordered_state,
- )
- Finally, we can rerun generation and observe the speedup:
- .. code-block:: console
- # Before
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
- --path checkpoints/checkpoint_best.pt \
- --beam 5 \
- --remove-bpe
- (...)
- | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
- | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
- # After
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
- --path checkpoints/checkpoint_best.pt \
- --beam 5 \
- --remove-bpe
- (...)
- | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
- | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|