12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import unittest
- from fairseq.dataclass.utils import convert_namespace_to_omegaconf
- from fairseq.models.transformer import TransformerModel
- from tests.test_sequence_generator import get_dummy_task_and_parser
- class TestInferenceDropout(unittest.TestCase):
- def setUp(self):
- self.task, self.parser = get_dummy_task_and_parser()
- TransformerModel.add_args(self.parser)
- self.args = self.parser.parse_args([])
- self.args.encoder_layers = 2
- self.args.decoder_layers = 1
- logging.disable(logging.CRITICAL)
- def tearDown(self):
- logging.disable(logging.NOTSET)
- def test_sets_inference_dropout_to_true(self):
- self.args.retain_dropout = True
- self.transformer_model = TransformerModel.build_model(self.args, self.task)
- cfg = convert_namespace_to_omegaconf(self.args)
- self.transformer_model.prepare_for_inference_(cfg)
- assert self.transformer_model.encoder.dropout_module.apply_during_inference
- assert self.transformer_model.decoder.dropout_module.apply_during_inference
- for layer in self.transformer_model.encoder.layers:
- assert layer.dropout_module.apply_during_inference
- def test_inference_dropout_false_by_default(self):
- self.transformer_model = TransformerModel.build_model(self.args, self.task)
- cfg = convert_namespace_to_omegaconf(self.args)
- self.transformer_model.prepare_for_inference_(cfg)
- assert not self.transformer_model.encoder.dropout_module.apply_during_inference
- assert not self.transformer_model.decoder.dropout_module.apply_during_inference
- for layer in self.transformer_model.encoder.layers:
- assert not layer.dropout_module.apply_during_inference
- for layer in self.transformer_model.decoder.layers:
- assert not layer.dropout_module.apply_during_inference
- def test_applies_training_mode(self):
- self.transformer_model = TransformerModel.build_model(self.args, self.task)
- assert self.transformer_model.encoder.dropout_module.training
- for layer in self.transformer_model.encoder.layers:
- assert layer.dropout_module.training
- self.transformer_model.eval()
- assert not self.transformer_model.decoder.dropout_module.training
- for layer in self.transformer_model.encoder.layers:
- assert not layer.dropout_module.training
- def test_retain_modules(self):
- self.args.retain_dropout = True
- self.args.retain_dropout_modules = [
- "TransformerEncoder",
- "TransformerEncoderLayer",
- ]
- self.transformer_model = TransformerModel.build_model(self.args, self.task)
- cfg = convert_namespace_to_omegaconf(self.args)
- self.transformer_model.prepare_for_inference_(cfg)
- assert self.transformer_model.encoder.dropout_module.apply_during_inference
- assert not self.transformer_model.decoder.dropout_module.apply_during_inference
- for layer in self.transformer_model.decoder.layers:
- assert not layer.dropout_module.apply_during_inference
|