test_inference_dropout.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import unittest
  7. from fairseq.dataclass.utils import convert_namespace_to_omegaconf
  8. from fairseq.models.transformer import TransformerModel
  9. from tests.test_sequence_generator import get_dummy_task_and_parser
  10. class TestInferenceDropout(unittest.TestCase):
  11. def setUp(self):
  12. self.task, self.parser = get_dummy_task_and_parser()
  13. TransformerModel.add_args(self.parser)
  14. self.args = self.parser.parse_args([])
  15. self.args.encoder_layers = 2
  16. self.args.decoder_layers = 1
  17. logging.disable(logging.CRITICAL)
  18. def tearDown(self):
  19. logging.disable(logging.NOTSET)
  20. def test_sets_inference_dropout_to_true(self):
  21. self.args.retain_dropout = True
  22. self.transformer_model = TransformerModel.build_model(self.args, self.task)
  23. cfg = convert_namespace_to_omegaconf(self.args)
  24. self.transformer_model.prepare_for_inference_(cfg)
  25. assert self.transformer_model.encoder.dropout_module.apply_during_inference
  26. assert self.transformer_model.decoder.dropout_module.apply_during_inference
  27. for layer in self.transformer_model.encoder.layers:
  28. assert layer.dropout_module.apply_during_inference
  29. def test_inference_dropout_false_by_default(self):
  30. self.transformer_model = TransformerModel.build_model(self.args, self.task)
  31. cfg = convert_namespace_to_omegaconf(self.args)
  32. self.transformer_model.prepare_for_inference_(cfg)
  33. assert not self.transformer_model.encoder.dropout_module.apply_during_inference
  34. assert not self.transformer_model.decoder.dropout_module.apply_during_inference
  35. for layer in self.transformer_model.encoder.layers:
  36. assert not layer.dropout_module.apply_during_inference
  37. for layer in self.transformer_model.decoder.layers:
  38. assert not layer.dropout_module.apply_during_inference
  39. def test_applies_training_mode(self):
  40. self.transformer_model = TransformerModel.build_model(self.args, self.task)
  41. assert self.transformer_model.encoder.dropout_module.training
  42. for layer in self.transformer_model.encoder.layers:
  43. assert layer.dropout_module.training
  44. self.transformer_model.eval()
  45. assert not self.transformer_model.decoder.dropout_module.training
  46. for layer in self.transformer_model.encoder.layers:
  47. assert not layer.dropout_module.training
  48. def test_retain_modules(self):
  49. self.args.retain_dropout = True
  50. self.args.retain_dropout_modules = [
  51. "TransformerEncoder",
  52. "TransformerEncoderLayer",
  53. ]
  54. self.transformer_model = TransformerModel.build_model(self.args, self.task)
  55. cfg = convert_namespace_to_omegaconf(self.args)
  56. self.transformer_model.prepare_for_inference_(cfg)
  57. assert self.transformer_model.encoder.dropout_module.apply_during_inference
  58. assert not self.transformer_model.decoder.dropout_module.apply_during_inference
  59. for layer in self.transformer_model.decoder.layers:
  60. assert not layer.dropout_module.apply_during_inference