1234567891011121314151617181920212223242526 |
- from copy import deepcopy
- from basicsr.utils import get_root_logger
- from basicsr.utils.registry import LOSS_REGISTRY
- from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
- gradient_penalty_loss, r1_penalty)
- __all__ = [
- 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
- 'r1_penalty', 'g_path_regularize'
- ]
- def build_loss(opt):
- """Build loss from options.
- Args:
- opt (dict): Configuration. It must constain:
- type (str): Model type.
- """
- opt = deepcopy(opt)
- loss_type = opt.pop('type')
- loss = LOSS_REGISTRY.get(loss_type)(**opt)
- logger = get_root_logger()
- logger.info(f'Loss [{loss.__class__.__name__}] is created.')
- return loss
|