processor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. import cv2
  2. import numpy as np
  3. from annotator.util import HWC3
  4. from typing import Callable, Tuple
  5. def pad64(x):
  6. return int(np.ceil(float(x) / 64.0) * 64 - x)
  7. def safer_memory(x):
  8. # Fix many MAC/AMD problems
  9. return np.ascontiguousarray(x.copy()).copy()
  10. def resize_image_with_pad(input_image, resolution, skip_hwc3=False):
  11. if skip_hwc3:
  12. img = input_image
  13. else:
  14. img = HWC3(input_image)
  15. H_raw, W_raw, _ = img.shape
  16. k = float(resolution) / float(min(H_raw, W_raw))
  17. interpolation = cv2.INTER_CUBIC if k > 1 else cv2.INTER_AREA
  18. H_target = int(np.round(float(H_raw) * k))
  19. W_target = int(np.round(float(W_raw) * k))
  20. img = cv2.resize(img, (W_target, H_target), interpolation=interpolation)
  21. H_pad, W_pad = pad64(H_target), pad64(W_target)
  22. img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
  23. def remove_pad(x):
  24. return safer_memory(x[:H_target, :W_target])
  25. return safer_memory(img_padded), remove_pad
  26. model_canny = None
  27. def canny(img, res=512, thr_a=100, thr_b=200, **kwargs):
  28. l, h = thr_a, thr_b
  29. img, remove_pad = resize_image_with_pad(img, res)
  30. global model_canny
  31. if model_canny is None:
  32. from annotator.canny import apply_canny
  33. model_canny = apply_canny
  34. result = model_canny(img, l, h)
  35. return remove_pad(result), True
  36. def scribble_thr(img, res=512, **kwargs):
  37. img, remove_pad = resize_image_with_pad(img, res)
  38. result = np.zeros_like(img, dtype=np.uint8)
  39. result[np.min(img, axis=2) < 127] = 255
  40. return remove_pad(result), True
  41. def scribble_xdog(img, res=512, thr_a=32, **kwargs):
  42. img, remove_pad = resize_image_with_pad(img, res)
  43. g1 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 0.5)
  44. g2 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 5.0)
  45. dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
  46. result = np.zeros_like(img, dtype=np.uint8)
  47. result[2 * (255 - dog) > thr_a] = 255
  48. return remove_pad(result), True
  49. def tile_resample(img, res=512, thr_a=1.0, **kwargs):
  50. img = HWC3(img)
  51. if thr_a < 1.1:
  52. return img, True
  53. H, W, C = img.shape
  54. H = int(float(H) / float(thr_a))
  55. W = int(float(W) / float(thr_a))
  56. img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
  57. return img, True
  58. def threshold(img, res=512, thr_a=127, **kwargs):
  59. img, remove_pad = resize_image_with_pad(img, res)
  60. result = np.zeros_like(img, dtype=np.uint8)
  61. result[np.min(img, axis=2) > thr_a] = 255
  62. return remove_pad(result), True
  63. def identity(img, **kwargs):
  64. return img, True
  65. def invert(img, res=512, **kwargs):
  66. return 255 - HWC3(img), True
  67. model_hed = None
  68. def hed(img, res=512, **kwargs):
  69. img, remove_pad = resize_image_with_pad(img, res)
  70. global model_hed
  71. if model_hed is None:
  72. from annotator.hed import apply_hed
  73. model_hed = apply_hed
  74. result = model_hed(img)
  75. return remove_pad(result), True
  76. def hed_safe(img, res=512, **kwargs):
  77. img, remove_pad = resize_image_with_pad(img, res)
  78. global model_hed
  79. if model_hed is None:
  80. from annotator.hed import apply_hed
  81. model_hed = apply_hed
  82. result = model_hed(img, is_safe=True)
  83. return remove_pad(result), True
  84. def unload_hed():
  85. global model_hed
  86. if model_hed is not None:
  87. from annotator.hed import unload_hed_model
  88. unload_hed_model()
  89. def scribble_hed(img, res=512, **kwargs):
  90. result, _ = hed(img, res)
  91. import cv2
  92. from annotator.util import nms
  93. result = nms(result, 127, 3.0)
  94. result = cv2.GaussianBlur(result, (0, 0), 3.0)
  95. result[result > 4] = 255
  96. result[result < 255] = 0
  97. return result, True
  98. model_mediapipe_face = None
  99. def mediapipe_face(img, res=512, thr_a: int = 10, thr_b: float = 0.5, **kwargs):
  100. max_faces = int(thr_a)
  101. min_confidence = thr_b
  102. img, remove_pad = resize_image_with_pad(img, res)
  103. global model_mediapipe_face
  104. if model_mediapipe_face is None:
  105. from annotator.mediapipe_face import apply_mediapipe_face
  106. model_mediapipe_face = apply_mediapipe_face
  107. result = model_mediapipe_face(img, max_faces=max_faces, min_confidence=min_confidence)
  108. return remove_pad(result), True
  109. model_mlsd = None
  110. def mlsd(img, res=512, thr_a=0.1, thr_b=0.1, **kwargs):
  111. thr_v, thr_d = thr_a, thr_b
  112. img, remove_pad = resize_image_with_pad(img, res)
  113. global model_mlsd
  114. if model_mlsd is None:
  115. from annotator.mlsd import apply_mlsd
  116. model_mlsd = apply_mlsd
  117. result = model_mlsd(img, thr_v, thr_d)
  118. return remove_pad(result), True
  119. def unload_mlsd():
  120. global model_mlsd
  121. if model_mlsd is not None:
  122. from annotator.mlsd import unload_mlsd_model
  123. unload_mlsd_model()
  124. model_midas = None
  125. def midas(img, res=512, a=np.pi * 2.0, **kwargs):
  126. img, remove_pad = resize_image_with_pad(img, res)
  127. global model_midas
  128. if model_midas is None:
  129. from annotator.midas import apply_midas
  130. model_midas = apply_midas
  131. result, _ = model_midas(img, a)
  132. return remove_pad(result), True
  133. def midas_normal(img, res=512, a=np.pi * 2.0, thr_a=0.4, **kwargs): # bg_th -> thr_a
  134. bg_th = thr_a
  135. img, remove_pad = resize_image_with_pad(img, res)
  136. global model_midas
  137. if model_midas is None:
  138. from annotator.midas import apply_midas
  139. model_midas = apply_midas
  140. _, result = model_midas(img, a, bg_th)
  141. return remove_pad(result), True
  142. def unload_midas():
  143. global model_midas
  144. if model_midas is not None:
  145. from annotator.midas import unload_midas_model
  146. unload_midas_model()
  147. model_leres = None
  148. def leres(img, res=512, a=np.pi * 2.0, thr_a=0, thr_b=0, boost=False, **kwargs):
  149. img, remove_pad = resize_image_with_pad(img, res)
  150. global model_leres
  151. if model_leres is None:
  152. from annotator.leres import apply_leres
  153. model_leres = apply_leres
  154. result = model_leres(img, thr_a, thr_b, boost=boost)
  155. return remove_pad(result), True
  156. def unload_leres():
  157. global model_leres
  158. if model_leres is not None:
  159. from annotator.leres import unload_leres_model
  160. unload_leres_model()
  161. class OpenposeModel(object):
  162. def __init__(self) -> None:
  163. self.model_openpose = None
  164. def run_model(
  165. self,
  166. img: np.ndarray,
  167. include_body: bool,
  168. include_hand: bool,
  169. include_face: bool,
  170. use_dw_pose: bool = False,
  171. json_pose_callback: Callable[[str], None] = None,
  172. res: int = 512,
  173. **kwargs # Ignore rest of kwargs
  174. ) -> Tuple[np.ndarray, bool]:
  175. """Run the openpose model. Returns a tuple of
  176. - result image
  177. - is_image flag
  178. The JSON format pose string is passed to `json_pose_callback`.
  179. """
  180. if json_pose_callback is None:
  181. json_pose_callback = lambda x: None
  182. img, remove_pad = resize_image_with_pad(img, res)
  183. if self.model_openpose is None:
  184. from annotator.openpose import OpenposeDetector
  185. self.model_openpose = OpenposeDetector()
  186. return remove_pad(self.model_openpose(
  187. img,
  188. include_body=include_body,
  189. include_hand=include_hand,
  190. include_face=include_face,
  191. use_dw_pose=use_dw_pose,
  192. json_pose_callback=json_pose_callback
  193. )), True
  194. def unload(self):
  195. if self.model_openpose is not None:
  196. self.model_openpose.unload_model()
  197. g_openpose_model = OpenposeModel()
  198. model_uniformer = None
  199. def uniformer(img, res=512, **kwargs):
  200. img, remove_pad = resize_image_with_pad(img, res)
  201. global model_uniformer
  202. if model_uniformer is None:
  203. from annotator.uniformer import apply_uniformer
  204. model_uniformer = apply_uniformer
  205. result = model_uniformer(img)
  206. return remove_pad(result), True
  207. def unload_uniformer():
  208. global model_uniformer
  209. if model_uniformer is not None:
  210. from annotator.uniformer import unload_uniformer_model
  211. unload_uniformer_model()
  212. model_pidinet = None
  213. def pidinet(img, res=512, **kwargs):
  214. img, remove_pad = resize_image_with_pad(img, res)
  215. global model_pidinet
  216. if model_pidinet is None:
  217. from annotator.pidinet import apply_pidinet
  218. model_pidinet = apply_pidinet
  219. result = model_pidinet(img)
  220. return remove_pad(result), True
  221. def pidinet_ts(img, res=512, **kwargs):
  222. img, remove_pad = resize_image_with_pad(img, res)
  223. global model_pidinet
  224. if model_pidinet is None:
  225. from annotator.pidinet import apply_pidinet
  226. model_pidinet = apply_pidinet
  227. result = model_pidinet(img, apply_fliter=True)
  228. return remove_pad(result), True
  229. def pidinet_safe(img, res=512, **kwargs):
  230. img, remove_pad = resize_image_with_pad(img, res)
  231. global model_pidinet
  232. if model_pidinet is None:
  233. from annotator.pidinet import apply_pidinet
  234. model_pidinet = apply_pidinet
  235. result = model_pidinet(img, is_safe=True)
  236. return remove_pad(result), True
  237. def scribble_pidinet(img, res=512, **kwargs):
  238. result, _ = pidinet(img, res)
  239. import cv2
  240. from annotator.util import nms
  241. result = nms(result, 127, 3.0)
  242. result = cv2.GaussianBlur(result, (0, 0), 3.0)
  243. result[result > 4] = 255
  244. result[result < 255] = 0
  245. return result, True
  246. def unload_pidinet():
  247. global model_pidinet
  248. if model_pidinet is not None:
  249. from annotator.pidinet import unload_pid_model
  250. unload_pid_model()
  251. clip_encoder = None
  252. def clip(img, res=512, **kwargs):
  253. img = HWC3(img)
  254. global clip_encoder
  255. if clip_encoder is None:
  256. from annotator.clip import apply_clip
  257. clip_encoder = apply_clip
  258. result = clip_encoder(img)
  259. return result, False
  260. def clip_vision_visualization(x):
  261. x = x.detach().cpu().numpy()[0]
  262. x = np.ascontiguousarray(x).copy()
  263. return np.ndarray((x.shape[0] * 4, x.shape[1]), dtype="uint8", buffer=x.tobytes())
  264. def unload_clip():
  265. global clip_encoder
  266. if clip_encoder is not None:
  267. from annotator.clip import unload_clip_model
  268. unload_clip_model()
  269. model_color = None
  270. def color(img, res=512, **kwargs):
  271. img = HWC3(img)
  272. global model_color
  273. if model_color is None:
  274. from annotator.color import apply_color
  275. model_color = apply_color
  276. result = model_color(img, res=res)
  277. return result, True
  278. def lineart_standard(img, res=512, **kwargs):
  279. img, remove_pad = resize_image_with_pad(img, res)
  280. x = img.astype(np.float32)
  281. g = cv2.GaussianBlur(x, (0, 0), 6.0)
  282. intensity = np.min(g - x, axis=2).clip(0, 255)
  283. intensity /= max(16, np.median(intensity[intensity > 8]))
  284. intensity *= 127
  285. result = intensity.clip(0, 255).astype(np.uint8)
  286. return remove_pad(result), True
  287. model_lineart = None
  288. def lineart(img, res=512, **kwargs):
  289. img, remove_pad = resize_image_with_pad(img, res)
  290. global model_lineart
  291. if model_lineart is None:
  292. from annotator.lineart import LineartDetector
  293. model_lineart = LineartDetector(LineartDetector.model_default)
  294. # applied auto inversion
  295. result = 255 - model_lineart(img)
  296. return remove_pad(result), True
  297. def unload_lineart():
  298. global model_lineart
  299. if model_lineart is not None:
  300. model_lineart.unload_model()
  301. model_lineart_coarse = None
  302. def lineart_coarse(img, res=512, **kwargs):
  303. img, remove_pad = resize_image_with_pad(img, res)
  304. global model_lineart_coarse
  305. if model_lineart_coarse is None:
  306. from annotator.lineart import LineartDetector
  307. model_lineart_coarse = LineartDetector(LineartDetector.model_coarse)
  308. # applied auto inversion
  309. result = 255 - model_lineart_coarse(img)
  310. return remove_pad(result), True
  311. def unload_lineart_coarse():
  312. global model_lineart_coarse
  313. if model_lineart_coarse is not None:
  314. model_lineart_coarse.unload_model()
  315. model_lineart_anime = None
  316. def lineart_anime(img, res=512, **kwargs):
  317. img, remove_pad = resize_image_with_pad(img, res)
  318. global model_lineart_anime
  319. if model_lineart_anime is None:
  320. from annotator.lineart_anime import LineartAnimeDetector
  321. model_lineart_anime = LineartAnimeDetector()
  322. # applied auto inversion
  323. result = 255 - model_lineart_anime(img)
  324. return remove_pad(result), True
  325. def unload_lineart_anime():
  326. global model_lineart_anime
  327. if model_lineart_anime is not None:
  328. model_lineart_anime.unload_model()
  329. model_manga_line = None
  330. def lineart_anime_denoise(img, res=512, **kwargs):
  331. img, remove_pad = resize_image_with_pad(img, res)
  332. global model_manga_line
  333. if model_manga_line is None:
  334. from annotator.manga_line import MangaLineExtration
  335. model_manga_line = MangaLineExtration()
  336. # applied auto inversion
  337. result = model_manga_line(img)
  338. return remove_pad(result), True
  339. def unload_lineart_anime_denoise():
  340. global model_manga_line
  341. if model_manga_line is not None:
  342. model_manga_line.unload_model()
  343. model_lama = None
  344. def lama_inpaint(img, res=512, **kwargs):
  345. H, W, C = img.shape
  346. raw_color = img[:, :, 0:3].copy()
  347. raw_mask = img[:, :, 3:4].copy()
  348. res = 256 # Always use 256 since lama is trained on 256
  349. img_res, remove_pad = resize_image_with_pad(img, res, skip_hwc3=True)
  350. global model_lama
  351. if model_lama is None:
  352. from annotator.lama import LamaInpainting
  353. model_lama = LamaInpainting()
  354. # applied auto inversion
  355. prd_color = model_lama(img_res)
  356. prd_color = remove_pad(prd_color)
  357. prd_color = cv2.resize(prd_color, (W, H))
  358. alpha = raw_mask.astype(np.float32) / 255.0
  359. fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
  360. fin_color = fin_color.clip(0, 255).astype(np.uint8)
  361. result = np.concatenate([fin_color, raw_mask], axis=2)
  362. return result, True
  363. def unload_lama_inpaint():
  364. global model_lama
  365. if model_lama is not None:
  366. model_lama.unload_model()
  367. model_zoe_depth = None
  368. def zoe_depth(img, res=512, **kwargs):
  369. img, remove_pad = resize_image_with_pad(img, res)
  370. global model_zoe_depth
  371. if model_zoe_depth is None:
  372. from annotator.zoe import ZoeDetector
  373. model_zoe_depth = ZoeDetector()
  374. result = model_zoe_depth(img)
  375. return remove_pad(result), True
  376. def unload_zoe_depth():
  377. global model_zoe_depth
  378. if model_zoe_depth is not None:
  379. model_zoe_depth.unload_model()
  380. model_normal_bae = None
  381. def normal_bae(img, res=512, **kwargs):
  382. img, remove_pad = resize_image_with_pad(img, res)
  383. global model_normal_bae
  384. if model_normal_bae is None:
  385. from annotator.normalbae import NormalBaeDetector
  386. model_normal_bae = NormalBaeDetector()
  387. result = model_normal_bae(img)
  388. return remove_pad(result), True
  389. def unload_normal_bae():
  390. global model_normal_bae
  391. if model_normal_bae is not None:
  392. model_normal_bae.unload_model()
  393. model_oneformer_coco = None
  394. def oneformer_coco(img, res=512, **kwargs):
  395. img, remove_pad = resize_image_with_pad(img, res)
  396. global model_oneformer_coco
  397. if model_oneformer_coco is None:
  398. from annotator.oneformer import OneformerDetector
  399. model_oneformer_coco = OneformerDetector(OneformerDetector.configs["coco"])
  400. result = model_oneformer_coco(img)
  401. return remove_pad(result), True
  402. def unload_oneformer_coco():
  403. global model_oneformer_coco
  404. if model_oneformer_coco is not None:
  405. model_oneformer_coco.unload_model()
  406. model_oneformer_ade20k = None
  407. def oneformer_ade20k(img, res=512, **kwargs):
  408. img, remove_pad = resize_image_with_pad(img, res)
  409. global model_oneformer_ade20k
  410. if model_oneformer_ade20k is None:
  411. from annotator.oneformer import OneformerDetector
  412. model_oneformer_ade20k = OneformerDetector(OneformerDetector.configs["ade20k"])
  413. result = model_oneformer_ade20k(img)
  414. return remove_pad(result), True
  415. def unload_oneformer_ade20k():
  416. global model_oneformer_ade20k
  417. if model_oneformer_ade20k is not None:
  418. model_oneformer_ade20k.unload_model()
  419. model_shuffle = None
  420. def shuffle(img, res=512, **kwargs):
  421. img, remove_pad = resize_image_with_pad(img, res)
  422. img = remove_pad(img)
  423. global model_shuffle
  424. if model_shuffle is None:
  425. from annotator.shuffle import ContentShuffleDetector
  426. model_shuffle = ContentShuffleDetector()
  427. result = model_shuffle(img)
  428. return result, True
  429. model_free_preprocessors = [
  430. "reference_only",
  431. "reference_adain",
  432. "reference_adain+attn"
  433. ]
  434. flag_preprocessor_resolution = "Preprocessor Resolution"
  435. preprocessor_sliders_config = {
  436. "none": [],
  437. "inpaint": [],
  438. "inpaint_only": [],
  439. "canny": [
  440. {
  441. "name": flag_preprocessor_resolution,
  442. "value": 512,
  443. "min": 64,
  444. "max": 2048
  445. },
  446. {
  447. "name": "Canny Low Threshold",
  448. "value": 100,
  449. "min": 1,
  450. "max": 255
  451. },
  452. {
  453. "name": "Canny High Threshold",
  454. "value": 200,
  455. "min": 1,
  456. "max": 255
  457. },
  458. ],
  459. "mlsd": [
  460. {
  461. "name": flag_preprocessor_resolution,
  462. "min": 64,
  463. "max": 2048,
  464. "value": 512
  465. },
  466. {
  467. "name": "MLSD Value Threshold",
  468. "min": 0.01,
  469. "max": 2.0,
  470. "value": 0.1,
  471. "step": 0.01
  472. },
  473. {
  474. "name": "MLSD Distance Threshold",
  475. "min": 0.01,
  476. "max": 20.0,
  477. "value": 0.1,
  478. "step": 0.01
  479. }
  480. ],
  481. "hed": [
  482. {
  483. "name": flag_preprocessor_resolution,
  484. "min": 64,
  485. "max": 2048,
  486. "value": 512
  487. }
  488. ],
  489. "scribble_hed": [
  490. {
  491. "name": flag_preprocessor_resolution,
  492. "min": 64,
  493. "max": 2048,
  494. "value": 512
  495. }
  496. ],
  497. "hed_safe": [
  498. {
  499. "name": flag_preprocessor_resolution,
  500. "min": 64,
  501. "max": 2048,
  502. "value": 512
  503. }
  504. ],
  505. "openpose": [
  506. {
  507. "name": flag_preprocessor_resolution,
  508. "min": 64,
  509. "max": 2048,
  510. "value": 512
  511. }
  512. ],
  513. "openpose_full": [
  514. {
  515. "name": flag_preprocessor_resolution,
  516. "min": 64,
  517. "max": 2048,
  518. "value": 512
  519. }
  520. ],
  521. "dw_openpose_full": [
  522. {
  523. "name": flag_preprocessor_resolution,
  524. "min": 64,
  525. "max": 2048,
  526. "value": 512
  527. }
  528. ],
  529. "segmentation": [
  530. {
  531. "name": flag_preprocessor_resolution,
  532. "min": 64,
  533. "max": 2048,
  534. "value": 512
  535. }
  536. ],
  537. "depth": [
  538. {
  539. "name": flag_preprocessor_resolution,
  540. "min": 64,
  541. "max": 2048,
  542. "value": 512
  543. }
  544. ],
  545. "depth_leres": [
  546. {
  547. "name": flag_preprocessor_resolution,
  548. "min": 64,
  549. "max": 2048,
  550. "value": 512
  551. },
  552. {
  553. "name": "Remove Near %",
  554. "min": 0,
  555. "max": 100,
  556. "value": 0,
  557. "step": 0.1,
  558. },
  559. {
  560. "name": "Remove Background %",
  561. "min": 0,
  562. "max": 100,
  563. "value": 0,
  564. "step": 0.1,
  565. }
  566. ],
  567. "depth_leres++": [
  568. {
  569. "name": flag_preprocessor_resolution,
  570. "min": 64,
  571. "max": 2048,
  572. "value": 512
  573. },
  574. {
  575. "name": "Remove Near %",
  576. "min": 0,
  577. "max": 100,
  578. "value": 0,
  579. "step": 0.1,
  580. },
  581. {
  582. "name": "Remove Background %",
  583. "min": 0,
  584. "max": 100,
  585. "value": 0,
  586. "step": 0.1,
  587. }
  588. ],
  589. "normal_map": [
  590. {
  591. "name": flag_preprocessor_resolution,
  592. "min": 64,
  593. "max": 2048,
  594. "value": 512
  595. },
  596. {
  597. "name": "Normal Background Threshold",
  598. "min": 0.0,
  599. "max": 1.0,
  600. "value": 0.4,
  601. "step": 0.01
  602. }
  603. ],
  604. "threshold": [
  605. {
  606. "name": flag_preprocessor_resolution,
  607. "value": 512,
  608. "min": 64,
  609. "max": 2048
  610. },
  611. {
  612. "name": "Binarization Threshold",
  613. "min": 0,
  614. "max": 255,
  615. "value": 127
  616. }
  617. ],
  618. "scribble_xdog": [
  619. {
  620. "name": flag_preprocessor_resolution,
  621. "value": 512,
  622. "min": 64,
  623. "max": 2048
  624. },
  625. {
  626. "name": "XDoG Threshold",
  627. "min": 1,
  628. "max": 64,
  629. "value": 32,
  630. }
  631. ],
  632. "tile_resample": [
  633. None,
  634. {
  635. "name": "Down Sampling Rate",
  636. "value": 1.0,
  637. "min": 1.0,
  638. "max": 8.0,
  639. "step": 0.01
  640. }
  641. ],
  642. "tile_colorfix": [
  643. None,
  644. {
  645. "name": "Variation",
  646. "value": 8.0,
  647. "min": 3.0,
  648. "max": 32.0,
  649. "step": 1.0
  650. }
  651. ],
  652. "tile_colorfix+sharp": [
  653. None,
  654. {
  655. "name": "Variation",
  656. "value": 8.0,
  657. "min": 3.0,
  658. "max": 32.0,
  659. "step": 1.0
  660. },
  661. {
  662. "name": "Sharpness",
  663. "value": 1.0,
  664. "min": 0.0,
  665. "max": 2.0,
  666. "step": 0.01
  667. }
  668. ],
  669. "reference_only": [
  670. None,
  671. {
  672. "name": r'Style Fidelity (only for "Balanced" mode)',
  673. "value": 0.5,
  674. "min": 0.0,
  675. "max": 1.0,
  676. "step": 0.01
  677. }
  678. ],
  679. "reference_adain": [
  680. None,
  681. {
  682. "name": r'Style Fidelity (only for "Balanced" mode)',
  683. "value": 0.5,
  684. "min": 0.0,
  685. "max": 1.0,
  686. "step": 0.01
  687. }
  688. ],
  689. "reference_adain+attn": [
  690. None,
  691. {
  692. "name": r'Style Fidelity (only for "Balanced" mode)',
  693. "value": 0.5,
  694. "min": 0.0,
  695. "max": 1.0,
  696. "step": 0.01
  697. }
  698. ],
  699. "inpaint_only+lama": [],
  700. "color": [
  701. {
  702. "name": flag_preprocessor_resolution,
  703. "value": 512,
  704. "min": 64,
  705. "max": 2048,
  706. }
  707. ],
  708. "mediapipe_face": [
  709. {
  710. "name": flag_preprocessor_resolution,
  711. "value": 512,
  712. "min": 64,
  713. "max": 2048,
  714. },
  715. {
  716. "name": "Max Faces",
  717. "value": 1,
  718. "min": 1,
  719. "max": 10,
  720. "step": 1
  721. },
  722. {
  723. "name": "Min Face Confidence",
  724. "value": 0.5,
  725. "min": 0.01,
  726. "max": 1.0,
  727. "step": 0.01
  728. }
  729. ],
  730. }
  731. preprocessor_filters = {
  732. "All": "none",
  733. "Canny": "canny",
  734. "Depth": "depth_midas",
  735. "Normal": "normal_bae",
  736. "OpenPose": "openpose_full",
  737. "MLSD": "mlsd",
  738. "Lineart": "lineart_standard (from white bg & black line)",
  739. "SoftEdge": "softedge_pidinet",
  740. "Scribble": "scribble_pidinet",
  741. "Seg": "seg_ofade20k",
  742. "Shuffle": "shuffle",
  743. "Tile": "tile_resample",
  744. "Inpaint": "inpaint_only",
  745. "IP2P": "none",
  746. "Reference": "reference_only",
  747. "T2IA": "none",
  748. }