latent_upscaler.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # 外部から簡単にupscalerを呼ぶためのスクリプト
  2. # 単体で動くようにモデル定義も含めている
  3. import argparse
  4. import glob
  5. import os
  6. import cv2
  7. from diffusers import AutoencoderKL
  8. from typing import Dict, List
  9. import numpy as np
  10. import torch
  11. from torch import nn
  12. from tqdm import tqdm
  13. from PIL import Image
  14. class ResidualBlock(nn.Module):
  15. def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
  16. super(ResidualBlock, self).__init__()
  17. if out_channels is None:
  18. out_channels = in_channels
  19. self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
  20. self.bn1 = nn.BatchNorm2d(out_channels)
  21. self.relu1 = nn.ReLU(inplace=True)
  22. self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
  23. self.bn2 = nn.BatchNorm2d(out_channels)
  24. self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
  25. # initialize weights
  26. self._initialize_weights()
  27. def _initialize_weights(self):
  28. for m in self.modules():
  29. if isinstance(m, nn.Conv2d):
  30. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  31. if m.bias is not None:
  32. nn.init.constant_(m.bias, 0)
  33. elif isinstance(m, nn.BatchNorm2d):
  34. nn.init.constant_(m.weight, 1)
  35. nn.init.constant_(m.bias, 0)
  36. elif isinstance(m, nn.Linear):
  37. nn.init.normal_(m.weight, 0, 0.01)
  38. nn.init.constant_(m.bias, 0)
  39. def forward(self, x):
  40. residual = x
  41. out = self.conv1(x)
  42. out = self.bn1(out)
  43. out = self.relu1(out)
  44. out = self.conv2(out)
  45. out = self.bn2(out)
  46. out += residual
  47. out = self.relu2(out)
  48. return out
  49. class Upscaler(nn.Module):
  50. def __init__(self):
  51. super(Upscaler, self).__init__()
  52. # define layers
  53. # latent has 4 channels
  54. self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  55. self.bn1 = nn.BatchNorm2d(128)
  56. self.relu1 = nn.ReLU(inplace=True)
  57. # resblocks
  58. # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
  59. self.resblock1 = ResidualBlock(128)
  60. self.resblock2 = ResidualBlock(128)
  61. self.resblock3 = ResidualBlock(128)
  62. self.resblock4 = ResidualBlock(128)
  63. self.resblock5 = ResidualBlock(128)
  64. self.resblock6 = ResidualBlock(128)
  65. self.resblock7 = ResidualBlock(128)
  66. self.resblock8 = ResidualBlock(128)
  67. self.resblock9 = ResidualBlock(128)
  68. self.resblock10 = ResidualBlock(128)
  69. self.resblock11 = ResidualBlock(128)
  70. self.resblock12 = ResidualBlock(128)
  71. self.resblock13 = ResidualBlock(128)
  72. self.resblock14 = ResidualBlock(128)
  73. self.resblock15 = ResidualBlock(128)
  74. self.resblock16 = ResidualBlock(128)
  75. self.resblock17 = ResidualBlock(128)
  76. self.resblock18 = ResidualBlock(128)
  77. self.resblock19 = ResidualBlock(128)
  78. self.resblock20 = ResidualBlock(128)
  79. # last convs
  80. self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  81. self.bn2 = nn.BatchNorm2d(64)
  82. self.relu2 = nn.ReLU(inplace=True)
  83. self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  84. self.bn3 = nn.BatchNorm2d(64)
  85. self.relu3 = nn.ReLU(inplace=True)
  86. # final conv: output 4 channels
  87. self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
  88. # initialize weights
  89. self._initialize_weights()
  90. def _initialize_weights(self):
  91. for m in self.modules():
  92. if isinstance(m, nn.Conv2d):
  93. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  94. if m.bias is not None:
  95. nn.init.constant_(m.bias, 0)
  96. elif isinstance(m, nn.BatchNorm2d):
  97. nn.init.constant_(m.weight, 1)
  98. nn.init.constant_(m.bias, 0)
  99. elif isinstance(m, nn.Linear):
  100. nn.init.normal_(m.weight, 0, 0.01)
  101. nn.init.constant_(m.bias, 0)
  102. # initialize final conv weights to 0: 流行りのzero conv
  103. nn.init.constant_(self.conv_final.weight, 0)
  104. def forward(self, x):
  105. inp = x
  106. x = self.conv1(x)
  107. x = self.bn1(x)
  108. x = self.relu1(x)
  109. # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
  110. residual = x
  111. x = self.resblock1(x)
  112. x = self.resblock2(x)
  113. x = self.resblock3(x)
  114. x = self.resblock4(x)
  115. x = x + residual
  116. residual = x
  117. x = self.resblock5(x)
  118. x = self.resblock6(x)
  119. x = self.resblock7(x)
  120. x = self.resblock8(x)
  121. x = x + residual
  122. residual = x
  123. x = self.resblock9(x)
  124. x = self.resblock10(x)
  125. x = self.resblock11(x)
  126. x = self.resblock12(x)
  127. x = x + residual
  128. residual = x
  129. x = self.resblock13(x)
  130. x = self.resblock14(x)
  131. x = self.resblock15(x)
  132. x = self.resblock16(x)
  133. x = x + residual
  134. residual = x
  135. x = self.resblock17(x)
  136. x = self.resblock18(x)
  137. x = self.resblock19(x)
  138. x = self.resblock20(x)
  139. x = x + residual
  140. x = self.conv2(x)
  141. x = self.bn2(x)
  142. x = self.relu2(x)
  143. x = self.conv3(x)
  144. x = self.bn3(x)
  145. # ここにreluを入れないほうがいい気がする
  146. x = self.conv_final(x)
  147. # network estimates the difference between the input and the output
  148. x = x + inp
  149. return x
  150. def support_latents(self) -> bool:
  151. return False
  152. def upscale(
  153. self,
  154. vae: AutoencoderKL,
  155. lowreso_images: List[Image.Image],
  156. lowreso_latents: torch.Tensor,
  157. dtype: torch.dtype,
  158. width: int,
  159. height: int,
  160. batch_size: int = 1,
  161. vae_batch_size: int = 1,
  162. ):
  163. # assertion
  164. assert lowreso_images is not None, "Upscaler requires lowreso image"
  165. # make upsampled image with lanczos4
  166. upsampled_images = []
  167. for lowreso_image in lowreso_images:
  168. upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
  169. upsampled_images.append(upsampled_image)
  170. # convert to tensor: this tensor is too large to be converted to cuda
  171. upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
  172. upsampled_images = torch.stack(upsampled_images, dim=0)
  173. upsampled_images = upsampled_images.to(dtype)
  174. # normalize to [-1, 1]
  175. upsampled_images = upsampled_images / 127.5 - 1.0
  176. # convert upsample images to latents with batch size
  177. # print("Encoding upsampled (LANCZOS4) images...")
  178. upsampled_latents = []
  179. for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
  180. batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
  181. with torch.no_grad():
  182. batch = vae.encode(batch).latent_dist.sample()
  183. upsampled_latents.append(batch)
  184. upsampled_latents = torch.cat(upsampled_latents, dim=0)
  185. # upscale (refine) latents with this model with batch size
  186. print("Upscaling latents...")
  187. upscaled_latents = []
  188. for i in range(0, upsampled_latents.shape[0], batch_size):
  189. with torch.no_grad():
  190. upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
  191. upscaled_latents = torch.cat(upscaled_latents, dim=0)
  192. return upscaled_latents * 0.18215
  193. # external interface: returns a model
  194. def create_upscaler(**kwargs):
  195. weights = kwargs["weights"]
  196. model = Upscaler()
  197. print(f"Loading weights from {weights}...")
  198. if os.path.splitext(weights)[1] == ".safetensors":
  199. from safetensors.torch import load_file
  200. sd = load_file(weights)
  201. else:
  202. sd = torch.load(weights, map_location=torch.device("cpu"))
  203. model.load_state_dict(sd)
  204. return model
  205. # another interface: upscale images with a model for given images from command line
  206. def upscale_images(args: argparse.Namespace):
  207. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  208. us_dtype = torch.float16 # TODO: support fp32/bf16
  209. os.makedirs(args.output_dir, exist_ok=True)
  210. # load VAE with Diffusers
  211. assert args.vae_path is not None, "VAE path is required"
  212. print(f"Loading VAE from {args.vae_path}...")
  213. vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
  214. vae.to(DEVICE, dtype=us_dtype)
  215. # prepare model
  216. print("Preparing model...")
  217. upscaler: Upscaler = create_upscaler(weights=args.weights)
  218. # print("Loading weights from", args.weights)
  219. # upscaler.load_state_dict(torch.load(args.weights))
  220. upscaler.eval()
  221. upscaler.to(DEVICE, dtype=us_dtype)
  222. # load images
  223. image_paths = glob.glob(args.image_pattern)
  224. images = []
  225. for image_path in image_paths:
  226. image = Image.open(image_path)
  227. image = image.convert("RGB")
  228. # make divisible by 8
  229. width = image.width
  230. height = image.height
  231. if width % 8 != 0:
  232. width = width - (width % 8)
  233. if height % 8 != 0:
  234. height = height - (height % 8)
  235. if width != image.width or height != image.height:
  236. image = image.crop((0, 0, width, height))
  237. images.append(image)
  238. # debug output
  239. if args.debug:
  240. for image, image_path in zip(images, image_paths):
  241. image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
  242. basename = os.path.basename(image_path)
  243. basename_wo_ext, ext = os.path.splitext(basename)
  244. dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
  245. image_debug.save(dest_file_name)
  246. # upscale
  247. print("Upscaling...")
  248. upscaled_latents = upscaler.upscale(
  249. vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
  250. )
  251. upscaled_latents /= 0.18215
  252. # decode with batch
  253. print("Decoding...")
  254. upscaled_images = []
  255. for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
  256. with torch.no_grad():
  257. batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
  258. batch = batch.to("cpu")
  259. upscaled_images.append(batch)
  260. upscaled_images = torch.cat(upscaled_images, dim=0)
  261. # tensor to numpy
  262. upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
  263. upscaled_images = (upscaled_images + 1.0) * 127.5
  264. upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
  265. upscaled_images = upscaled_images[..., ::-1]
  266. # save images
  267. for i, image in enumerate(upscaled_images):
  268. basename = os.path.basename(image_paths[i])
  269. basename_wo_ext, ext = os.path.splitext(basename)
  270. dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
  271. cv2.imwrite(dest_file_name, image)
  272. if __name__ == "__main__":
  273. parser = argparse.ArgumentParser()
  274. parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
  275. parser.add_argument("--weights", type=str, default=None, help="Weights path")
  276. parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
  277. parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
  278. parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
  279. parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
  280. parser.add_argument("--debug", action="store_true", help="Debug mode")
  281. args = parser.parse_args()
  282. upscale_images(args)