__init__.py 836 B

1234567891011121314151617181920212223242526
  1. from copy import deepcopy
  2. from basicsr.utils import get_root_logger
  3. from basicsr.utils.registry import LOSS_REGISTRY
  4. from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
  5. gradient_penalty_loss, r1_penalty)
  6. __all__ = [
  7. 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
  8. 'r1_penalty', 'g_path_regularize'
  9. ]
  10. def build_loss(opt):
  11. """Build loss from options.
  12. Args:
  13. opt (dict): Configuration. It must constain:
  14. type (str): Model type.
  15. """
  16. opt = deepcopy(opt)
  17. loss_type = opt.pop('type')
  18. loss = LOSS_REGISTRY.get(loss_type)(**opt)
  19. logger = get_root_logger()
  20. logger.info(f'Loss [{loss.__class__.__name__}] is created.')
  21. return loss