detect_face_rotate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
  2. # (c) 2022 Kohya S. @kohya_ss
  3. # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
  4. # v2: extract max face if multiple faces are found
  5. # v3: add crop_ratio option
  6. # v4: add multiple faces extraction and min/max size
  7. import argparse
  8. import math
  9. import cv2
  10. import glob
  11. import os
  12. from anime_face_detector import create_detector
  13. from tqdm import tqdm
  14. import numpy as np
  15. KP_REYE = 11
  16. KP_LEYE = 19
  17. SCORE_THRES = 0.90
  18. def detect_faces(detector, image, min_size):
  19. preds = detector(image) # bgr
  20. # print(len(preds))
  21. faces = []
  22. for pred in preds:
  23. bb = pred['bbox']
  24. score = bb[-1]
  25. if score < SCORE_THRES:
  26. continue
  27. left, top, right, bottom = bb[:4]
  28. cx = int((left + right) / 2)
  29. cy = int((top + bottom) / 2)
  30. fw = int(right - left)
  31. fh = int(bottom - top)
  32. lex, ley = pred['keypoints'][KP_LEYE, 0:2]
  33. rex, rey = pred['keypoints'][KP_REYE, 0:2]
  34. angle = math.atan2(ley - rey, lex - rex)
  35. angle = angle / math.pi * 180
  36. faces.append((cx, cy, fw, fh, angle))
  37. faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
  38. return faces
  39. def rotate_image(image, angle, cx, cy):
  40. h, w = image.shape[0:2]
  41. rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
  42. # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
  43. # nh = max(h, int(w * math.sin(angle)))
  44. # nw = max(w, int(h * math.sin(angle)))
  45. # if nh > h or nw > w:
  46. # pad_y = nh - h
  47. # pad_t = pad_y // 2
  48. # pad_x = nw - w
  49. # pad_l = pad_x // 2
  50. # m = np.array([[0, 0, pad_l],
  51. # [0, 0, pad_t]])
  52. # rot_mat = rot_mat + m
  53. # h, w = nh, nw
  54. # cx += pad_l
  55. # cy += pad_t
  56. result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
  57. return result, cx, cy
  58. def process(args):
  59. assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
  60. assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
  61. # アニメ顔検出モデルを読み込む
  62. print("loading face detector.")
  63. detector = create_detector('yolov3')
  64. # cropの引数を解析する
  65. if args.crop_size is None:
  66. crop_width = crop_height = None
  67. else:
  68. tokens = args.crop_size.split(',')
  69. assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
  70. crop_width, crop_height = [int(t) for t in tokens]
  71. if args.crop_ratio is None:
  72. crop_h_ratio = crop_v_ratio = None
  73. else:
  74. tokens = args.crop_ratio.split(',')
  75. assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
  76. crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
  77. # 画像を処理する
  78. print("processing.")
  79. output_extension = ".png"
  80. os.makedirs(args.dst_dir, exist_ok=True)
  81. paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
  82. glob.glob(os.path.join(args.src_dir, "*.webp"))
  83. for path in tqdm(paths):
  84. basename = os.path.splitext(os.path.basename(path))[0]
  85. # image = cv2.imread(path) # 日本語ファイル名でエラーになる
  86. image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
  87. if len(image.shape) == 2:
  88. image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  89. if image.shape[2] == 4:
  90. print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
  91. image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
  92. h, w = image.shape[:2]
  93. faces = detect_faces(detector, image, args.multiple_faces)
  94. for i, face in enumerate(faces):
  95. cx, cy, fw, fh, angle = face
  96. face_size = max(fw, fh)
  97. if args.min_size is not None and face_size < args.min_size:
  98. continue
  99. if args.max_size is not None and face_size >= args.max_size:
  100. continue
  101. face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
  102. # オプション指定があれば回転する
  103. face_img = image
  104. if args.rotate:
  105. face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
  106. # オプション指定があれば顔を中心に切り出す
  107. if crop_width is not None or crop_h_ratio is not None:
  108. cur_crop_width, cur_crop_height = crop_width, crop_height
  109. if crop_h_ratio is not None:
  110. cur_crop_width = int(face_size * crop_h_ratio + .5)
  111. cur_crop_height = int(face_size * crop_v_ratio + .5)
  112. # リサイズを必要なら行う
  113. scale = 1.0
  114. if args.resize_face_size is not None:
  115. # 顔サイズを基準にリサイズする
  116. scale = args.resize_face_size / face_size
  117. if scale < cur_crop_width / w:
  118. print(
  119. f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
  120. scale = cur_crop_width / w
  121. if scale < cur_crop_height / h:
  122. print(
  123. f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
  124. scale = cur_crop_height / h
  125. elif crop_h_ratio is not None:
  126. # 倍率指定の時にはリサイズしない
  127. pass
  128. else:
  129. # 切り出しサイズ指定あり
  130. if w < cur_crop_width:
  131. print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
  132. scale = cur_crop_width / w
  133. if h < cur_crop_height:
  134. print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
  135. scale = cur_crop_height / h
  136. if args.resize_fit:
  137. scale = max(cur_crop_width / w, cur_crop_height / h)
  138. if scale != 1.0:
  139. w = int(w * scale + .5)
  140. h = int(h * scale + .5)
  141. face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
  142. cx = int(cx * scale + .5)
  143. cy = int(cy * scale + .5)
  144. fw = int(fw * scale + .5)
  145. fh = int(fh * scale + .5)
  146. cur_crop_width = min(cur_crop_width, face_img.shape[1])
  147. cur_crop_height = min(cur_crop_height, face_img.shape[0])
  148. x = cx - cur_crop_width // 2
  149. cx = cur_crop_width // 2
  150. if x < 0:
  151. cx = cx + x
  152. x = 0
  153. elif x + cur_crop_width > w:
  154. cx = cx + (x + cur_crop_width - w)
  155. x = w - cur_crop_width
  156. face_img = face_img[:, x:x+cur_crop_width]
  157. y = cy - cur_crop_height // 2
  158. cy = cur_crop_height // 2
  159. if y < 0:
  160. cy = cy + y
  161. y = 0
  162. elif y + cur_crop_height > h:
  163. cy = cy + (y + cur_crop_height - h)
  164. y = h - cur_crop_height
  165. face_img = face_img[y:y + cur_crop_height]
  166. # # debug
  167. # print(path, cx, cy, angle)
  168. # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
  169. # cv2.imshow("image", crp)
  170. # if cv2.waitKey() == 27:
  171. # break
  172. # cv2.destroyAllWindows()
  173. # debug
  174. if args.debug:
  175. cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
  176. _, buf = cv2.imencode(output_extension, face_img)
  177. with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
  178. buf.tofile(f)
  179. def setup_parser() -> argparse.ArgumentParser:
  180. parser = argparse.ArgumentParser()
  181. parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
  182. parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
  183. parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
  184. parser.add_argument("--resize_fit", action="store_true",
  185. help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
  186. parser.add_argument("--resize_face_size", type=int, default=None,
  187. help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
  188. parser.add_argument("--crop_size", type=str, default=None,
  189. help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
  190. parser.add_argument("--crop_ratio", type=str, default=None,
  191. help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
  192. parser.add_argument("--min_size", type=int, default=None,
  193. help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
  194. parser.add_argument("--max_size", type=int, default=None,
  195. help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
  196. parser.add_argument("--multiple_faces", action="store_true",
  197. help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
  198. parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
  199. return parser
  200. if __name__ == '__main__':
  201. parser = setup_parser()
  202. args = parser.parse_args()
  203. process(args)