inference_codeformer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import os
  2. import cv2
  3. import argparse
  4. import glob
  5. import torch
  6. from torchvision.transforms.functional import normalize
  7. from basicsr.utils import imwrite, img2tensor, tensor2img
  8. from basicsr.utils.download_util import load_file_from_url
  9. from basicsr.utils.misc import gpu_is_available, get_device
  10. from facelib.utils.face_restoration_helper import FaceRestoreHelper
  11. from facelib.utils.misc import is_gray
  12. from basicsr.utils.registry import ARCH_REGISTRY
  13. pretrain_model_url = {
  14. 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
  15. }
  16. def set_realesrgan():
  17. from basicsr.archs.rrdbnet_arch import RRDBNet
  18. from basicsr.utils.realesrgan_utils import RealESRGANer
  19. use_half = False
  20. if torch.cuda.is_available(): # set False in CPU/MPS mode
  21. no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
  22. if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
  23. use_half = True
  24. model = RRDBNet(
  25. num_in_ch=3,
  26. num_out_ch=3,
  27. num_feat=64,
  28. num_block=23,
  29. num_grow_ch=32,
  30. scale=2,
  31. )
  32. upsampler = RealESRGANer(
  33. scale=2,
  34. model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
  35. model=model,
  36. tile=args.bg_tile,
  37. tile_pad=40,
  38. pre_pad=0,
  39. half=use_half
  40. )
  41. if not gpu_is_available(): # CPU
  42. import warnings
  43. warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
  44. 'The unoptimized RealESRGAN is slow on CPU. '
  45. 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.',
  46. category=RuntimeWarning)
  47. return upsampler
  48. if __name__ == '__main__':
  49. # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  50. device = get_device()
  51. parser = argparse.ArgumentParser()
  52. parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
  53. help='Input image, video or folder. Default: inputs/whole_imgs')
  54. parser.add_argument('-o', '--output_path', type=str, default=None,
  55. help='Output folder. Default: results/<input_name>_<w>')
  56. parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
  57. help='Balance the quality and fidelity. Default: 0.5')
  58. parser.add_argument('-s', '--upscale', type=int, default=2,
  59. help='The final upsampling scale of the image. Default: 2')
  60. parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
  61. parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
  62. parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
  63. # large det_model: 'YOLOv5l', 'retinaface_resnet50'
  64. # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
  65. parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
  66. help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
  67. Default: retinaface_resnet50')
  68. parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
  69. parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
  70. parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
  71. parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
  72. parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
  73. args = parser.parse_args()
  74. # ------------------------ input & output ------------------------
  75. w = args.fidelity_weight
  76. input_video = False
  77. if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
  78. input_img_list = [args.input_path]
  79. result_root = f'results/test_img_{w}'
  80. elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
  81. from basicsr.utils.video_util import VideoReader, VideoWriter
  82. input_img_list = []
  83. vidreader = VideoReader(args.input_path)
  84. image = vidreader.get_frame()
  85. while image is not None:
  86. input_img_list.append(image)
  87. image = vidreader.get_frame()
  88. audio = vidreader.get_audio()
  89. fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
  90. video_name = os.path.basename(args.input_path)[:-4]
  91. result_root = f'results/{video_name}_{w}'
  92. input_video = True
  93. vidreader.close()
  94. else: # input img folder
  95. if args.input_path.endswith('/'): # solve when path ends with /
  96. args.input_path = args.input_path[:-1]
  97. # scan all the jpg and png images
  98. input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
  99. result_root = f'results/{os.path.basename(args.input_path)}_{w}'
  100. if not args.output_path is None: # set output path
  101. result_root = args.output_path
  102. test_img_num = len(input_img_list)
  103. if test_img_num == 0:
  104. raise FileNotFoundError('No input image/video is found...\n'
  105. '\tNote that --input_path for video should end with .mp4|.mov|.avi')
  106. # ------------------ set up background upsampler ------------------
  107. if args.bg_upsampler == 'realesrgan':
  108. bg_upsampler = set_realesrgan()
  109. else:
  110. bg_upsampler = None
  111. # ------------------ set up face upsampler ------------------
  112. if args.face_upsample:
  113. if bg_upsampler is not None:
  114. face_upsampler = bg_upsampler
  115. else:
  116. face_upsampler = set_realesrgan()
  117. else:
  118. face_upsampler = None
  119. # ------------------ set up CodeFormer restorer -------------------
  120. net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
  121. connect_list=['32', '64', '128', '256']).to(device)
  122. # ckpt_path = 'weights/CodeFormer/codeformer.pth'
  123. ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
  124. model_dir='weights/CodeFormer', progress=True, file_name=None)
  125. checkpoint = torch.load(ckpt_path)['params_ema']
  126. net.load_state_dict(checkpoint)
  127. net.eval()
  128. # ------------------ set up FaceRestoreHelper -------------------
  129. # large det_model: 'YOLOv5l', 'retinaface_resnet50'
  130. # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
  131. if not args.has_aligned:
  132. print(f'Face detection model: {args.detection_model}')
  133. if bg_upsampler is not None:
  134. print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
  135. else:
  136. print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
  137. face_helper = FaceRestoreHelper(
  138. args.upscale,
  139. face_size=512,
  140. crop_ratio=(1, 1),
  141. det_model = args.detection_model,
  142. save_ext='png',
  143. use_parse=True,
  144. device=device)
  145. # -------------------- start to processing ---------------------
  146. for i, img_path in enumerate(input_img_list):
  147. # clean all the intermediate results to process the next image
  148. face_helper.clean_all()
  149. if isinstance(img_path, str):
  150. img_name = os.path.basename(img_path)
  151. basename, ext = os.path.splitext(img_name)
  152. print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
  153. img = cv2.imread(img_path, cv2.IMREAD_COLOR)
  154. else: # for video processing
  155. basename = str(i).zfill(6)
  156. img_name = f'{video_name}_{basename}' if input_video else basename
  157. print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
  158. img = img_path
  159. if args.has_aligned:
  160. # the input faces are already cropped and aligned
  161. img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
  162. face_helper.is_gray = is_gray(img, threshold=10)
  163. if face_helper.is_gray:
  164. print('Grayscale input: True')
  165. face_helper.cropped_faces = [img]
  166. else:
  167. face_helper.read_image(img)
  168. # get face landmarks for each face
  169. num_det_faces = face_helper.get_face_landmarks_5(
  170. only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
  171. print(f'\tdetect {num_det_faces} faces')
  172. # align and warp each face
  173. face_helper.align_warp_face()
  174. # face restoration for each cropped face
  175. for idx, cropped_face in enumerate(face_helper.cropped_faces):
  176. # prepare data
  177. cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
  178. normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
  179. cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
  180. try:
  181. with torch.no_grad():
  182. output = net(cropped_face_t, w=w, adain=True)[0]
  183. restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
  184. del output
  185. torch.cuda.empty_cache()
  186. except Exception as error:
  187. print(f'\tFailed inference for CodeFormer: {error}')
  188. restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
  189. restored_face = restored_face.astype('uint8')
  190. face_helper.add_restored_face(restored_face, cropped_face)
  191. # paste_back
  192. if not args.has_aligned:
  193. # upsample the background
  194. if bg_upsampler is not None:
  195. # Now only support RealESRGAN for upsampling background
  196. bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
  197. else:
  198. bg_img = None
  199. face_helper.get_inverse_affine(None)
  200. # paste each restored face to the input image
  201. if args.face_upsample and face_upsampler is not None:
  202. restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
  203. else:
  204. restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
  205. # save faces
  206. for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
  207. # save cropped face
  208. if not args.has_aligned:
  209. save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
  210. imwrite(cropped_face, save_crop_path)
  211. # save restored face
  212. if args.has_aligned:
  213. save_face_name = f'{basename}.png'
  214. else:
  215. save_face_name = f'{basename}_{idx:02d}.png'
  216. if args.suffix is not None:
  217. save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
  218. save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
  219. imwrite(restored_face, save_restore_path)
  220. # save restored img
  221. if not args.has_aligned and restored_img is not None:
  222. if args.suffix is not None:
  223. basename = f'{basename}_{args.suffix}'
  224. save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
  225. imwrite(restored_img, save_restore_path)
  226. # save enhanced video
  227. if input_video:
  228. print('Video Saving...')
  229. # load images
  230. video_frames = []
  231. img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
  232. for img_path in img_list:
  233. img = cv2.imread(img_path)
  234. video_frames.append(img)
  235. # write images to video
  236. height, width = video_frames[0].shape[:2]
  237. if args.suffix is not None:
  238. video_name = f'{video_name}_{args.suffix}.png'
  239. save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
  240. vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
  241. for f in video_frames:
  242. vidwriter.write_frame(f)
  243. vidwriter.close()
  244. print(f'\nAll results are saved in {result_root}')