test_cross_entropy.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from examples.speech_recognition.criterions.cross_entropy_acc import (
  7. CrossEntropyWithAccCriterion,
  8. )
  9. from .asr_test_base import CrossEntropyCriterionTestBase
  10. class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
  11. def setUp(self):
  12. self.criterion_cls = CrossEntropyWithAccCriterion
  13. super().setUp()
  14. def test_cross_entropy_all_correct(self):
  15. sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
  16. loss, sample_size, logging_output = self.criterion(
  17. self.model, sample, "sum", log_probs=True
  18. )
  19. assert logging_output["correct"] == 20
  20. assert logging_output["total"] == 20
  21. assert logging_output["sample_size"] == 20
  22. assert logging_output["ntokens"] == 20
  23. def test_cross_entropy_all_wrong(self):
  24. sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
  25. loss, sample_size, logging_output = self.criterion(
  26. self.model, sample, "sum", log_probs=True
  27. )
  28. assert logging_output["correct"] == 0
  29. assert logging_output["total"] == 20
  30. assert logging_output["sample_size"] == 20
  31. assert logging_output["ntokens"] == 20