losses.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import math
  2. import lpips
  3. import torch
  4. from torch import autograd as autograd
  5. from torch import nn as nn
  6. from torch.nn import functional as F
  7. from basicsr.archs.vgg_arch import VGGFeatureExtractor
  8. from basicsr.utils.registry import LOSS_REGISTRY
  9. from .loss_util import weighted_loss
  10. _reduction_modes = ['none', 'mean', 'sum']
  11. @weighted_loss
  12. def l1_loss(pred, target):
  13. return F.l1_loss(pred, target, reduction='none')
  14. @weighted_loss
  15. def mse_loss(pred, target):
  16. return F.mse_loss(pred, target, reduction='none')
  17. @weighted_loss
  18. def charbonnier_loss(pred, target, eps=1e-12):
  19. return torch.sqrt((pred - target)**2 + eps)
  20. @LOSS_REGISTRY.register()
  21. class L1Loss(nn.Module):
  22. """L1 (mean absolute error, MAE) loss.
  23. Args:
  24. loss_weight (float): Loss weight for L1 loss. Default: 1.0.
  25. reduction (str): Specifies the reduction to apply to the output.
  26. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
  27. """
  28. def __init__(self, loss_weight=1.0, reduction='mean'):
  29. super(L1Loss, self).__init__()
  30. if reduction not in ['none', 'mean', 'sum']:
  31. raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
  32. self.loss_weight = loss_weight
  33. self.reduction = reduction
  34. def forward(self, pred, target, weight=None, **kwargs):
  35. """
  36. Args:
  37. pred (Tensor): of shape (N, C, H, W). Predicted tensor.
  38. target (Tensor): of shape (N, C, H, W). Ground truth tensor.
  39. weight (Tensor, optional): of shape (N, C, H, W). Element-wise
  40. weights. Default: None.
  41. """
  42. return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
  43. @LOSS_REGISTRY.register()
  44. class MSELoss(nn.Module):
  45. """MSE (L2) loss.
  46. Args:
  47. loss_weight (float): Loss weight for MSE loss. Default: 1.0.
  48. reduction (str): Specifies the reduction to apply to the output.
  49. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
  50. """
  51. def __init__(self, loss_weight=1.0, reduction='mean'):
  52. super(MSELoss, self).__init__()
  53. if reduction not in ['none', 'mean', 'sum']:
  54. raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
  55. self.loss_weight = loss_weight
  56. self.reduction = reduction
  57. def forward(self, pred, target, weight=None, **kwargs):
  58. """
  59. Args:
  60. pred (Tensor): of shape (N, C, H, W). Predicted tensor.
  61. target (Tensor): of shape (N, C, H, W). Ground truth tensor.
  62. weight (Tensor, optional): of shape (N, C, H, W). Element-wise
  63. weights. Default: None.
  64. """
  65. return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
  66. @LOSS_REGISTRY.register()
  67. class CharbonnierLoss(nn.Module):
  68. """Charbonnier loss (one variant of Robust L1Loss, a differentiable
  69. variant of L1Loss).
  70. Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
  71. Super-Resolution".
  72. Args:
  73. loss_weight (float): Loss weight for L1 loss. Default: 1.0.
  74. reduction (str): Specifies the reduction to apply to the output.
  75. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
  76. eps (float): A value used to control the curvature near zero.
  77. Default: 1e-12.
  78. """
  79. def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
  80. super(CharbonnierLoss, self).__init__()
  81. if reduction not in ['none', 'mean', 'sum']:
  82. raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
  83. self.loss_weight = loss_weight
  84. self.reduction = reduction
  85. self.eps = eps
  86. def forward(self, pred, target, weight=None, **kwargs):
  87. """
  88. Args:
  89. pred (Tensor): of shape (N, C, H, W). Predicted tensor.
  90. target (Tensor): of shape (N, C, H, W). Ground truth tensor.
  91. weight (Tensor, optional): of shape (N, C, H, W). Element-wise
  92. weights. Default: None.
  93. """
  94. return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
  95. @LOSS_REGISTRY.register()
  96. class WeightedTVLoss(L1Loss):
  97. """Weighted TV loss.
  98. Args:
  99. loss_weight (float): Loss weight. Default: 1.0.
  100. """
  101. def __init__(self, loss_weight=1.0):
  102. super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
  103. def forward(self, pred, weight=None):
  104. y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
  105. x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
  106. loss = x_diff + y_diff
  107. return loss
  108. @LOSS_REGISTRY.register()
  109. class PerceptualLoss(nn.Module):
  110. """Perceptual loss with commonly used style loss.
  111. Args:
  112. layer_weights (dict): The weight for each layer of vgg feature.
  113. Here is an example: {'conv5_4': 1.}, which means the conv5_4
  114. feature layer (before relu5_4) will be extracted with weight
  115. 1.0 in calculting losses.
  116. vgg_type (str): The type of vgg network used as feature extractor.
  117. Default: 'vgg19'.
  118. use_input_norm (bool): If True, normalize the input image in vgg.
  119. Default: True.
  120. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
  121. Default: False.
  122. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
  123. loss will be calculated and the loss will multiplied by the
  124. weight. Default: 1.0.
  125. style_weight (float): If `style_weight > 0`, the style loss will be
  126. calculated and the loss will multiplied by the weight.
  127. Default: 0.
  128. criterion (str): Criterion used for perceptual loss. Default: 'l1'.
  129. """
  130. def __init__(self,
  131. layer_weights,
  132. vgg_type='vgg19',
  133. use_input_norm=True,
  134. range_norm=False,
  135. perceptual_weight=1.0,
  136. style_weight=0.,
  137. criterion='l1'):
  138. super(PerceptualLoss, self).__init__()
  139. self.perceptual_weight = perceptual_weight
  140. self.style_weight = style_weight
  141. self.layer_weights = layer_weights
  142. self.vgg = VGGFeatureExtractor(
  143. layer_name_list=list(layer_weights.keys()),
  144. vgg_type=vgg_type,
  145. use_input_norm=use_input_norm,
  146. range_norm=range_norm)
  147. self.criterion_type = criterion
  148. if self.criterion_type == 'l1':
  149. self.criterion = torch.nn.L1Loss()
  150. elif self.criterion_type == 'l2':
  151. self.criterion = torch.nn.L2loss()
  152. elif self.criterion_type == 'mse':
  153. self.criterion = torch.nn.MSELoss(reduction='mean')
  154. elif self.criterion_type == 'fro':
  155. self.criterion = None
  156. else:
  157. raise NotImplementedError(f'{criterion} criterion has not been supported.')
  158. def forward(self, x, gt):
  159. """Forward function.
  160. Args:
  161. x (Tensor): Input tensor with shape (n, c, h, w).
  162. gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
  163. Returns:
  164. Tensor: Forward results.
  165. """
  166. # extract vgg features
  167. x_features = self.vgg(x)
  168. gt_features = self.vgg(gt.detach())
  169. # calculate perceptual loss
  170. if self.perceptual_weight > 0:
  171. percep_loss = 0
  172. for k in x_features.keys():
  173. if self.criterion_type == 'fro':
  174. percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
  175. else:
  176. percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
  177. percep_loss *= self.perceptual_weight
  178. else:
  179. percep_loss = None
  180. # calculate style loss
  181. if self.style_weight > 0:
  182. style_loss = 0
  183. for k in x_features.keys():
  184. if self.criterion_type == 'fro':
  185. style_loss += torch.norm(
  186. self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
  187. else:
  188. style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
  189. gt_features[k])) * self.layer_weights[k]
  190. style_loss *= self.style_weight
  191. else:
  192. style_loss = None
  193. return percep_loss, style_loss
  194. def _gram_mat(self, x):
  195. """Calculate Gram matrix.
  196. Args:
  197. x (torch.Tensor): Tensor with shape of (n, c, h, w).
  198. Returns:
  199. torch.Tensor: Gram matrix.
  200. """
  201. n, c, h, w = x.size()
  202. features = x.view(n, c, w * h)
  203. features_t = features.transpose(1, 2)
  204. gram = features.bmm(features_t) / (c * h * w)
  205. return gram
  206. @LOSS_REGISTRY.register()
  207. class LPIPSLoss(nn.Module):
  208. def __init__(self,
  209. loss_weight=1.0,
  210. use_input_norm=True,
  211. range_norm=False,):
  212. super(LPIPSLoss, self).__init__()
  213. self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
  214. self.loss_weight = loss_weight
  215. self.use_input_norm = use_input_norm
  216. self.range_norm = range_norm
  217. if self.use_input_norm:
  218. # the mean is for image with range [0, 1]
  219. self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
  220. # the std is for image with range [0, 1]
  221. self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
  222. def forward(self, pred, target):
  223. if self.range_norm:
  224. pred = (pred + 1) / 2
  225. target = (target + 1) / 2
  226. if self.use_input_norm:
  227. pred = (pred - self.mean) / self.std
  228. target = (target - self.mean) / self.std
  229. lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
  230. return self.loss_weight * lpips_loss.mean()
  231. @LOSS_REGISTRY.register()
  232. class GANLoss(nn.Module):
  233. """Define GAN loss.
  234. Args:
  235. gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
  236. real_label_val (float): The value for real label. Default: 1.0.
  237. fake_label_val (float): The value for fake label. Default: 0.0.
  238. loss_weight (float): Loss weight. Default: 1.0.
  239. Note that loss_weight is only for generators; and it is always 1.0
  240. for discriminators.
  241. """
  242. def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
  243. super(GANLoss, self).__init__()
  244. self.gan_type = gan_type
  245. self.loss_weight = loss_weight
  246. self.real_label_val = real_label_val
  247. self.fake_label_val = fake_label_val
  248. if self.gan_type == 'vanilla':
  249. self.loss = nn.BCEWithLogitsLoss()
  250. elif self.gan_type == 'lsgan':
  251. self.loss = nn.MSELoss()
  252. elif self.gan_type == 'wgan':
  253. self.loss = self._wgan_loss
  254. elif self.gan_type == 'wgan_softplus':
  255. self.loss = self._wgan_softplus_loss
  256. elif self.gan_type == 'hinge':
  257. self.loss = nn.ReLU()
  258. else:
  259. raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
  260. def _wgan_loss(self, input, target):
  261. """wgan loss.
  262. Args:
  263. input (Tensor): Input tensor.
  264. target (bool): Target label.
  265. Returns:
  266. Tensor: wgan loss.
  267. """
  268. return -input.mean() if target else input.mean()
  269. def _wgan_softplus_loss(self, input, target):
  270. """wgan loss with soft plus. softplus is a smooth approximation to the
  271. ReLU function.
  272. In StyleGAN2, it is called:
  273. Logistic loss for discriminator;
  274. Non-saturating loss for generator.
  275. Args:
  276. input (Tensor): Input tensor.
  277. target (bool): Target label.
  278. Returns:
  279. Tensor: wgan loss.
  280. """
  281. return F.softplus(-input).mean() if target else F.softplus(input).mean()
  282. def get_target_label(self, input, target_is_real):
  283. """Get target label.
  284. Args:
  285. input (Tensor): Input tensor.
  286. target_is_real (bool): Whether the target is real or fake.
  287. Returns:
  288. (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
  289. return Tensor.
  290. """
  291. if self.gan_type in ['wgan', 'wgan_softplus']:
  292. return target_is_real
  293. target_val = (self.real_label_val if target_is_real else self.fake_label_val)
  294. return input.new_ones(input.size()) * target_val
  295. def forward(self, input, target_is_real, is_disc=False):
  296. """
  297. Args:
  298. input (Tensor): The input for the loss module, i.e., the network
  299. prediction.
  300. target_is_real (bool): Whether the targe is real or fake.
  301. is_disc (bool): Whether the loss for discriminators or not.
  302. Default: False.
  303. Returns:
  304. Tensor: GAN loss value.
  305. """
  306. if self.gan_type == 'hinge':
  307. if is_disc: # for discriminators in hinge-gan
  308. input = -input if target_is_real else input
  309. loss = self.loss(1 + input).mean()
  310. else: # for generators in hinge-gan
  311. loss = -input.mean()
  312. else: # other gan types
  313. target_label = self.get_target_label(input, target_is_real)
  314. loss = self.loss(input, target_label)
  315. # loss_weight is always 1.0 for discriminators
  316. return loss if is_disc else loss * self.loss_weight
  317. def r1_penalty(real_pred, real_img):
  318. """R1 regularization for discriminator. The core idea is to
  319. penalize the gradient on real data alone: when the
  320. generator distribution produces the true data distribution
  321. and the discriminator is equal to 0 on the data manifold, the
  322. gradient penalty ensures that the discriminator cannot create
  323. a non-zero gradient orthogonal to the data manifold without
  324. suffering a loss in the GAN game.
  325. Ref:
  326. Eq. 9 in Which training methods for GANs do actually converge.
  327. """
  328. grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
  329. grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
  330. return grad_penalty
  331. def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
  332. noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
  333. grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
  334. path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
  335. path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
  336. path_penalty = (path_lengths - path_mean).pow(2).mean()
  337. return path_penalty, path_lengths.detach().mean(), path_mean.detach()
  338. def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
  339. """Calculate gradient penalty for wgan-gp.
  340. Args:
  341. discriminator (nn.Module): Network for the discriminator.
  342. real_data (Tensor): Real input data.
  343. fake_data (Tensor): Fake input data.
  344. weight (Tensor): Weight tensor. Default: None.
  345. Returns:
  346. Tensor: A tensor for gradient penalty.
  347. """
  348. batch_size = real_data.size(0)
  349. alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
  350. # interpolate between real_data and fake_data
  351. interpolates = alpha * real_data + (1. - alpha) * fake_data
  352. interpolates = autograd.Variable(interpolates, requires_grad=True)
  353. disc_interpolates = discriminator(interpolates)
  354. gradients = autograd.grad(
  355. outputs=disc_interpolates,
  356. inputs=interpolates,
  357. grad_outputs=torch.ones_like(disc_interpolates),
  358. create_graph=True,
  359. retain_graph=True,
  360. only_inputs=True)[0]
  361. if weight is not None:
  362. gradients = gradients * weight
  363. gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
  364. if weight is not None:
  365. gradients_penalty /= torch.mean(weight)
  366. return gradients_penalty