esrgan_model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import os
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. from basicsr.utils.download_util import load_file_from_url
  6. import modules.esrgan_model_arch as arch
  7. from modules import shared, modelloader, images, devices
  8. from modules.upscaler import Upscaler, UpscalerData
  9. from modules.shared import opts
  10. def mod2normal(state_dict):
  11. # this code is copied from https://github.com/victorca25/iNNfer
  12. if 'conv_first.weight' in state_dict:
  13. crt_net = {}
  14. items = []
  15. for k, v in state_dict.items():
  16. items.append(k)
  17. crt_net['model.0.weight'] = state_dict['conv_first.weight']
  18. crt_net['model.0.bias'] = state_dict['conv_first.bias']
  19. for k in items.copy():
  20. if 'RDB' in k:
  21. ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
  22. if '.weight' in k:
  23. ori_k = ori_k.replace('.weight', '.0.weight')
  24. elif '.bias' in k:
  25. ori_k = ori_k.replace('.bias', '.0.bias')
  26. crt_net[ori_k] = state_dict[k]
  27. items.remove(k)
  28. crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
  29. crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
  30. crt_net['model.3.weight'] = state_dict['upconv1.weight']
  31. crt_net['model.3.bias'] = state_dict['upconv1.bias']
  32. crt_net['model.6.weight'] = state_dict['upconv2.weight']
  33. crt_net['model.6.bias'] = state_dict['upconv2.bias']
  34. crt_net['model.8.weight'] = state_dict['HRconv.weight']
  35. crt_net['model.8.bias'] = state_dict['HRconv.bias']
  36. crt_net['model.10.weight'] = state_dict['conv_last.weight']
  37. crt_net['model.10.bias'] = state_dict['conv_last.bias']
  38. state_dict = crt_net
  39. return state_dict
  40. def resrgan2normal(state_dict, nb=23):
  41. # this code is copied from https://github.com/victorca25/iNNfer
  42. if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
  43. re8x = 0
  44. crt_net = {}
  45. items = []
  46. for k, v in state_dict.items():
  47. items.append(k)
  48. crt_net['model.0.weight'] = state_dict['conv_first.weight']
  49. crt_net['model.0.bias'] = state_dict['conv_first.bias']
  50. for k in items.copy():
  51. if "rdb" in k:
  52. ori_k = k.replace('body.', 'model.1.sub.')
  53. ori_k = ori_k.replace('.rdb', '.RDB')
  54. if '.weight' in k:
  55. ori_k = ori_k.replace('.weight', '.0.weight')
  56. elif '.bias' in k:
  57. ori_k = ori_k.replace('.bias', '.0.bias')
  58. crt_net[ori_k] = state_dict[k]
  59. items.remove(k)
  60. crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
  61. crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
  62. crt_net['model.3.weight'] = state_dict['conv_up1.weight']
  63. crt_net['model.3.bias'] = state_dict['conv_up1.bias']
  64. crt_net['model.6.weight'] = state_dict['conv_up2.weight']
  65. crt_net['model.6.bias'] = state_dict['conv_up2.bias']
  66. if 'conv_up3.weight' in state_dict:
  67. # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
  68. re8x = 3
  69. crt_net['model.9.weight'] = state_dict['conv_up3.weight']
  70. crt_net['model.9.bias'] = state_dict['conv_up3.bias']
  71. crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
  72. crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
  73. crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
  74. crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
  75. state_dict = crt_net
  76. return state_dict
  77. def infer_params(state_dict):
  78. # this code is copied from https://github.com/victorca25/iNNfer
  79. scale2x = 0
  80. scalemin = 6
  81. n_uplayer = 0
  82. plus = False
  83. for block in list(state_dict):
  84. parts = block.split(".")
  85. n_parts = len(parts)
  86. if n_parts == 5 and parts[2] == "sub":
  87. nb = int(parts[3])
  88. elif n_parts == 3:
  89. part_num = int(parts[1])
  90. if (part_num > scalemin
  91. and parts[0] == "model"
  92. and parts[2] == "weight"):
  93. scale2x += 1
  94. if part_num > n_uplayer:
  95. n_uplayer = part_num
  96. out_nc = state_dict[block].shape[0]
  97. if not plus and "conv1x1" in block:
  98. plus = True
  99. nf = state_dict["model.0.weight"].shape[0]
  100. in_nc = state_dict["model.0.weight"].shape[1]
  101. out_nc = out_nc
  102. scale = 2 ** scale2x
  103. return in_nc, out_nc, nf, nb, plus, scale
  104. class UpscalerESRGAN(Upscaler):
  105. def __init__(self, dirname):
  106. self.name = "ESRGAN"
  107. self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
  108. self.model_name = "ESRGAN_4x"
  109. self.scalers = []
  110. self.user_path = dirname
  111. super().__init__()
  112. model_paths = self.find_models(ext_filter=[".pt", ".pth"])
  113. scalers = []
  114. if len(model_paths) == 0:
  115. scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
  116. scalers.append(scaler_data)
  117. for file in model_paths:
  118. if "http" in file:
  119. name = self.model_name
  120. else:
  121. name = modelloader.friendly_name(file)
  122. scaler_data = UpscalerData(name, file, self, 4)
  123. self.scalers.append(scaler_data)
  124. def do_upscale(self, img, selected_model):
  125. model = self.load_model(selected_model)
  126. if model is None:
  127. return img
  128. model.to(devices.device_esrgan)
  129. img = esrgan_upscale(model, img)
  130. return img
  131. def load_model(self, path: str):
  132. if "http" in path:
  133. filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
  134. file_name="%s.pth" % self.model_name,
  135. progress=True)
  136. else:
  137. filename = path
  138. if not os.path.exists(filename) or filename is None:
  139. print("Unable to load %s from %s" % (self.model_path, filename))
  140. return None
  141. state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
  142. if "params_ema" in state_dict:
  143. state_dict = state_dict["params_ema"]
  144. elif "params" in state_dict:
  145. state_dict = state_dict["params"]
  146. num_conv = 16 if "realesr-animevideov3" in filename else 32
  147. model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
  148. model.load_state_dict(state_dict)
  149. model.eval()
  150. return model
  151. if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
  152. nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
  153. state_dict = resrgan2normal(state_dict, nb)
  154. elif "conv_first.weight" in state_dict:
  155. state_dict = mod2normal(state_dict)
  156. elif "model.0.weight" not in state_dict:
  157. raise Exception("The file is not a recognized ESRGAN model.")
  158. in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
  159. model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
  160. model.load_state_dict(state_dict)
  161. model.eval()
  162. return model
  163. def upscale_without_tiling(model, img):
  164. img = np.array(img)
  165. img = img[:, :, ::-1]
  166. img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
  167. img = torch.from_numpy(img).float()
  168. img = img.unsqueeze(0).to(devices.device_esrgan)
  169. with torch.no_grad():
  170. output = model(img)
  171. output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
  172. output = 255. * np.moveaxis(output, 0, 2)
  173. output = output.astype(np.uint8)
  174. output = output[:, :, ::-1]
  175. return Image.fromarray(output, 'RGB')
  176. def esrgan_upscale(model, img):
  177. if opts.ESRGAN_tile == 0:
  178. return upscale_without_tiling(model, img)
  179. grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
  180. newtiles = []
  181. scale_factor = 1
  182. for y, h, row in grid.tiles:
  183. newrow = []
  184. for tiledata in row:
  185. x, w, tile = tiledata
  186. output = upscale_without_tiling(model, tile)
  187. scale_factor = output.width // tile.width
  188. newrow.append([x * scale_factor, w * scale_factor, output])
  189. newtiles.append([y * scale_factor, h * scale_factor, newrow])
  190. newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
  191. output = images.combine_grid(newgrid)
  192. return output