123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455 |
- import math
- import lpips
- import torch
- from torch import autograd as autograd
- from torch import nn as nn
- from torch.nn import functional as F
- from basicsr.archs.vgg_arch import VGGFeatureExtractor
- from basicsr.utils.registry import LOSS_REGISTRY
- from .loss_util import weighted_loss
- _reduction_modes = ['none', 'mean', 'sum']
- @weighted_loss
- def l1_loss(pred, target):
- return F.l1_loss(pred, target, reduction='none')
- @weighted_loss
- def mse_loss(pred, target):
- return F.mse_loss(pred, target, reduction='none')
- @weighted_loss
- def charbonnier_loss(pred, target, eps=1e-12):
- return torch.sqrt((pred - target)**2 + eps)
- @LOSS_REGISTRY.register()
- class L1Loss(nn.Module):
- """L1 (mean absolute error, MAE) loss.
- Args:
- loss_weight (float): Loss weight for L1 loss. Default: 1.0.
- reduction (str): Specifies the reduction to apply to the output.
- Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
- """
- def __init__(self, loss_weight=1.0, reduction='mean'):
- super(L1Loss, self).__init__()
- if reduction not in ['none', 'mean', 'sum']:
- raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
- self.loss_weight = loss_weight
- self.reduction = reduction
- def forward(self, pred, target, weight=None, **kwargs):
- """
- Args:
- pred (Tensor): of shape (N, C, H, W). Predicted tensor.
- target (Tensor): of shape (N, C, H, W). Ground truth tensor.
- weight (Tensor, optional): of shape (N, C, H, W). Element-wise
- weights. Default: None.
- """
- return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
- @LOSS_REGISTRY.register()
- class MSELoss(nn.Module):
- """MSE (L2) loss.
- Args:
- loss_weight (float): Loss weight for MSE loss. Default: 1.0.
- reduction (str): Specifies the reduction to apply to the output.
- Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
- """
- def __init__(self, loss_weight=1.0, reduction='mean'):
- super(MSELoss, self).__init__()
- if reduction not in ['none', 'mean', 'sum']:
- raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
- self.loss_weight = loss_weight
- self.reduction = reduction
- def forward(self, pred, target, weight=None, **kwargs):
- """
- Args:
- pred (Tensor): of shape (N, C, H, W). Predicted tensor.
- target (Tensor): of shape (N, C, H, W). Ground truth tensor.
- weight (Tensor, optional): of shape (N, C, H, W). Element-wise
- weights. Default: None.
- """
- return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
- @LOSS_REGISTRY.register()
- class CharbonnierLoss(nn.Module):
- """Charbonnier loss (one variant of Robust L1Loss, a differentiable
- variant of L1Loss).
- Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
- Super-Resolution".
- Args:
- loss_weight (float): Loss weight for L1 loss. Default: 1.0.
- reduction (str): Specifies the reduction to apply to the output.
- Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
- eps (float): A value used to control the curvature near zero.
- Default: 1e-12.
- """
- def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
- super(CharbonnierLoss, self).__init__()
- if reduction not in ['none', 'mean', 'sum']:
- raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
- self.loss_weight = loss_weight
- self.reduction = reduction
- self.eps = eps
- def forward(self, pred, target, weight=None, **kwargs):
- """
- Args:
- pred (Tensor): of shape (N, C, H, W). Predicted tensor.
- target (Tensor): of shape (N, C, H, W). Ground truth tensor.
- weight (Tensor, optional): of shape (N, C, H, W). Element-wise
- weights. Default: None.
- """
- return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
- @LOSS_REGISTRY.register()
- class WeightedTVLoss(L1Loss):
- """Weighted TV loss.
- Args:
- loss_weight (float): Loss weight. Default: 1.0.
- """
- def __init__(self, loss_weight=1.0):
- super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
- def forward(self, pred, weight=None):
- y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
- x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
- loss = x_diff + y_diff
- return loss
- @LOSS_REGISTRY.register()
- class PerceptualLoss(nn.Module):
- """Perceptual loss with commonly used style loss.
- Args:
- layer_weights (dict): The weight for each layer of vgg feature.
- Here is an example: {'conv5_4': 1.}, which means the conv5_4
- feature layer (before relu5_4) will be extracted with weight
- 1.0 in calculting losses.
- vgg_type (str): The type of vgg network used as feature extractor.
- Default: 'vgg19'.
- use_input_norm (bool): If True, normalize the input image in vgg.
- Default: True.
- range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
- Default: False.
- perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
- loss will be calculated and the loss will multiplied by the
- weight. Default: 1.0.
- style_weight (float): If `style_weight > 0`, the style loss will be
- calculated and the loss will multiplied by the weight.
- Default: 0.
- criterion (str): Criterion used for perceptual loss. Default: 'l1'.
- """
- def __init__(self,
- layer_weights,
- vgg_type='vgg19',
- use_input_norm=True,
- range_norm=False,
- perceptual_weight=1.0,
- style_weight=0.,
- criterion='l1'):
- super(PerceptualLoss, self).__init__()
- self.perceptual_weight = perceptual_weight
- self.style_weight = style_weight
- self.layer_weights = layer_weights
- self.vgg = VGGFeatureExtractor(
- layer_name_list=list(layer_weights.keys()),
- vgg_type=vgg_type,
- use_input_norm=use_input_norm,
- range_norm=range_norm)
- self.criterion_type = criterion
- if self.criterion_type == 'l1':
- self.criterion = torch.nn.L1Loss()
- elif self.criterion_type == 'l2':
- self.criterion = torch.nn.L2loss()
- elif self.criterion_type == 'mse':
- self.criterion = torch.nn.MSELoss(reduction='mean')
- elif self.criterion_type == 'fro':
- self.criterion = None
- else:
- raise NotImplementedError(f'{criterion} criterion has not been supported.')
- def forward(self, x, gt):
- """Forward function.
- Args:
- x (Tensor): Input tensor with shape (n, c, h, w).
- gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
- Returns:
- Tensor: Forward results.
- """
- # extract vgg features
- x_features = self.vgg(x)
- gt_features = self.vgg(gt.detach())
- # calculate perceptual loss
- if self.perceptual_weight > 0:
- percep_loss = 0
- for k in x_features.keys():
- if self.criterion_type == 'fro':
- percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
- else:
- percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
- percep_loss *= self.perceptual_weight
- else:
- percep_loss = None
- # calculate style loss
- if self.style_weight > 0:
- style_loss = 0
- for k in x_features.keys():
- if self.criterion_type == 'fro':
- style_loss += torch.norm(
- self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
- else:
- style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
- gt_features[k])) * self.layer_weights[k]
- style_loss *= self.style_weight
- else:
- style_loss = None
- return percep_loss, style_loss
- def _gram_mat(self, x):
- """Calculate Gram matrix.
- Args:
- x (torch.Tensor): Tensor with shape of (n, c, h, w).
- Returns:
- torch.Tensor: Gram matrix.
- """
- n, c, h, w = x.size()
- features = x.view(n, c, w * h)
- features_t = features.transpose(1, 2)
- gram = features.bmm(features_t) / (c * h * w)
- return gram
- @LOSS_REGISTRY.register()
- class LPIPSLoss(nn.Module):
- def __init__(self,
- loss_weight=1.0,
- use_input_norm=True,
- range_norm=False,):
- super(LPIPSLoss, self).__init__()
- self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
- self.loss_weight = loss_weight
- self.use_input_norm = use_input_norm
- self.range_norm = range_norm
- if self.use_input_norm:
- # the mean is for image with range [0, 1]
- self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
- # the std is for image with range [0, 1]
- self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
- def forward(self, pred, target):
- if self.range_norm:
- pred = (pred + 1) / 2
- target = (target + 1) / 2
- if self.use_input_norm:
- pred = (pred - self.mean) / self.std
- target = (target - self.mean) / self.std
- lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
- return self.loss_weight * lpips_loss.mean()
- @LOSS_REGISTRY.register()
- class GANLoss(nn.Module):
- """Define GAN loss.
- Args:
- gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
- real_label_val (float): The value for real label. Default: 1.0.
- fake_label_val (float): The value for fake label. Default: 0.0.
- loss_weight (float): Loss weight. Default: 1.0.
- Note that loss_weight is only for generators; and it is always 1.0
- for discriminators.
- """
- def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
- super(GANLoss, self).__init__()
- self.gan_type = gan_type
- self.loss_weight = loss_weight
- self.real_label_val = real_label_val
- self.fake_label_val = fake_label_val
- if self.gan_type == 'vanilla':
- self.loss = nn.BCEWithLogitsLoss()
- elif self.gan_type == 'lsgan':
- self.loss = nn.MSELoss()
- elif self.gan_type == 'wgan':
- self.loss = self._wgan_loss
- elif self.gan_type == 'wgan_softplus':
- self.loss = self._wgan_softplus_loss
- elif self.gan_type == 'hinge':
- self.loss = nn.ReLU()
- else:
- raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
- def _wgan_loss(self, input, target):
- """wgan loss.
- Args:
- input (Tensor): Input tensor.
- target (bool): Target label.
- Returns:
- Tensor: wgan loss.
- """
- return -input.mean() if target else input.mean()
- def _wgan_softplus_loss(self, input, target):
- """wgan loss with soft plus. softplus is a smooth approximation to the
- ReLU function.
- In StyleGAN2, it is called:
- Logistic loss for discriminator;
- Non-saturating loss for generator.
- Args:
- input (Tensor): Input tensor.
- target (bool): Target label.
- Returns:
- Tensor: wgan loss.
- """
- return F.softplus(-input).mean() if target else F.softplus(input).mean()
- def get_target_label(self, input, target_is_real):
- """Get target label.
- Args:
- input (Tensor): Input tensor.
- target_is_real (bool): Whether the target is real or fake.
- Returns:
- (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
- return Tensor.
- """
- if self.gan_type in ['wgan', 'wgan_softplus']:
- return target_is_real
- target_val = (self.real_label_val if target_is_real else self.fake_label_val)
- return input.new_ones(input.size()) * target_val
- def forward(self, input, target_is_real, is_disc=False):
- """
- Args:
- input (Tensor): The input for the loss module, i.e., the network
- prediction.
- target_is_real (bool): Whether the targe is real or fake.
- is_disc (bool): Whether the loss for discriminators or not.
- Default: False.
- Returns:
- Tensor: GAN loss value.
- """
- if self.gan_type == 'hinge':
- if is_disc: # for discriminators in hinge-gan
- input = -input if target_is_real else input
- loss = self.loss(1 + input).mean()
- else: # for generators in hinge-gan
- loss = -input.mean()
- else: # other gan types
- target_label = self.get_target_label(input, target_is_real)
- loss = self.loss(input, target_label)
- # loss_weight is always 1.0 for discriminators
- return loss if is_disc else loss * self.loss_weight
- def r1_penalty(real_pred, real_img):
- """R1 regularization for discriminator. The core idea is to
- penalize the gradient on real data alone: when the
- generator distribution produces the true data distribution
- and the discriminator is equal to 0 on the data manifold, the
- gradient penalty ensures that the discriminator cannot create
- a non-zero gradient orthogonal to the data manifold without
- suffering a loss in the GAN game.
- Ref:
- Eq. 9 in Which training methods for GANs do actually converge.
- """
- grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
- grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
- return grad_penalty
- def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
- noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
- grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
- path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
- path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
- path_penalty = (path_lengths - path_mean).pow(2).mean()
- return path_penalty, path_lengths.detach().mean(), path_mean.detach()
- def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
- """Calculate gradient penalty for wgan-gp.
- Args:
- discriminator (nn.Module): Network for the discriminator.
- real_data (Tensor): Real input data.
- fake_data (Tensor): Fake input data.
- weight (Tensor): Weight tensor. Default: None.
- Returns:
- Tensor: A tensor for gradient penalty.
- """
- batch_size = real_data.size(0)
- alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
- # interpolate between real_data and fake_data
- interpolates = alpha * real_data + (1. - alpha) * fake_data
- interpolates = autograd.Variable(interpolates, requires_grad=True)
- disc_interpolates = discriminator(interpolates)
- gradients = autograd.grad(
- outputs=disc_interpolates,
- inputs=interpolates,
- grad_outputs=torch.ones_like(disc_interpolates),
- create_graph=True,
- retain_graph=True,
- only_inputs=True)[0]
- if weight is not None:
- gradients = gradients * weight
- gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
- if weight is not None:
- gradients_penalty /= torch.mean(weight)
- return gradients_penalty
|