codeformer_arch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import math
  2. import numpy as np
  3. import torch
  4. from torch import nn, Tensor
  5. import torch.nn.functional as F
  6. from typing import Optional, List
  7. from basicsr.archs.vqgan_arch import *
  8. from basicsr.utils import get_root_logger
  9. from basicsr.utils.registry import ARCH_REGISTRY
  10. def calc_mean_std(feat, eps=1e-5):
  11. """Calculate mean and std for adaptive_instance_normalization.
  12. Args:
  13. feat (Tensor): 4D tensor.
  14. eps (float): A small value added to the variance to avoid
  15. divide-by-zero. Default: 1e-5.
  16. """
  17. size = feat.size()
  18. assert len(size) == 4, 'The input feature should be 4D tensor.'
  19. b, c = size[:2]
  20. feat_var = feat.view(b, c, -1).var(dim=2) + eps
  21. feat_std = feat_var.sqrt().view(b, c, 1, 1)
  22. feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
  23. return feat_mean, feat_std
  24. def adaptive_instance_normalization(content_feat, style_feat):
  25. """Adaptive instance normalization.
  26. Adjust the reference features to have the similar color and illuminations
  27. as those in the degradate features.
  28. Args:
  29. content_feat (Tensor): The reference feature.
  30. style_feat (Tensor): The degradate features.
  31. """
  32. size = content_feat.size()
  33. style_mean, style_std = calc_mean_std(style_feat)
  34. content_mean, content_std = calc_mean_std(content_feat)
  35. normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
  36. return normalized_feat * style_std.expand(size) + style_mean.expand(size)
  37. class PositionEmbeddingSine(nn.Module):
  38. """
  39. This is a more standard version of the position embedding, very similar to the one
  40. used by the Attention is all you need paper, generalized to work on images.
  41. """
  42. def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
  43. super().__init__()
  44. self.num_pos_feats = num_pos_feats
  45. self.temperature = temperature
  46. self.normalize = normalize
  47. if scale is not None and normalize is False:
  48. raise ValueError("normalize should be True if scale is passed")
  49. if scale is None:
  50. scale = 2 * math.pi
  51. self.scale = scale
  52. def forward(self, x, mask=None):
  53. if mask is None:
  54. mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
  55. not_mask = ~mask
  56. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  57. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  58. if self.normalize:
  59. eps = 1e-6
  60. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  61. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  62. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  63. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  64. pos_x = x_embed[:, :, :, None] / dim_t
  65. pos_y = y_embed[:, :, :, None] / dim_t
  66. pos_x = torch.stack(
  67. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
  68. ).flatten(3)
  69. pos_y = torch.stack(
  70. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
  71. ).flatten(3)
  72. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  73. return pos
  74. def _get_activation_fn(activation):
  75. """Return an activation function given a string"""
  76. if activation == "relu":
  77. return F.relu
  78. if activation == "gelu":
  79. return F.gelu
  80. if activation == "glu":
  81. return F.glu
  82. raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
  83. class TransformerSALayer(nn.Module):
  84. def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
  85. super().__init__()
  86. self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
  87. # Implementation of Feedforward model - MLP
  88. self.linear1 = nn.Linear(embed_dim, dim_mlp)
  89. self.dropout = nn.Dropout(dropout)
  90. self.linear2 = nn.Linear(dim_mlp, embed_dim)
  91. self.norm1 = nn.LayerNorm(embed_dim)
  92. self.norm2 = nn.LayerNorm(embed_dim)
  93. self.dropout1 = nn.Dropout(dropout)
  94. self.dropout2 = nn.Dropout(dropout)
  95. self.activation = _get_activation_fn(activation)
  96. def with_pos_embed(self, tensor, pos: Optional[Tensor]):
  97. return tensor if pos is None else tensor + pos
  98. def forward(self, tgt,
  99. tgt_mask: Optional[Tensor] = None,
  100. tgt_key_padding_mask: Optional[Tensor] = None,
  101. query_pos: Optional[Tensor] = None):
  102. # self attention
  103. tgt2 = self.norm1(tgt)
  104. q = k = self.with_pos_embed(tgt2, query_pos)
  105. tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
  106. key_padding_mask=tgt_key_padding_mask)[0]
  107. tgt = tgt + self.dropout1(tgt2)
  108. # ffn
  109. tgt2 = self.norm2(tgt)
  110. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  111. tgt = tgt + self.dropout2(tgt2)
  112. return tgt
  113. class Fuse_sft_block(nn.Module):
  114. def __init__(self, in_ch, out_ch):
  115. super().__init__()
  116. self.encode_enc = ResBlock(2*in_ch, out_ch)
  117. self.scale = nn.Sequential(
  118. nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
  119. nn.LeakyReLU(0.2, True),
  120. nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
  121. self.shift = nn.Sequential(
  122. nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
  123. nn.LeakyReLU(0.2, True),
  124. nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
  125. def forward(self, enc_feat, dec_feat, w=1):
  126. enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
  127. scale = self.scale(enc_feat)
  128. shift = self.shift(enc_feat)
  129. residual = w * (dec_feat * scale + shift)
  130. out = dec_feat + residual
  131. return out
  132. @ARCH_REGISTRY.register()
  133. class CodeFormer(VQAutoEncoder):
  134. def __init__(self, dim_embd=512, n_head=8, n_layers=9,
  135. codebook_size=1024, latent_size=256,
  136. connect_list=['32', '64', '128', '256'],
  137. fix_modules=['quantize','generator'], vqgan_path=None):
  138. super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
  139. if vqgan_path is not None:
  140. self.load_state_dict(
  141. torch.load(vqgan_path, map_location='cpu')['params_ema'])
  142. if fix_modules is not None:
  143. for module in fix_modules:
  144. for param in getattr(self, module).parameters():
  145. param.requires_grad = False
  146. self.connect_list = connect_list
  147. self.n_layers = n_layers
  148. self.dim_embd = dim_embd
  149. self.dim_mlp = dim_embd*2
  150. self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
  151. self.feat_emb = nn.Linear(256, self.dim_embd)
  152. # transformer
  153. self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
  154. for _ in range(self.n_layers)])
  155. # logits_predict head
  156. self.idx_pred_layer = nn.Sequential(
  157. nn.LayerNorm(dim_embd),
  158. nn.Linear(dim_embd, codebook_size, bias=False))
  159. self.channels = {
  160. '16': 512,
  161. '32': 256,
  162. '64': 256,
  163. '128': 128,
  164. '256': 128,
  165. '512': 64,
  166. }
  167. # after second residual block for > 16, before attn layer for ==16
  168. self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
  169. # after first residual block for > 16, before attn layer for ==16
  170. self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
  171. # fuse_convs_dict
  172. self.fuse_convs_dict = nn.ModuleDict()
  173. for f_size in self.connect_list:
  174. in_ch = self.channels[f_size]
  175. self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
  176. def _init_weights(self, module):
  177. if isinstance(module, (nn.Linear, nn.Embedding)):
  178. module.weight.data.normal_(mean=0.0, std=0.02)
  179. if isinstance(module, nn.Linear) and module.bias is not None:
  180. module.bias.data.zero_()
  181. elif isinstance(module, nn.LayerNorm):
  182. module.bias.data.zero_()
  183. module.weight.data.fill_(1.0)
  184. def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
  185. # ################### Encoder #####################
  186. enc_feat_dict = {}
  187. out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
  188. for i, block in enumerate(self.encoder.blocks):
  189. x = block(x)
  190. if i in out_list:
  191. enc_feat_dict[str(x.shape[-1])] = x.clone()
  192. lq_feat = x
  193. # ################# Transformer ###################
  194. # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
  195. pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
  196. # BCHW -> BC(HW) -> (HW)BC
  197. feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
  198. query_emb = feat_emb
  199. # Transformer encoder
  200. for layer in self.ft_layers:
  201. query_emb = layer(query_emb, query_pos=pos_emb)
  202. # output logits
  203. logits = self.idx_pred_layer(query_emb) # (hw)bn
  204. logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
  205. if code_only: # for training stage II
  206. # logits doesn't need softmax before cross_entropy loss
  207. return logits, lq_feat
  208. # ################# Quantization ###################
  209. # if self.training:
  210. # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
  211. # # b(hw)c -> bc(hw) -> bchw
  212. # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
  213. # ------------
  214. soft_one_hot = F.softmax(logits, dim=2)
  215. _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
  216. quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
  217. # preserve gradients
  218. # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
  219. if detach_16:
  220. quant_feat = quant_feat.detach() # for training stage III
  221. if adain:
  222. quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
  223. # ################## Generator ####################
  224. x = quant_feat
  225. fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
  226. for i, block in enumerate(self.generator.blocks):
  227. x = block(x)
  228. if i in fuse_list: # fuse after i-th block
  229. f_size = str(x.shape[-1])
  230. if w>0:
  231. x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
  232. out = x
  233. # logits doesn't need softmax before cross_entropy loss
  234. return out, logits, lq_feat