controlnet_ui_group.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. import gradio as gr
  2. import functools
  3. from typing import List, Optional, Union, Dict, Callable
  4. import numpy as np
  5. import base64
  6. from scripts.utils import svg_preprocess
  7. from scripts import (
  8. global_state,
  9. external_code,
  10. processor,
  11. batch_hijack,
  12. )
  13. from scripts.processor import (
  14. preprocessor_sliders_config,
  15. flag_preprocessor_resolution,
  16. model_free_preprocessors,
  17. preprocessor_filters,
  18. HWC3,
  19. )
  20. from scripts.logging import logger
  21. from scripts.controlnet_ui.openpose_editor import OpenposeEditor
  22. from modules import shared
  23. from modules.ui_components import FormRow
  24. class ToolButton(gr.Button, gr.components.FormComponent):
  25. """Small button with single emoji as text, fits inside gradio forms"""
  26. def __init__(self, **kwargs):
  27. super().__init__(variant="tool",
  28. elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"],
  29. **kwargs)
  30. def get_block_name(self):
  31. return "button"
  32. class UiControlNetUnit(external_code.ControlNetUnit):
  33. """The data class that stores all states of a ControlNetUnit."""
  34. def __init__(
  35. self,
  36. input_mode: batch_hijack.InputMode = batch_hijack.InputMode.SIMPLE,
  37. batch_images: Optional[Union[str, List[external_code.InputImage]]] = None,
  38. output_dir: str = "",
  39. loopback: bool = False,
  40. use_preview_as_input: bool = False,
  41. generated_image: Optional[np.ndarray] = None,
  42. enabled: bool = True,
  43. module: Optional[str] = None,
  44. model: Optional[str] = None,
  45. weight: float = 1.0,
  46. image: Optional[np.ndarray] = None,
  47. *args,
  48. **kwargs,
  49. ):
  50. if use_preview_as_input and generated_image is not None:
  51. input_image = generated_image
  52. module = "none"
  53. else:
  54. input_image = image
  55. super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
  56. self.is_ui = True
  57. self.input_mode = input_mode
  58. self.batch_images = batch_images
  59. self.output_dir = output_dir
  60. self.loopback = loopback
  61. class ControlNetUiGroup(object):
  62. # Note: Change symbol hints mapping in `javascript/hints.js` when you change the symbol values.
  63. refresh_symbol = "\U0001f504" # 🔄
  64. switch_values_symbol = "\U000021C5" # ⇅
  65. camera_symbol = "\U0001F4F7" # 📷
  66. reverse_symbol = "\U000021C4" # ⇄
  67. tossup_symbol = "\u2934"
  68. trigger_symbol = "\U0001F4A5" # 💥
  69. open_symbol = "\U0001F4DD" # 📝
  70. global_batch_input_dir = gr.Textbox(
  71. label="Controlnet input directory",
  72. placeholder="Leave empty to use input directory",
  73. **shared.hide_dirs,
  74. elem_id="controlnet_batch_input_dir",
  75. )
  76. img2img_batch_input_dir = None
  77. img2img_batch_input_dir_callbacks = []
  78. img2img_batch_output_dir = None
  79. img2img_batch_output_dir_callbacks = []
  80. txt2img_submit_button = None
  81. img2img_submit_button = None
  82. # Slider controls from A1111 WebUI.
  83. txt2img_w_slider = None
  84. txt2img_h_slider = None
  85. img2img_w_slider = None
  86. img2img_h_slider = None
  87. def __init__(
  88. self,
  89. gradio_compat: bool,
  90. infotext_fields: List[str],
  91. default_unit: external_code.ControlNetUnit,
  92. preprocessors: List[Callable],
  93. ):
  94. self.gradio_compat = gradio_compat
  95. self.infotext_fields = infotext_fields
  96. self.default_unit = default_unit
  97. self.preprocessors = preprocessors
  98. self.webcam_enabled = False
  99. self.webcam_mirrored = False
  100. # Note: All gradio elements declared in `render` will be defined as member variable.
  101. self.upload_tab = None
  102. self.input_image = None
  103. self.generated_image_group = None
  104. self.generated_image = None
  105. self.batch_tab = None
  106. self.batch_image_dir = None
  107. self.create_canvas = None
  108. self.canvas_width = None
  109. self.canvas_height = None
  110. self.canvas_create_button = None
  111. self.canvas_cancel_button = None
  112. self.open_new_canvas_button = None
  113. self.webcam_enable = None
  114. self.webcam_mirror = None
  115. self.send_dimen_button = None
  116. self.enabled = None
  117. self.lowvram = None
  118. self.pixel_perfect = None
  119. self.preprocessor_preview = None
  120. self.type_filter = None
  121. self.module = None
  122. self.trigger_preprocessor = None
  123. self.model = None
  124. self.refresh_models = None
  125. self.weight = None
  126. self.guidance_start = None
  127. self.guidance_end = None
  128. self.advanced = None
  129. self.processor_res = None
  130. self.threshold_a = None
  131. self.threshold_b = None
  132. self.control_mode = None
  133. self.resize_mode = None
  134. self.loopback = None
  135. self.use_preview_as_input = None
  136. self.openpose_editor = None
  137. def render(self, tabname: str, elem_id_tabname: str) -> None:
  138. """The pure HTML structure of a single ControlNetUnit. Calling this
  139. function will populate `self` with all gradio element declared
  140. in local scope.
  141. Args:
  142. tabname:
  143. elem_id_tabname:
  144. Returns:
  145. None
  146. """
  147. with gr.Tabs():
  148. with gr.Tab(label="Single Image") as self.upload_tab:
  149. with gr.Row(elem_classes=["cnet-image-row"]).style(equal_height=True):
  150. with gr.Group(elem_classes=["cnet-input-image-group"]):
  151. self.input_image = gr.Image(
  152. source="upload",
  153. brush_radius=20,
  154. mirror_webcam=False,
  155. type="numpy",
  156. tool="sketch",
  157. elem_id=f"{elem_id_tabname}_{tabname}_input_image",
  158. elem_classes=["cnet-image"],
  159. )
  160. with gr.Group(
  161. visible=False, elem_classes=["cnet-generated-image-group"]
  162. ) as self.generated_image_group:
  163. self.generated_image = gr.Image(
  164. value=None,
  165. label="Preprocessor Preview",
  166. elem_id=f"{elem_id_tabname}_{tabname}_generated_image",
  167. elem_classes=["cnet-image"], interactive=False
  168. ).style(
  169. height=242
  170. ) # Gradio's magic number. Only 242 works.
  171. with gr.Group(
  172. elem_classes=["cnet-generated-image-control-group"]
  173. ):
  174. self.openpose_editor = OpenposeEditor()
  175. preview_check_elem_id = f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_preview_checkbox"
  176. preview_close_button_js = f"document.querySelector('#{preview_check_elem_id} input[type=\\'checkbox\\']').click();"
  177. gr.HTML(
  178. value=f"""<a title="Close Preview" onclick="{preview_close_button_js}">Close</a>""",
  179. visible=True,
  180. elem_classes=["cnet-close-preview"],
  181. )
  182. with gr.Tab(label="Batch") as self.batch_tab:
  183. self.batch_image_dir = gr.Textbox(
  184. label="Input Directory",
  185. placeholder="Leave empty to use img2img batch controlnet input directory",
  186. elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
  187. )
  188. with gr.Accordion(label="Open New Canvas", visible=False) as self.create_canvas:
  189. self.canvas_width = gr.Slider(
  190. label="New Canvas Width",
  191. minimum=256,
  192. maximum=1024,
  193. value=512,
  194. step=64,
  195. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_width",
  196. )
  197. self.canvas_height = gr.Slider(
  198. label="New Canvas Height",
  199. minimum=256,
  200. maximum=1024,
  201. value=512,
  202. step=64,
  203. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_height",
  204. )
  205. with gr.Row():
  206. self.canvas_create_button = gr.Button(
  207. value="Create New Canvas",
  208. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_create_button",
  209. )
  210. self.canvas_cancel_button = gr.Button(
  211. value="Cancel",
  212. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_cancel_button",
  213. )
  214. with gr.Row(elem_classes="controlnet_image_controls"):
  215. gr.HTML(
  216. value="<p>Set the preprocessor to [invert] If your image has white background and black lines.</p>",
  217. elem_classes="controlnet_invert_warning",
  218. )
  219. self.open_new_canvas_button = ToolButton(
  220. value=ControlNetUiGroup.open_symbol,
  221. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_open_new_canvas_button",
  222. )
  223. self.webcam_enable = ToolButton(
  224. value=ControlNetUiGroup.camera_symbol,
  225. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_enable",
  226. )
  227. self.webcam_mirror = ToolButton(
  228. value=ControlNetUiGroup.reverse_symbol,
  229. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_mirror",
  230. )
  231. self.send_dimen_button = ToolButton(
  232. value=ControlNetUiGroup.tossup_symbol,
  233. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_send_dimen_button",
  234. )
  235. with FormRow(elem_classes=["controlnet_main_options"]):
  236. self.enabled = gr.Checkbox(
  237. label="Enable",
  238. value=self.default_unit.enabled,
  239. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox",
  240. elem_classes=['cnet-unit-enabled'],
  241. )
  242. self.lowvram = gr.Checkbox(
  243. label="Low VRAM",
  244. value=self.default_unit.low_vram,
  245. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_low_vram_checkbox",
  246. )
  247. self.pixel_perfect = gr.Checkbox(
  248. label="Pixel Perfect",
  249. value=self.default_unit.pixel_perfect,
  250. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pixel_perfect_checkbox",
  251. )
  252. self.preprocessor_preview = gr.Checkbox(
  253. label="Allow Preview", value=False, elem_id=preview_check_elem_id
  254. )
  255. self.use_preview_as_input = gr.Checkbox(
  256. label="Preview as Input",
  257. value=False,
  258. elem_classes=["cnet-preview-as-input"],
  259. visible=False,
  260. )
  261. if not shared.opts.data.get("controlnet_disable_control_type", False):
  262. with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
  263. self.type_filter = gr.Radio(
  264. list(preprocessor_filters.keys()),
  265. label=f"Control Type",
  266. value="All",
  267. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio",
  268. elem_classes="controlnet_control_type_filter_group",
  269. )
  270. with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]):
  271. self.module = gr.Dropdown(
  272. global_state.ui_preprocessor_keys,
  273. label=f"Preprocessor",
  274. value=self.default_unit.module,
  275. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown",
  276. )
  277. self.trigger_preprocessor = ToolButton(
  278. value=ControlNetUiGroup.trigger_symbol,
  279. visible=True,
  280. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_trigger_preprocessor",
  281. elem_classes=['cnet-run-preprocessor'],
  282. )
  283. self.model = gr.Dropdown(
  284. list(global_state.cn_models.keys()),
  285. label=f"Model",
  286. value=self.default_unit.model,
  287. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_model_dropdown",
  288. )
  289. self.refresh_models = ToolButton(
  290. value=ControlNetUiGroup.refresh_symbol,
  291. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_refresh_models",
  292. )
  293. with gr.Row(elem_classes=["controlnet_weight_steps", "controlnet_row"]):
  294. self.weight = gr.Slider(
  295. label=f"Control Weight",
  296. value=self.default_unit.weight,
  297. minimum=0.0,
  298. maximum=2.0,
  299. step=0.05,
  300. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_weight_slider",
  301. elem_classes="controlnet_control_weight_slider",
  302. )
  303. self.guidance_start = gr.Slider(
  304. label="Starting Control Step",
  305. value=self.default_unit.guidance_start,
  306. minimum=0.0,
  307. maximum=1.0,
  308. interactive=True,
  309. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_start_control_step_slider",
  310. elem_classes="controlnet_start_control_step_slider",
  311. )
  312. self.guidance_end = gr.Slider(
  313. label="Ending Control Step",
  314. value=self.default_unit.guidance_end,
  315. minimum=0.0,
  316. maximum=1.0,
  317. interactive=True,
  318. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider",
  319. elem_classes="controlnet_ending_control_step_slider",
  320. )
  321. # advanced options
  322. with gr.Column(visible=False) as self.advanced:
  323. self.processor_res = gr.Slider(
  324. label="Preprocessor resolution",
  325. value=self.default_unit.processor_res,
  326. minimum=64,
  327. maximum=2048,
  328. visible=False,
  329. interactive=False,
  330. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_resolution_slider",
  331. )
  332. self.threshold_a = gr.Slider(
  333. label="Threshold A",
  334. value=self.default_unit.threshold_a,
  335. minimum=64,
  336. maximum=1024,
  337. visible=False,
  338. interactive=False,
  339. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_A_slider",
  340. )
  341. self.threshold_b = gr.Slider(
  342. label="Threshold B",
  343. value=self.default_unit.threshold_b,
  344. minimum=64,
  345. maximum=1024,
  346. visible=False,
  347. interactive=False,
  348. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_B_slider",
  349. )
  350. self.control_mode = gr.Radio(
  351. choices=[e.value for e in external_code.ControlMode],
  352. value=self.default_unit.control_mode.value,
  353. label="Control Mode",
  354. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio",
  355. elem_classes="controlnet_control_mode_radio",
  356. )
  357. self.resize_mode = gr.Radio(
  358. choices=[e.value for e in external_code.ResizeMode],
  359. value=self.default_unit.resize_mode.value,
  360. label="Resize Mode",
  361. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio",
  362. elem_classes="controlnet_resize_mode_radio",
  363. )
  364. self.loopback = gr.Checkbox(
  365. label="[Loopback] Automatically send generated images to this ControlNet unit",
  366. value=self.default_unit.loopback,
  367. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
  368. elem_classes="controlnet_loopback_checkbox",
  369. )
  370. def register_send_dimensions(self, is_img2img: bool):
  371. """Register event handler for send dimension button."""
  372. def send_dimensions(image):
  373. def closesteight(num):
  374. rem = num % 8
  375. if rem <= 4:
  376. return round(num - rem)
  377. else:
  378. return round(num + (8 - rem))
  379. if image:
  380. interm = np.asarray(image.get("image"))
  381. return closesteight(interm.shape[1]), closesteight(interm.shape[0])
  382. else:
  383. return gr.Slider.update(), gr.Slider.update()
  384. outputs = (
  385. [
  386. ControlNetUiGroup.img2img_w_slider,
  387. ControlNetUiGroup.img2img_h_slider,
  388. ]
  389. if is_img2img
  390. else [
  391. ControlNetUiGroup.txt2img_w_slider,
  392. ControlNetUiGroup.txt2img_h_slider,
  393. ]
  394. )
  395. self.send_dimen_button.click(
  396. fn=send_dimensions,
  397. inputs=[self.input_image],
  398. outputs=outputs,
  399. )
  400. def register_webcam_toggle(self):
  401. def webcam_toggle():
  402. self.webcam_enabled = not self.webcam_enabled
  403. return {
  404. "value": None,
  405. "source": "webcam" if self.webcam_enabled else "upload",
  406. "__type__": "update",
  407. }
  408. self.webcam_enable.click(webcam_toggle, inputs=None, outputs=self.input_image)
  409. def register_webcam_mirror_toggle(self):
  410. def webcam_mirror_toggle():
  411. self.webcam_mirrored = not self.webcam_mirrored
  412. return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"}
  413. self.webcam_mirror.click(
  414. webcam_mirror_toggle, inputs=None, outputs=self.input_image
  415. )
  416. def register_refresh_all_models(self):
  417. def refresh_all_models(*inputs):
  418. global_state.update_cn_models()
  419. dd = inputs[0]
  420. selected = dd if dd in global_state.cn_models else "None"
  421. return gr.Dropdown.update(
  422. value=selected, choices=list(global_state.cn_models.keys())
  423. )
  424. self.refresh_models.click(refresh_all_models, self.model, self.model)
  425. def register_build_sliders(self):
  426. if not self.gradio_compat:
  427. return
  428. def build_sliders(module, pp):
  429. grs = []
  430. module = global_state.get_module_basename(module)
  431. if module not in preprocessor_sliders_config:
  432. grs += [
  433. gr.update(
  434. label=flag_preprocessor_resolution,
  435. value=512,
  436. minimum=64,
  437. maximum=2048,
  438. step=1,
  439. visible=not pp,
  440. interactive=not pp,
  441. ),
  442. gr.update(visible=False, interactive=False),
  443. gr.update(visible=False, interactive=False),
  444. gr.update(visible=True),
  445. ]
  446. else:
  447. for slider_config in preprocessor_sliders_config[module]:
  448. if isinstance(slider_config, dict):
  449. visible = True
  450. if slider_config["name"] == flag_preprocessor_resolution:
  451. visible = not pp
  452. grs.append(
  453. gr.update(
  454. label=slider_config["name"],
  455. value=slider_config["value"],
  456. minimum=slider_config["min"],
  457. maximum=slider_config["max"],
  458. step=slider_config["step"]
  459. if "step" in slider_config
  460. else 1,
  461. visible=visible,
  462. interactive=visible,
  463. )
  464. )
  465. else:
  466. grs.append(gr.update(visible=False, interactive=False))
  467. while len(grs) < 3:
  468. grs.append(gr.update(visible=False, interactive=False))
  469. grs.append(gr.update(visible=True))
  470. if module in model_free_preprocessors:
  471. grs += [
  472. gr.update(visible=False, value="None"),
  473. gr.update(visible=False),
  474. ]
  475. else:
  476. grs += [gr.update(visible=True), gr.update(visible=True)]
  477. return grs
  478. inputs = [self.module, self.pixel_perfect]
  479. outputs = [
  480. self.processor_res,
  481. self.threshold_a,
  482. self.threshold_b,
  483. self.advanced,
  484. self.model,
  485. self.refresh_models,
  486. ]
  487. self.module.change(build_sliders, inputs=inputs, outputs=outputs)
  488. self.pixel_perfect.change(build_sliders, inputs=inputs, outputs=outputs)
  489. if self.type_filter is not None:
  490. def filter_selected(k, pp):
  491. (
  492. filtered_preprocessor_list,
  493. filtered_model_list,
  494. default_option,
  495. default_model
  496. ) = global_state.select_control_type(k)
  497. return [
  498. gr.Dropdown.update(value=default_option, choices=filtered_preprocessor_list),
  499. gr.Dropdown.update(value=default_model, choices=filtered_model_list),
  500. ] + build_sliders(default_option, pp)
  501. self.type_filter.change(
  502. filter_selected,
  503. inputs=[self.type_filter, self.pixel_perfect],
  504. outputs=[self.module, self.model, *outputs],
  505. )
  506. def register_run_annotator(self, is_img2img: bool):
  507. def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm):
  508. if image is None:
  509. return (
  510. gr.update(value=None, visible=True),
  511. gr.update(),
  512. *self.openpose_editor.update(''),
  513. )
  514. img = HWC3(image["image"])
  515. has_mask = not (
  516. (image["mask"][:, :, 0] <= 5).all()
  517. or (image["mask"][:, :, 0] >= 250).all()
  518. )
  519. if "inpaint" in module:
  520. color = HWC3(image["image"])
  521. alpha = image["mask"][:, :, 0:1]
  522. img = np.concatenate([color, alpha], axis=2)
  523. elif has_mask and not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False):
  524. img = HWC3(image["mask"][:, :, 0])
  525. module = global_state.get_module_basename(module)
  526. preprocessor = self.preprocessors[module]
  527. if pp:
  528. pres = external_code.pixel_perfect_resolution(
  529. img,
  530. target_H=t2i_h,
  531. target_W=t2i_w,
  532. resize_mode=external_code.resize_mode_from_value(rm),
  533. )
  534. class JsonAcceptor:
  535. def __init__(self) -> None:
  536. self.value = ""
  537. def accept(self, json_string: str) -> None:
  538. self.value = json_string
  539. json_acceptor = JsonAcceptor()
  540. logger.info(f"Preview Resolution = {pres}")
  541. def is_openpose(module: str):
  542. return "openpose" in module
  543. # Only openpose preprocessor returns a JSON output, pass json_acceptor
  544. # only when a JSON output is expected. This will make preprocessor cache
  545. # work for all other preprocessors other than openpose ones. JSON acceptor
  546. # instance are different every call, which means cache will never take
  547. # effect.
  548. # TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue?
  549. # This requires changing all callsites though.
  550. result, is_image = preprocessor(
  551. img,
  552. res=pres,
  553. thr_a=pthr_a,
  554. thr_b=pthr_b,
  555. json_pose_callback=json_acceptor.accept
  556. if is_openpose(module)
  557. else None,
  558. )
  559. if "clip" in module:
  560. result = processor.clip_vision_visualization(result)
  561. is_image = True
  562. if is_image:
  563. result = external_code.visualize_inpaint_mask(result)
  564. return (
  565. # Update to `generated_image`
  566. gr.update(value=result, visible=True, interactive=False),
  567. # preprocessor_preview
  568. gr.update(value=True),
  569. # openpose editor
  570. *self.openpose_editor.update(json_acceptor.value),
  571. )
  572. return (
  573. # Update to `generated_image`
  574. gr.update(value=None, visible=True),
  575. # preprocessor_preview
  576. gr.update(value=True),
  577. # openpose editor
  578. *self.openpose_editor.update(json_acceptor.value),
  579. )
  580. self.trigger_preprocessor.click(
  581. fn=run_annotator,
  582. inputs=[
  583. self.input_image,
  584. self.module,
  585. self.processor_res,
  586. self.threshold_a,
  587. self.threshold_b,
  588. ControlNetUiGroup.img2img_w_slider
  589. if is_img2img
  590. else ControlNetUiGroup.txt2img_w_slider,
  591. ControlNetUiGroup.img2img_h_slider
  592. if is_img2img
  593. else ControlNetUiGroup.txt2img_h_slider,
  594. self.pixel_perfect,
  595. self.resize_mode,
  596. ],
  597. outputs=[
  598. self.generated_image,
  599. self.preprocessor_preview,
  600. *self.openpose_editor.outputs(),
  601. ],
  602. )
  603. def register_shift_preview(self):
  604. def shift_preview(is_on):
  605. return (
  606. # generated_image
  607. gr.update() if is_on else gr.update(value=None),
  608. # generated_image_group
  609. gr.update(visible=is_on),
  610. # use_preview_as_input,
  611. gr.update(visible=False), # Now this is automatically managed
  612. # download_pose_link
  613. gr.update() if is_on else gr.update(value=None),
  614. # modal edit button
  615. gr.update() if is_on else gr.update(visible=False),
  616. )
  617. self.preprocessor_preview.change(
  618. fn=shift_preview,
  619. inputs=[self.preprocessor_preview],
  620. outputs=[
  621. self.generated_image,
  622. self.generated_image_group,
  623. self.use_preview_as_input,
  624. self.openpose_editor.download_link,
  625. self.openpose_editor.modal,
  626. ],
  627. )
  628. def register_create_canvas(self):
  629. self.open_new_canvas_button.click(
  630. lambda: gr.Accordion.update(visible=True),
  631. inputs=None,
  632. outputs=self.create_canvas,
  633. )
  634. self.canvas_cancel_button.click(
  635. lambda: gr.Accordion.update(visible=False),
  636. inputs=None,
  637. outputs=self.create_canvas,
  638. )
  639. def fn_canvas(h, w):
  640. return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255, gr.Accordion.update(
  641. visible=False
  642. )
  643. self.canvas_create_button.click(
  644. fn=fn_canvas,
  645. inputs=[self.canvas_height, self.canvas_width],
  646. outputs=[self.input_image, self.create_canvas],
  647. )
  648. def register_callbacks(self, is_img2img: bool):
  649. """Register callbacks on the UI elements.
  650. Args:
  651. is_img2img: Whether ControlNet is under img2img. False when in txt2img mode.
  652. Returns:
  653. None
  654. """
  655. self.register_send_dimensions(is_img2img)
  656. self.register_webcam_toggle()
  657. self.register_webcam_mirror_toggle()
  658. self.register_refresh_all_models()
  659. self.register_build_sliders()
  660. self.register_run_annotator(is_img2img)
  661. self.register_shift_preview()
  662. self.register_create_canvas()
  663. self.openpose_editor.register_callbacks(
  664. self.generated_image, self.use_preview_as_input
  665. )
  666. def register_modules(
  667. self, tabname: str, enabled, module, model, weight, guidance_start, guidance_end
  668. ):
  669. self.infotext_fields.extend(
  670. [
  671. (enabled, f"{tabname} Enabled"),
  672. (module, f"{tabname} Preprocessor"),
  673. (model, f"{tabname} Model"),
  674. (weight, f"{tabname} Weight"),
  675. (guidance_start, f"{tabname} Guidance Start"),
  676. (guidance_end, f"{tabname} Guidance End"),
  677. ]
  678. )
  679. def render_and_register_unit(self, tabname: str, is_img2img: bool):
  680. """Render the invisible states elements for misc persistent
  681. purposes. Register callbacks on loading/unloading the controlnet
  682. unit and handle batch processes.
  683. Args:
  684. tabname:
  685. is_img2img:
  686. Returns:
  687. The data class "ControlNetUnit" representing this ControlNetUnit.
  688. """
  689. input_mode = gr.State(batch_hijack.InputMode.SIMPLE)
  690. batch_image_dir_state = gr.State("")
  691. output_dir_state = gr.State("")
  692. unit_args = (
  693. input_mode,
  694. batch_image_dir_state,
  695. output_dir_state,
  696. self.loopback,
  697. # Non-persistent fields.
  698. # Following inputs will not be persistent on `ControlNetUnit`.
  699. # They are only used during object construction.
  700. self.use_preview_as_input,
  701. self.generated_image,
  702. # End of Non-persistent fields.
  703. self.enabled,
  704. self.module,
  705. self.model,
  706. self.weight,
  707. self.input_image,
  708. self.resize_mode,
  709. self.lowvram,
  710. self.processor_res,
  711. self.threshold_a,
  712. self.threshold_b,
  713. self.guidance_start,
  714. self.guidance_end,
  715. self.pixel_perfect,
  716. self.control_mode,
  717. )
  718. self.register_modules(
  719. tabname,
  720. self.enabled,
  721. self.module,
  722. self.model,
  723. self.weight,
  724. self.guidance_start,
  725. self.guidance_end,
  726. )
  727. self.input_image.preprocess = functools.partial(
  728. svg_preprocess, preprocess=self.input_image.preprocess
  729. )
  730. unit = gr.State(self.default_unit)
  731. for comp in unit_args:
  732. event_subscribers = []
  733. if hasattr(comp, "edit"):
  734. event_subscribers.append(comp.edit)
  735. elif hasattr(comp, "click"):
  736. event_subscribers.append(comp.click)
  737. elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
  738. event_subscribers.append(comp.release)
  739. elif hasattr(comp, "change"):
  740. event_subscribers.append(comp.change)
  741. if hasattr(comp, "clear"):
  742. event_subscribers.append(comp.clear)
  743. for event_subscriber in event_subscribers:
  744. event_subscriber(
  745. fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
  746. )
  747. def clear_preview(x):
  748. if x:
  749. logger.info('Preview as input is cancelled.')
  750. return gr.update(value=False), gr.update(value=None)
  751. for comp in (
  752. self.pixel_perfect,
  753. self.module,
  754. self.input_image,
  755. self.processor_res,
  756. self.threshold_a,
  757. self.threshold_b,
  758. ):
  759. event_subscribers = []
  760. if hasattr(comp, "edit"):
  761. event_subscribers.append(comp.edit)
  762. elif hasattr(comp, "click"):
  763. event_subscribers.append(comp.click)
  764. elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
  765. event_subscribers.append(comp.release)
  766. elif hasattr(comp, "change"):
  767. event_subscribers.append(comp.change)
  768. if hasattr(comp, "clear"):
  769. event_subscribers.append(comp.clear)
  770. for event_subscriber in event_subscribers:
  771. event_subscriber(
  772. fn=clear_preview, inputs=self.use_preview_as_input, outputs=[self.use_preview_as_input,
  773. self.generated_image]
  774. )
  775. # keep input_mode in sync
  776. def ui_controlnet_unit_for_input_mode(input_mode, *args):
  777. args = list(args)
  778. args[0] = input_mode
  779. return input_mode, UiControlNetUnit(*args)
  780. for input_tab in (
  781. (self.upload_tab, batch_hijack.InputMode.SIMPLE),
  782. (self.batch_tab, batch_hijack.InputMode.BATCH),
  783. ):
  784. input_tab[0].select(
  785. fn=ui_controlnet_unit_for_input_mode,
  786. inputs=[gr.State(input_tab[1])] + list(unit_args),
  787. outputs=[input_mode, unit],
  788. )
  789. def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
  790. if batch_dir:
  791. return batch_dir
  792. elif fallback_dir:
  793. return fallback_dir
  794. else:
  795. return fallback_fallback_dir
  796. # keep batch_dir in sync with global batch input textboxes
  797. def subscribe_for_batch_dir():
  798. batch_dirs = [
  799. self.batch_image_dir,
  800. ControlNetUiGroup.global_batch_input_dir,
  801. ControlNetUiGroup.img2img_batch_input_dir,
  802. ]
  803. for batch_dir_comp in batch_dirs:
  804. subscriber = getattr(batch_dir_comp, "blur", None)
  805. if subscriber is None:
  806. continue
  807. subscriber(
  808. fn=determine_batch_dir,
  809. inputs=batch_dirs,
  810. outputs=[batch_image_dir_state],
  811. queue=False,
  812. )
  813. if ControlNetUiGroup.img2img_batch_input_dir is None:
  814. # we are too soon, subscribe later when available
  815. ControlNetUiGroup.img2img_batch_input_dir_callbacks.append(
  816. subscribe_for_batch_dir
  817. )
  818. else:
  819. subscribe_for_batch_dir()
  820. # keep output_dir in sync with global batch output textbox
  821. def subscribe_for_output_dir():
  822. ControlNetUiGroup.img2img_batch_output_dir.blur(
  823. fn=lambda a: a,
  824. inputs=[ControlNetUiGroup.img2img_batch_output_dir],
  825. outputs=[output_dir_state],
  826. queue=False,
  827. )
  828. if ControlNetUiGroup.img2img_batch_input_dir is None:
  829. # we are too soon, subscribe later when available
  830. ControlNetUiGroup.img2img_batch_output_dir_callbacks.append(
  831. subscribe_for_output_dir
  832. )
  833. else:
  834. subscribe_for_output_dir()
  835. (
  836. ControlNetUiGroup.img2img_submit_button
  837. if is_img2img
  838. else ControlNetUiGroup.txt2img_submit_button
  839. ).click(
  840. fn=UiControlNetUnit,
  841. inputs=list(unit_args),
  842. outputs=unit,
  843. queue=False,
  844. )
  845. return unit
  846. @staticmethod
  847. def on_after_component(component, **_kwargs):
  848. elem_id = getattr(component, "elem_id", None)
  849. if elem_id == "txt2img_generate":
  850. ControlNetUiGroup.txt2img_submit_button = component
  851. return
  852. if elem_id == "img2img_generate":
  853. ControlNetUiGroup.img2img_submit_button = component
  854. return
  855. if elem_id == "img2img_batch_input_dir":
  856. ControlNetUiGroup.img2img_batch_input_dir = component
  857. for callback in ControlNetUiGroup.img2img_batch_input_dir_callbacks:
  858. callback()
  859. return
  860. if elem_id == "img2img_batch_output_dir":
  861. ControlNetUiGroup.img2img_batch_output_dir = component
  862. for callback in ControlNetUiGroup.img2img_batch_output_dir_callbacks:
  863. callback()
  864. return
  865. if elem_id == "img2img_batch_inpaint_mask_dir":
  866. ControlNetUiGroup.global_batch_input_dir.render()
  867. return
  868. if elem_id == "txt2img_width":
  869. ControlNetUiGroup.txt2img_w_slider = component
  870. return
  871. if elem_id == "txt2img_height":
  872. ControlNetUiGroup.txt2img_h_slider = component
  873. return
  874. if elem_id == "img2img_width":
  875. ControlNetUiGroup.img2img_w_slider = component
  876. return
  877. if elem_id == "img2img_height":
  878. ControlNetUiGroup.img2img_h_slider = component
  879. return