123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348 |
- # 外部から簡単にupscalerを呼ぶためのスクリプト
- # 単体で動くようにモデル定義も含めている
- import argparse
- import glob
- import os
- import cv2
- from diffusers import AutoencoderKL
- from typing import Dict, List
- import numpy as np
- import torch
- from torch import nn
- from tqdm import tqdm
- from PIL import Image
- class ResidualBlock(nn.Module):
- def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
- super(ResidualBlock, self).__init__()
- if out_channels is None:
- out_channels = in_channels
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
- self.bn1 = nn.BatchNorm2d(out_channels)
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
- self.bn2 = nn.BatchNorm2d(out_channels)
- self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
- # initialize weights
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, 0, 0.01)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- residual = x
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu1(out)
- out = self.conv2(out)
- out = self.bn2(out)
- out += residual
- out = self.relu2(out)
- return out
- class Upscaler(nn.Module):
- def __init__(self):
- super(Upscaler, self).__init__()
- # define layers
- # latent has 4 channels
- self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
- self.bn1 = nn.BatchNorm2d(128)
- self.relu1 = nn.ReLU(inplace=True)
- # resblocks
- # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
- self.resblock1 = ResidualBlock(128)
- self.resblock2 = ResidualBlock(128)
- self.resblock3 = ResidualBlock(128)
- self.resblock4 = ResidualBlock(128)
- self.resblock5 = ResidualBlock(128)
- self.resblock6 = ResidualBlock(128)
- self.resblock7 = ResidualBlock(128)
- self.resblock8 = ResidualBlock(128)
- self.resblock9 = ResidualBlock(128)
- self.resblock10 = ResidualBlock(128)
- self.resblock11 = ResidualBlock(128)
- self.resblock12 = ResidualBlock(128)
- self.resblock13 = ResidualBlock(128)
- self.resblock14 = ResidualBlock(128)
- self.resblock15 = ResidualBlock(128)
- self.resblock16 = ResidualBlock(128)
- self.resblock17 = ResidualBlock(128)
- self.resblock18 = ResidualBlock(128)
- self.resblock19 = ResidualBlock(128)
- self.resblock20 = ResidualBlock(128)
- # last convs
- self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
- self.bn2 = nn.BatchNorm2d(64)
- self.relu2 = nn.ReLU(inplace=True)
- self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
- self.bn3 = nn.BatchNorm2d(64)
- self.relu3 = nn.ReLU(inplace=True)
- # final conv: output 4 channels
- self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
- # initialize weights
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, 0, 0.01)
- nn.init.constant_(m.bias, 0)
- # initialize final conv weights to 0: 流行りのzero conv
- nn.init.constant_(self.conv_final.weight, 0)
- def forward(self, x):
- inp = x
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu1(x)
- # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
- residual = x
- x = self.resblock1(x)
- x = self.resblock2(x)
- x = self.resblock3(x)
- x = self.resblock4(x)
- x = x + residual
- residual = x
- x = self.resblock5(x)
- x = self.resblock6(x)
- x = self.resblock7(x)
- x = self.resblock8(x)
- x = x + residual
- residual = x
- x = self.resblock9(x)
- x = self.resblock10(x)
- x = self.resblock11(x)
- x = self.resblock12(x)
- x = x + residual
- residual = x
- x = self.resblock13(x)
- x = self.resblock14(x)
- x = self.resblock15(x)
- x = self.resblock16(x)
- x = x + residual
- residual = x
- x = self.resblock17(x)
- x = self.resblock18(x)
- x = self.resblock19(x)
- x = self.resblock20(x)
- x = x + residual
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu2(x)
- x = self.conv3(x)
- x = self.bn3(x)
- # ここにreluを入れないほうがいい気がする
- x = self.conv_final(x)
- # network estimates the difference between the input and the output
- x = x + inp
- return x
- def support_latents(self) -> bool:
- return False
- def upscale(
- self,
- vae: AutoencoderKL,
- lowreso_images: List[Image.Image],
- lowreso_latents: torch.Tensor,
- dtype: torch.dtype,
- width: int,
- height: int,
- batch_size: int = 1,
- vae_batch_size: int = 1,
- ):
- # assertion
- assert lowreso_images is not None, "Upscaler requires lowreso image"
- # make upsampled image with lanczos4
- upsampled_images = []
- for lowreso_image in lowreso_images:
- upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
- upsampled_images.append(upsampled_image)
- # convert to tensor: this tensor is too large to be converted to cuda
- upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
- upsampled_images = torch.stack(upsampled_images, dim=0)
- upsampled_images = upsampled_images.to(dtype)
- # normalize to [-1, 1]
- upsampled_images = upsampled_images / 127.5 - 1.0
- # convert upsample images to latents with batch size
- # print("Encoding upsampled (LANCZOS4) images...")
- upsampled_latents = []
- for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
- batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
- with torch.no_grad():
- batch = vae.encode(batch).latent_dist.sample()
- upsampled_latents.append(batch)
- upsampled_latents = torch.cat(upsampled_latents, dim=0)
- # upscale (refine) latents with this model with batch size
- print("Upscaling latents...")
- upscaled_latents = []
- for i in range(0, upsampled_latents.shape[0], batch_size):
- with torch.no_grad():
- upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
- upscaled_latents = torch.cat(upscaled_latents, dim=0)
- return upscaled_latents * 0.18215
- # external interface: returns a model
- def create_upscaler(**kwargs):
- weights = kwargs["weights"]
- model = Upscaler()
- print(f"Loading weights from {weights}...")
- if os.path.splitext(weights)[1] == ".safetensors":
- from safetensors.torch import load_file
- sd = load_file(weights)
- else:
- sd = torch.load(weights, map_location=torch.device("cpu"))
- model.load_state_dict(sd)
- return model
- # another interface: upscale images with a model for given images from command line
- def upscale_images(args: argparse.Namespace):
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- us_dtype = torch.float16 # TODO: support fp32/bf16
- os.makedirs(args.output_dir, exist_ok=True)
- # load VAE with Diffusers
- assert args.vae_path is not None, "VAE path is required"
- print(f"Loading VAE from {args.vae_path}...")
- vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
- vae.to(DEVICE, dtype=us_dtype)
- # prepare model
- print("Preparing model...")
- upscaler: Upscaler = create_upscaler(weights=args.weights)
- # print("Loading weights from", args.weights)
- # upscaler.load_state_dict(torch.load(args.weights))
- upscaler.eval()
- upscaler.to(DEVICE, dtype=us_dtype)
- # load images
- image_paths = glob.glob(args.image_pattern)
- images = []
- for image_path in image_paths:
- image = Image.open(image_path)
- image = image.convert("RGB")
- # make divisible by 8
- width = image.width
- height = image.height
- if width % 8 != 0:
- width = width - (width % 8)
- if height % 8 != 0:
- height = height - (height % 8)
- if width != image.width or height != image.height:
- image = image.crop((0, 0, width, height))
- images.append(image)
- # debug output
- if args.debug:
- for image, image_path in zip(images, image_paths):
- image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
- basename = os.path.basename(image_path)
- basename_wo_ext, ext = os.path.splitext(basename)
- dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
- image_debug.save(dest_file_name)
- # upscale
- print("Upscaling...")
- upscaled_latents = upscaler.upscale(
- vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
- )
- upscaled_latents /= 0.18215
- # decode with batch
- print("Decoding...")
- upscaled_images = []
- for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
- with torch.no_grad():
- batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
- batch = batch.to("cpu")
- upscaled_images.append(batch)
- upscaled_images = torch.cat(upscaled_images, dim=0)
- # tensor to numpy
- upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
- upscaled_images = (upscaled_images + 1.0) * 127.5
- upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
- upscaled_images = upscaled_images[..., ::-1]
- # save images
- for i, image in enumerate(upscaled_images):
- basename = os.path.basename(image_paths[i])
- basename_wo_ext, ext = os.path.splitext(basename)
- dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
- cv2.imwrite(dest_file_name, image)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
- parser.add_argument("--weights", type=str, default=None, help="Weights path")
- parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
- parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
- parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
- parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
- parser.add_argument("--debug", action="store_true", help="Debug mode")
- args = parser.parse_args()
- upscale_images(args)
|