inference_colorization.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import os
  2. import cv2
  3. import argparse
  4. import glob
  5. import torch
  6. from torchvision.transforms.functional import normalize
  7. from basicsr.utils import imwrite, img2tensor, tensor2img
  8. from basicsr.utils.download_util import load_file_from_url
  9. from basicsr.utils.misc import get_device
  10. from basicsr.utils.registry import ARCH_REGISTRY
  11. pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_colorization.pth'
  12. if __name__ == '__main__':
  13. # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  14. device = get_device()
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('-i', '--input_path', type=str, default='./inputs/gray_faces',
  17. help='Input image or folder. Default: inputs/gray_faces')
  18. parser.add_argument('-o', '--output_path', type=str, default=None,
  19. help='Output folder. Default: results/<input_name>')
  20. parser.add_argument('--suffix', type=str, default=None,
  21. help='Suffix of the restored faces. Default: None')
  22. args = parser.parse_args()
  23. # ------------------------ input & output ------------------------
  24. print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.')
  25. if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
  26. input_img_list = [args.input_path]
  27. result_root = f'results/test_colorization_img'
  28. else: # input img folder
  29. if args.input_path.endswith('/'): # solve when path ends with /
  30. args.input_path = args.input_path[:-1]
  31. # scan all the jpg and png images
  32. input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
  33. result_root = f'results/{os.path.basename(args.input_path)}'
  34. if not args.output_path is None: # set output path
  35. result_root = args.output_path
  36. test_img_num = len(input_img_list)
  37. # ------------------ set up CodeFormer restorer -------------------
  38. net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
  39. connect_list=['32', '64', '128']).to(device)
  40. # ckpt_path = 'weights/CodeFormer/codeformer.pth'
  41. ckpt_path = load_file_from_url(url=pretrain_model_url,
  42. model_dir='weights/CodeFormer', progress=True, file_name=None)
  43. checkpoint = torch.load(ckpt_path)['params_ema']
  44. net.load_state_dict(checkpoint)
  45. net.eval()
  46. # -------------------- start to processing ---------------------
  47. for i, img_path in enumerate(input_img_list):
  48. img_name = os.path.basename(img_path)
  49. basename, ext = os.path.splitext(img_name)
  50. print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
  51. input_face = cv2.imread(img_path)
  52. assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for colorization.'
  53. # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR)
  54. input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True)
  55. normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
  56. input_face = input_face.unsqueeze(0).to(device)
  57. try:
  58. with torch.no_grad():
  59. # w is fixed to 0 since we didn't train the Stage III for colorization
  60. output_face = net(input_face, w=0, adain=True)[0]
  61. save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1))
  62. del output_face
  63. torch.cuda.empty_cache()
  64. except Exception as error:
  65. print(f'\tFailed inference for CodeFormer: {error}')
  66. save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1))
  67. save_face = save_face.astype('uint8')
  68. # save face
  69. if args.suffix is not None:
  70. basename = f'{basename}_{args.suffix}'
  71. save_restore_path = os.path.join(result_root, f'{basename}.png')
  72. imwrite(save_face, save_restore_path)
  73. print(f'\nAll results are saved in {result_root}')