controlnet_ui_group.py 37 KB


  1. import gradio as gr
  2. import functools
  3. from typing import List, Optional, Union, Dict, Callable
  4. import numpy as np
  5. import base64
  6. from scripts.utils import svg_preprocess
  7. from scripts import (
  8. global_state,
  9. external_code,
  10. processor,
  11. batch_hijack,
  12. )
  13. from scripts.processor import (
  14. preprocessor_sliders_config,
  15. flag_preprocessor_resolution,
  16. model_free_preprocessors,
  17. preprocessor_filters,
  18. HWC3,
  19. )
  20. from scripts.logging import logger
  21. from scripts.controlnet_ui.openpose_editor import OpenposeEditor
  22. from modules import shared
  23. from modules.ui_components import FormRow
  24. class ToolButton(gr.Button, gr.components.FormComponent):
  25. """Small button with single emoji as text, fits inside gradio forms"""
  26. def __init__(self, **kwargs):
  27. super().__init__(variant="tool",
  28. elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"],
  29. **kwargs)
  30. def get_block_name(self):
  31. return "button"
  32. class UiControlNetUnit(external_code.ControlNetUnit):
  33. """The data class that stores all states of a ControlNetUnit."""
  34. def __init__(
  35. self,
  36. input_mode: batch_hijack.InputMode = batch_hijack.InputMode.SIMPLE,
  37. batch_images: Optional[Union[str, List[external_code.InputImage]]] = None,
  38. output_dir: str = "",
  39. loopback: bool = False,
  40. use_preview_as_input: bool = False,
  41. generated_image: Optional[np.ndarray] = None,
  42. enabled: bool = True,
  43. module: Optional[str] = None,
  44. model: Optional[str] = None,
  45. weight: float = 1.0,
  46. image: Optional[np.ndarray] = None,
  47. *args,
  48. **kwargs,
  49. ):
  50. if use_preview_as_input and generated_image is not None:
  51. input_image = generated_image
  52. module = "none"
  53. else:
  54. input_image = image
  55. super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
  56. self.is_ui = True
  57. self.input_mode = input_mode
  58. self.batch_images = batch_images
  59. self.output_dir = output_dir
  60. self.loopback = loopback
  61. class ControlNetUiGroup(object):
  62. # Note: Change symbol hints mapping in `javascript/hints.js` when you change the symbol values.
  63. refresh_symbol = "\U0001f504" # 🔄
  64. switch_values_symbol = "\U000021C5" # ⇅
  65. camera_symbol = "\U0001F4F7" # 📷
  66. reverse_symbol = "\U000021C4" # ⇄
  67. tossup_symbol = "\u2934"
  68. trigger_symbol = "\U0001F4A5" # 💥
  69. open_symbol = "\U0001F4DD" # 📝
  70. global_batch_input_dir = gr.Textbox(
  71. label="Controlnet input directory",
  72. placeholder="Leave empty to use input directory",
  73. **shared.hide_dirs,
  74. elem_id="controlnet_batch_input_dir",
  75. )
  76. img2img_batch_input_dir = None
  77. img2img_batch_input_dir_callbacks = []
  78. img2img_batch_output_dir = None
  79. img2img_batch_output_dir_callbacks = []
  80. txt2img_submit_button = None
  81. img2img_submit_button = None
  82. # Slider controls from A1111 WebUI.
  83. txt2img_w_slider = None
  84. txt2img_h_slider = None
  85. img2img_w_slider = None
  86. img2img_h_slider = None
  87. def __init__(
  88. self,
  89. gradio_compat: bool,
  90. infotext_fields: List[str],
  91. default_unit: external_code.ControlNetUnit,
  92. preprocessors: List[Callable],
  93. ):
  94. self.gradio_compat = gradio_compat
  95. self.infotext_fields = infotext_fields
  96. self.default_unit = default_unit
  97. self.preprocessors = preprocessors
  98. self.webcam_enabled = False
  99. self.webcam_mirrored = False
  100. # Note: All gradio elements declared in `render` will be defined as member variable.
  101. self.upload_tab = None
  102. self.input_image = None
  103. self.generated_image_group = None
  104. self.generated_image = None
  105. self.batch_tab = None
  106. self.batch_image_dir = None
  107. self.create_canvas = None
  108. self.canvas_width = None
  109. self.canvas_height = None
  110. self.canvas_create_button = None
  111. self.canvas_cancel_button = None
  112. self.open_new_canvas_button = None
  113. self.webcam_enable = None
  114. self.webcam_mirror = None
  115. self.send_dimen_button = None
  116. self.enabled = None
  117. self.lowvram = None
  118. self.pixel_perfect = None
  119. self.preprocessor_preview = None
  120. self.type_filter = None
  121. self.module = None
  122. self.trigger_preprocessor = None
  123. self.model = None
  124. self.refresh_models = None
  125. self.weight = None
  126. self.guidance_start = None
  127. self.guidance_end = None
  128. self.advanced = None
  129. self.processor_res = None
  130. self.threshold_a = None
  131. self.threshold_b = None
  132. self.control_mode = None
  133. self.resize_mode = None
  134. self.loopback = None
  135. self.use_preview_as_input = None
  136. self.openpose_editor = None
  137. def render(self, tabname: str, elem_id_tabname: str) -> None:
  138. """The pure HTML structure of a single ControlNetUnit. Calling this
  139. function will populate `self` with all gradio element declared
  140. in local scope.
  141. Args:
  142. tabname:
  143. elem_id_tabname:
  144. Returns:
  145. None
  146. """
  147. with gr.Tabs():
  148. with gr.Tab(label="Single Image") as self.upload_tab:
  149. with gr.Row(elem_classes=["cnet-image-row"]).style(equal_height=True):
  150. with gr.Group(elem_classes=["cnet-input-image-group"]):
  151. self.input_image = gr.Image(
  152. source="upload",
  153. brush_radius=20,
  154. mirror_webcam=False,
  155. type="numpy",
  156. tool="sketch",
  157. elem_id=f"{elem_id_tabname}_{tabname}_input_image",
  158. elem_classes=["cnet-image"],
  159. )
  160. with gr.Group(
  161. visible=False, elem_classes=["cnet-generated-image-group"]
  162. ) as self.generated_image_group:
  163. self.generated_image = gr.Image(
  164. value=None,
  165. label="Preprocessor Preview",
  166. elem_id=f"{elem_id_tabname}_{tabname}_generated_image",
  167. elem_classes=["cnet-image"], interactive=False
  168. ).style(
  169. height=242
  170. ) # Gradio's magic number. Only 242 works.
  171. with gr.Group(
  172. elem_classes=["cnet-generated-image-control-group"]
  173. ):
  174. self.openpose_editor = OpenposeEditor()
  175. preview_check_elem_id = f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_preview_checkbox"
  176. preview_close_button_js = f"document.querySelector('#{preview_check_elem_id} input[type=\\'checkbox\\']').click();"
  177. gr.HTML(
  178. value=f"""<a title="Close Preview" onclick="{preview_close_button_js}">Close</a>""",
  179. visible=True,
  180. elem_classes=["cnet-close-preview"],
  181. )
  182. with gr.Tab(label="Batch") as self.batch_tab:
  183. self.batch_image_dir = gr.Textbox(
  184. label="Input Directory",
  185. placeholder="Leave empty to use img2img batch controlnet input directory",
  186. elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
  187. )
  188. with gr.Accordion(label="Open New Canvas", visible=False) as self.create_canvas:
  189. self.canvas_width = gr.Slider(
  190. label="New Canvas Width",
  191. minimum=256,
  192. maximum=1024,
  193. value=512,
  194. step=64,
  195. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_width",
  196. )
  197. self.canvas_height = gr.Slider(
  198. label="New Canvas Height",
  199. minimum=256,
  200. maximum=1024,
  201. value=512,
  202. step=64,
  203. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_height",
  204. )
  205. with gr.Row():
  206. self.canvas_create_button = gr.Button(
  207. value="Create New Canvas",
  208. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_create_button",
  209. )
  210. self.canvas_cancel_button = gr.Button(
  211. value="Cancel",
  212. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_cancel_button",
  213. )
  214. with gr.Row(elem_classes="controlnet_image_controls"):
  215. gr.HTML(
  216. value="<p>Set the preprocessor to [invert] If your image has white background and black lines.</p>",
  217. elem_classes="controlnet_invert_warning",
  218. )
  219. self.open_new_canvas_button = ToolButton(
  220. value=ControlNetUiGroup.open_symbol,
  221. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_open_new_canvas_button",
  222. )
  223. self.webcam_enable = ToolButton(
  224. value=ControlNetUiGroup.camera_symbol,
  225. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_enable",
  226. )
  227. self.webcam_mirror = ToolButton(
  228. value=ControlNetUiGroup.reverse_symbol,
  229. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_mirror",
  230. )
  231. self.send_dimen_button = ToolButton(
  232. value=ControlNetUiGroup.tossup_symbol,
  233. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_send_dimen_button",
  234. )
  235. with FormRow(elem_classes=["controlnet_main_options"]):
  236. self.enabled = gr.Checkbox(
  237. label="Enable",
  238. value=self.default_unit.enabled,
  239. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox",
  240. elem_classes=['cnet-unit-enabled'],
  241. )
  242. self.lowvram = gr.Checkbox(
  243. label="Low VRAM",
  244. value=self.default_unit.low_vram,
  245. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_low_vram_checkbox",
  246. )
  247. self.pixel_perfect = gr.Checkbox(
  248. label="Pixel Perfect",
  249. value=self.default_unit.pixel_perfect,
  250. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pixel_perfect_checkbox",
  251. )
  252. self.preprocessor_preview = gr.Checkbox(
  253. label="Allow Preview", value=False, elem_id=preview_check_elem_id
  254. )
  255. self.use_preview_as_input = gr.Checkbox(
  256. label="Preview as Input",
  257. value=False,
  258. elem_classes=["cnet-preview-as-input"],
  259. visible=False,
  260. )
  261. if not shared.opts.data.get("controlnet_disable_control_type", False):
  262. with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
  263. self.type_filter = gr.Radio(
  264. list(preprocessor_filters.keys()),
  265. label=f"Control Type",
  266. value="All",
  267. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio",
  268. elem_classes="controlnet_control_type_filter_group",
  269. )
  270. with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]):
  271. self.module = gr.Dropdown(
  272. global_state.ui_preprocessor_keys,
  273. label=f"Preprocessor",
  274. value=self.default_unit.module,
  275. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown",
  276. )
  277. self.trigger_preprocessor = ToolButton(
  278. value=ControlNetUiGroup.trigger_symbol,
  279. visible=True,
  280. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_trigger_preprocessor",
  281. elem_classes=['cnet-run-preprocessor'],
  282. )
  283. self.model = gr.Dropdown(
  284. list(global_state.cn_models.keys()),
  285. label=f"Model",
  286. value=self.default_unit.model,
  287. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_model_dropdown",
  288. )
  289. self.refresh_models = ToolButton(
  290. value=ControlNetUiGroup.refresh_symbol,
  291. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_refresh_models",
  292. )
  293. with gr.Row(elem_classes=["controlnet_weight_steps", "controlnet_row"]):
  294. self.weight = gr.Slider(
  295. label=f"Control Weight",
  296. value=self.default_unit.weight,
  297. minimum=0.0,
  298. maximum=2.0,
  299. step=0.05,
  300. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_weight_slider",
  301. elem_classes="controlnet_control_weight_slider",
  302. )
  303. self.guidance_start = gr.Slider(
  304. label="Starting Control Step",
  305. value=self.default_unit.guidance_start,
  306. minimum=0.0,
  307. maximum=1.0,
  308. interactive=True,
  309. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_start_control_step_slider",
  310. elem_classes="controlnet_start_control_step_slider",
  311. )
  312. self.guidance_end = gr.Slider(
  313. label="Ending Control Step",
  314. value=self.default_unit.guidance_end,
  315. minimum=0.0,
  316. maximum=1.0,
  317. interactive=True,
  318. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider",
  319. elem_classes="controlnet_ending_control_step_slider",
  320. )
  321. # advanced options
  322. with gr.Column(visible=False) as self.advanced:
  323. self.processor_res = gr.Slider(
  324. label="Preprocessor resolution",
  325. value=self.default_unit.processor_res,
  326. minimum=64,
  327. maximum=2048,
  328. visible=False,
  329. interactive=False,
  330. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_resolution_slider",
  331. )
  332. self.threshold_a = gr.Slider(
  333. label="Threshold A",
  334. value=self.default_unit.threshold_a,
  335. minimum=64,
  336. maximum=1024,
  337. visible=False,
  338. interactive=False,
  339. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_A_slider",
  340. )
  341. self.threshold_b = gr.Slider(
  342. label="Threshold B",
  343. value=self.default_unit.threshold_b,
  344. minimum=64,
  345. maximum=1024,
  346. visible=False,
  347. interactive=False,
  348. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_B_slider",
  349. )
  350. self.control_mode = gr.Radio(
  351. choices=[e.value for e in external_code.ControlMode],
  352. value=self.default_unit.control_mode.value,
  353. label="Control Mode",
  354. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio",
  355. elem_classes="controlnet_control_mode_radio",
  356. )
  357. self.resize_mode = gr.Radio(
  358. choices=[e.value for e in external_code.ResizeMode],
  359. value=self.default_unit.resize_mode.value,
  360. label="Resize Mode",
  361. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio",
  362. elem_classes="controlnet_resize_mode_radio",
  363. )
  364. self.loopback = gr.Checkbox(
  365. label="[Loopback] Automatically send generated images to this ControlNet unit",
  366. value=self.default_unit.loopback,
  367. elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
  368. elem_classes="controlnet_loopback_checkbox",
  369. )
  370. def register_send_dimensions(self, is_img2img: bool):
  371. """Register event handler for send dimension button."""
  372. def send_dimensions(image):
  373. def closesteight(num):
  374. rem = num % 8
  375. if rem <= 4:
  376. return round(num - rem)
  377. else:
  378. return round(num + (8 - rem))
  379. if image:
  380. interm = np.asarray(image.get("image"))
  381. return closesteight(interm.shape[1]), closesteight(interm.shape[0])
  382. else:
  383. return gr.Slider.update(), gr.Slider.update()
  384. outputs = (
  385. [
  386. ControlNetUiGroup.img2img_w_slider,
  387. ControlNetUiGroup.img2img_h_slider,
  388. ]
  389. if is_img2img
  390. else [
  391. ControlNetUiGroup.txt2img_w_slider,
  392. ControlNetUiGroup.txt2img_h_slider,
  393. ]
  394. )
  395. self.send_dimen_button.click(
  396. fn=send_dimensions,
  397. inputs=[self.input_image],
  398. outputs=outputs,
  399. )
  400. def register_webcam_toggle(self):
  401. def webcam_toggle():
  402. self.webcam_enabled = not self.webcam_enabled
  403. return {
  404. "value": None,
  405. "source": "webcam" if self.webcam_enabled else "upload",
  406. "__type__": "update",
  407. }
  408. self.webcam_enable.click(webcam_toggle, inputs=None, outputs=self.input_image)
  409. def register_webcam_mirror_toggle(self):
  410. def webcam_mirror_toggle():
  411. self.webcam_mirrored = not self.webcam_mirrored
  412. return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"}
  413. self.webcam_mirror.click(
  414. webcam_mirror_toggle, inputs=None, outputs=self.input_image
  415. )
  416. def register_refresh_all_models(self):
  417. def refresh_all_models(*inputs):
  418. global_state.update_cn_models()
  419. dd = inputs[0]
  420. selected = dd if dd in global_state.cn_models else "None"
  421. return gr.Dropdown.update(
  422. value=selected, choices=list(global_state.cn_models.keys())
  423. )
  424. self.refresh_models.click(refresh_all_models, self.model, self.model)
  425. def register_build_sliders(self):
  426. if not self.gradio_compat:
  427. return
  428. def build_sliders(module, pp):
  429. grs = []
  430. module = global_state.get_module_basename(module)
  431. if module not in preprocessor_sliders_config:
  432. grs += [
  433. gr.update(
  434. label=flag_preprocessor_resolution,
  435. value=512,
  436. minimum=64,
  437. maximum=2048,
  438. step=1,
  439. visible=not pp,
  440. interactive=not pp,
  441. ),
  442. gr.update(visible=False, interactive=False),
  443. gr.update(visible=False, interactive=False),
  444. gr.update(visible=True),
  445. ]
  446. else:
  447. for slider_config in preprocessor_sliders_config[module]:
  448. if isinstance(slider_config, dict):
  449. visible = True
  450. if slider_config["name"] == flag_preprocessor_resolution:
  451. visible = not pp
  452. grs.append(
  453. gr.update(
  454. label=slider_config["name"],
  455. value=slider_config["value"],
  456. minimum=slider_config["min"],
  457. maximum=slider_config["max"],
  458. step=slider_config["step"]
  459. if "step" in slider_config
  460. else 1,
  461. visible=visible,
  462. interactive=visible,
  463. )
  464. )
  465. else:
  466. grs.append(gr.update(visible=False, interactive=False))
  467. while len(grs) < 3:
  468. grs.append(gr.update(visible=False, interactive=False))
  469. grs.append(gr.update(visible=True))
  470. if module in model_free_preprocessors:
  471. grs += [
  472. gr.update(visible=False, value="None"),
  473. gr.update(visible=False),
  474. ]
  475. else:
  476. grs += [gr.update(visible=True), gr.update(visible=True)]
  477. return grs
  478. inputs = [self.module, self.pixel_perfect]
  479. outputs = [
  480. self.processor_res,
  481. self.threshold_a,
  482. self.threshold_b,
  483. self.advanced,
  484. self.model,
  485. self.refresh_models,
  486. ]
  487. self.module.change(build_sliders, inputs=inputs, outputs=outputs)
  488. self.pixel_perfect.change(build_sliders, inputs=inputs, outputs=outputs)
  489. if self.type_filter is not None:
  490. def filter_selected(k, pp):
  491. (
  492. filtered_preprocessor_list,
  493. filtered_model_list,
  494. default_option,
  495. default_model
  496. ) = global_state.select_control_type(k)
  497. return [
  498. gr.Dropdown.update(value=default_option, choices=filtered_preprocessor_list),
  499. gr.Dropdown.update(value=default_model, choices=filtered_model_list),
  500. ] + build_sliders(default_option, pp)
  501. self.type_filter.change(
  502. filter_selected,
  503. inputs=[self.type_filter, self.pixel_perfect],
  504. outputs=[self.module, self.model, *outputs],
  505. )
  506. def register_run_annotator(self, is_img2img: bool):
  507. def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm):
  508. if image is None:
  509. return (
  510. gr.update(value=None, visible=True),
  511. gr.update(),
  512. *self.openpose_editor.update(''),
  513. )
  514. img = HWC3(image["image"])
  515. has_mask = not (
  516. (image["mask"][:, :, 0] <= 5).all()
  517. or (image["mask"][:, :, 0] >= 250).all()
  518. )
  519. if "inpaint" in module:
  520. color = HWC3(image["image"])
  521. alpha = image["mask"][:, :, 0:1]
  522. img = np.concatenate([color, alpha], axis=2)
  523. elif has_mask and not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False):
  524. img = HWC3(image["mask"][:, :, 0])
  525. module = global_state.get_module_basename(module)
  526. preprocessor = self.preprocessors[module]
  527. if pp:
  528. pres = external_code.pixel_perfect_resolution(
  529. img,
  530. target_H=t2i_h,
  531. target_W=t2i_w,
  532. resize_mode=external_code.resize_mode_from_value(rm),
  533. )
  534. class JsonAcceptor:
  535. def __init__(self) -> None:
  536. self.value = ""
  537. def accept(self, json_string: str) -> None:
  538. self.value = json_string
  539. json_acceptor = JsonAcceptor()
  540. logger.info(f"Preview Resolution = {pres}")
  541. def is_openpose(module: str):
  542. return "openpose" in module
  543. # Only openpose preprocessor returns a JSON output, pass json_acceptor
  544. # only when a JSON output is expected. This will make preprocessor cache
  545. # work for all other preprocessors other than openpose ones. JSON acceptor
  546. # instance are different every call, which means cache will never take
  547. # effect.
  548. # TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue?
  549. # This requires changing all callsites though.
  550. result, is_image = preprocessor(
  551. img,
  552. res=pres,
  553. thr_a=pthr_a,
  554. thr_b=pthr_b,
  555. json_pose_callback=json_acceptor.accept
  556. if is_openpose(module)
  557. else None,
  558. )
  559. if "clip" in module:
  560. result = processor.clip_vision_visualization(result)
  561. is_image = True
  562. if is_image:
  563. result = external_code.visualize_inpaint_mask(result)
  564. return (
  565. # Update to `generated_image`
  566. gr.update(value=result, visible=True, interactive=False),
  567. # preprocessor_preview
  568. gr.update(value=True),
  569. # openpose editor
  570. *self.openpose_editor.update(json_acceptor.value),
  571. )
  572. return (
  573. # Update to `generated_image`
  574. gr.update(value=None, visible=True),
  575. # preprocessor_preview
  576. gr.update(value=True),
  577. # openpose editor
  578. *self.openpose_editor.update(json_acceptor.value),
  579. )
  580. self.trigger_preprocessor.click(
  581. fn=run_annotator,
  582. inputs=[
  583. self.input_image,
  584. self.module,
  585. self.processor_res,
  586. self.threshold_a,
  587. self.threshold_b,
  588. ControlNetUiGroup.img2img_w_slider
  589. if is_img2img
  590. else ControlNetUiGroup.txt2img_w_slider,
  591. ControlNetUiGroup.img2img_h_slider
  592. if is_img2img
  593. else ControlNetUiGroup.txt2img_h_slider,
  594. self.pixel_perfect,
  595. self.resize_mode,
  596. ],
  597. outputs=[
  598. self.generated_image,
  599. self.preprocessor_preview,
  600. *self.openpose_editor.outputs(),
  601. ],
  602. )
  603. def register_shift_preview(self):
  604. def shift_preview(is_on):
  605. return (
  606. # generated_image
  607. gr.update() if is_on else gr.update(value=None),
  608. # generated_image_group
  609. gr.update(visible=is_on),
  610. # use_preview_as_input,
  611. gr.update(visible=False), # Now this is automatically managed
  612. # download_pose_link
  613. gr.update() if is_on else gr.update(value=None),
  614. # modal edit button
  615. gr.update() if is_on else gr.update(visible=False),
  616. )
  617. self.preprocessor_preview.change(
  618. fn=shift_preview,
  619. inputs=[self.preprocessor_preview],
  620. outputs=[
  621. self.generated_image,
  622. self.generated_image_group,
  623. self.use_preview_as_input,
  624. self.openpose_editor.download_link,
  625. self.openpose_editor.modal,
  626. ],
  627. )
  628. def register_create_canvas(self):
  629. self.open_new_canvas_button.click(
  630. lambda: gr.Accordion.update(visible=True),
  631. inputs=None,
  632. outputs=self.create_canvas,
  633. )
  634. self.canvas_cancel_button.click(
  635. lambda: gr.Accordion.update(visible=False),
  636. inputs=None,
  637. outputs=self.create_canvas,
  638. )
  639. def fn_canvas(h, w):
  640. return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255, gr.Accordion.update(
  641. visible=False
  642. )
  643. self.canvas_create_button.click(
  644. fn=fn_canvas,
  645. inputs=[self.canvas_height, self.canvas_width],
  646. outputs=[self.input_image, self.create_canvas],
  647. )
  648. def register_callbacks(self, is_img2img: bool):
  649. """Register callbacks on the UI elements.
  650. Args:
  651. is_img2img: Whether ControlNet is under img2img. False when in txt2img mode.
  652. Returns:
  653. None
  654. """
  655. self.register_send_dimensions(is_img2img)
  656. self.register_webcam_toggle()
  657. self.register_webcam_mirror_toggle()
  658. self.register_refresh_all_models()
  659. self.register_build_sliders()
  660. self.register_run_annotator(is_img2img)
  661. self.register_shift_preview()
  662. self.register_create_canvas()
  663. self.openpose_editor.register_callbacks(
  664. self.generated_image, self.use_preview_as_input
  665. )
  666. def register_modules(
  667. self, tabname: str, enabled, module, model, weight, guidance_start, guidance_end
  668. ):
  669. self.infotext_fields.extend(
  670. [
  671. (enabled, f"{tabname} Enabled"),
  672. (module, f"{tabname} Preprocessor"),
  673. (model, f"{tabname} Model"),
  674. (weight, f"{tabname} Weight"),
  675. (guidance_start, f"{tabname} Guidance Start"),
  676. (guidance_end, f"{tabname} Guidance End"),
  677. ]
  678. )
  679. def render_and_register_unit(self, tabname: str, is_img2img: bool):
  680. """Render the invisible states elements for misc persistent
  681. purposes. Register callbacks on loading/unloading the controlnet
  682. unit and handle batch processes.
  683. Args:
  684. tabname:
  685. is_img2img:
  686. Returns:
  687. The data class "ControlNetUnit" representing this ControlNetUnit.
  688. """
  689. input_mode = gr.State(batch_hijack.InputMode.SIMPLE)
  690. batch_image_dir_state = gr.State("")
  691. output_dir_state = gr.State("")
  692. unit_args = (
  693. input_mode,
  694. batch_image_dir_state,
  695. output_dir_state,
  696. self.loopback,
  697. # Non-persistent fields.
  698. # Following inputs will not be persistent on `ControlNetUnit`.
  699. # They are only used during object construction.
  700. self.use_preview_as_input,
  701. self.generated_image,
  702. # End of Non-persistent fields.
  703. self.enabled,
  704. self.module,
  705. self.model,
  706. self.weight,
  707. self.input_image,
  708. self.resize_mode,
  709. self.lowvram,
  710. self.processor_res,
  711. self.threshold_a,
  712. self.threshold_b,
  713. self.guidance_start,
  714. self.guidance_end,
  715. self.pixel_perfect,
  716. self.control_mode,
  717. )
  718. self.register_modules(
  719. tabname,
  720. self.enabled,
  721. self.module,
  722. self.model,
  723. self.weight,
  724. self.guidance_start,
  725. self.guidance_end,
  726. )
  727. self.input_image.preprocess = functools.partial(
  728. svg_preprocess, preprocess=self.input_image.preprocess
  729. )
  730. unit = gr.State(self.default_unit)
  731. for comp in unit_args:
  732. event_subscribers = []
  733. if hasattr(comp, "edit"):
  734. event_subscribers.append(comp.edit)
  735. elif hasattr(comp, "click"):
  736. event_subscribers.append(comp.click)
  737. elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
  738. event_subscribers.append(comp.release)
  739. elif hasattr(comp, "change"):
  740. event_subscribers.append(comp.change)
  741. if hasattr(comp, "clear"):
  742. event_subscribers.append(comp.clear)
  743. for event_subscriber in event_subscribers:
  744. event_subscriber(
  745. fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
  746. )
  747. def clear_preview(x):
  748. if x:
  749. logger.info('Preview as input is cancelled.')
  750. return gr.update(value=False), gr.update(value=None)
  751. for comp in (
  752. self.pixel_perfect,
  753. self.module,
  754. self.input_image,
  755. self.processor_res,
  756. self.threshold_a,
  757. self.threshold_b,
  758. ):
  759. event_subscribers = []
  760. if hasattr(comp, "edit"):
  761. event_subscribers.append(comp.edit)
  762. elif hasattr(comp, "click"):
  763. event_subscribers.append(comp.click)
  764. elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
  765. event_subscribers.append(comp.release)
  766. elif hasattr(comp, "change"):
  767. event_subscribers.append(comp.change)
  768. if hasattr(comp, "clear"):
  769. event_subscribers.append(comp.clear)
  770. for event_subscriber in event_subscribers:
  771. event_subscriber(
  772. fn=clear_preview, inputs=self.use_preview_as_input, outputs=[self.use_preview_as_input,
  773. self.generated_image]
  774. )
  775. # keep input_mode in sync
  776. def ui_controlnet_unit_for_input_mode(input_mode, *args):
  777. args = list(args)
  778. args[0] = input_mode
  779. return input_mode, UiControlNetUnit(*args)
  780. for input_tab in (
  781. (self.upload_tab, batch_hijack.InputMode.SIMPLE),
  782. (self.batch_tab, batch_hijack.InputMode.BATCH),
  783. ):
  784. input_tab[0].select(
  785. fn=ui_controlnet_unit_for_input_mode,
  786. inputs=[gr.State(input_tab[1])] + list(unit_args),
  787. outputs=[input_mode, unit],
  788. )
  789. def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
  790. if batch_dir:
  791. return batch_dir
  792. elif fallback_dir:
  793. return fallback_dir
  794. else:
  795. return fallback_fallback_dir
  796. # keep batch_dir in sync with global batch input textboxes
  797. def subscribe_for_batch_dir():
  798. batch_dirs = [
  799. self.batch_image_dir,
  800. ControlNetUiGroup.global_batch_input_dir,
  801. ControlNetUiGroup.img2img_batch_input_dir,
  802. ]
  803. for batch_dir_comp in batch_dirs:
  804. subscriber = getattr(batch_dir_comp, "blur", None)
  805. if subscriber is None:
  806. continue
  807. subscriber(
  808. fn=determine_batch_dir,
  809. inputs=batch_dirs,
  810. outputs=[batch_image_dir_state],
  811. queue=False,
  812. )
  813. if ControlNetUiGroup.img2img_batch_input_dir is None:
  814. # we are too soon, subscribe later when available
  815. ControlNetUiGroup.img2img_batch_input_dir_callbacks.append(
  816. subscribe_for_batch_dir
  817. )
  818. else:
  819. subscribe_for_batch_dir()
  820. # keep output_dir in sync with global batch output textbox
  821. def subscribe_for_output_dir():
  822. ControlNetUiGroup.img2img_batch_output_dir.blur(
  823. fn=lambda a: a,
  824. inputs=[ControlNetUiGroup.img2img_batch_output_dir],
  825. outputs=[output_dir_state],
  826. queue=False,
  827. )
  828. if ControlNetUiGroup.img2img_batch_input_dir is None:
  829. # we are too soon, subscribe later when available
  830. ControlNetUiGroup.img2img_batch_output_dir_callbacks.append(
  831. subscribe_for_output_dir
  832. )
  833. else:
  834. subscribe_for_output_dir()
  835. (
  836. ControlNetUiGroup.img2img_submit_button
  837. if is_img2img
  838. else ControlNetUiGroup.txt2img_submit_button
  839. ).click(
  840. fn=UiControlNetUnit,
  841. inputs=list(unit_args),
  842. outputs=unit,
  843. queue=False,
  844. )
  845. return unit
  846. @staticmethod
  847. def on_after_component(component, **_kwargs):
  848. elem_id = getattr(component, "elem_id", None)
  849. if elem_id == "txt2img_generate":
  850. ControlNetUiGroup.txt2img_submit_button = component
  851. return
  852. if elem_id == "img2img_generate":
  853. ControlNetUiGroup.img2img_submit_button = component
  854. return
  855. if elem_id == "img2img_batch_input_dir":
  856. ControlNetUiGroup.img2img_batch_input_dir = component
  857. for callback in ControlNetUiGroup.img2img_batch_input_dir_callbacks:
  858. callback()
  859. return
  860. if elem_id == "img2img_batch_output_dir":
  861. ControlNetUiGroup.img2img_batch_output_dir = component
  862. for callback in ControlNetUiGroup.img2img_batch_output_dir_callbacks:
  863. callback()
  864. return
  865. if elem_id == "img2img_batch_inpaint_mask_dir":
  866. ControlNetUiGroup.global_batch_input_dir.render()
  867. return
  868. if elem_id == "txt2img_width":
  869. ControlNetUiGroup.txt2img_w_slider = component
  870. return
  871. if elem_id == "txt2img_height":
  872. ControlNetUiGroup.txt2img_h_slider = component
  873. return
  874. if elem_id == "img2img_width":
  875. ControlNetUiGroup.img2img_w_slider = component
  876. return
  877. if elem_id == "img2img_height":
  878. ControlNetUiGroup.img2img_h_slider = component
  879. return