additional_networks.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. import os
  2. import torch
  3. import numpy as np
  4. import modules.scripts as scripts
  5. from modules import shared, script_callbacks
  6. import gradio as gr
  7. import modules.ui
  8. from modules.ui_components import ToolButton, FormRow
  9. from scripts import addnet_xyz_grid_support, lora_compvis, model_util, metadata_editor
  10. from scripts.model_util import lora_models, MAX_MODEL_COUNT
  11. memo_symbol = "\U0001F4DD" # 📝
  12. addnet_paste_params = {"txt2img": [], "img2img": []}
  13. class Script(scripts.Script):
  14. def __init__(self) -> None:
  15. super().__init__()
  16. self.latest_params = [(None, None, None, None)] * MAX_MODEL_COUNT
  17. self.latest_networks = []
  18. self.latest_model_hash = ""
  19. def title(self):
  20. return "Additional networks for generating"
  21. def show(self, is_img2img):
  22. return scripts.AlwaysVisible
  23. def ui(self, is_img2img):
  24. global addnet_paste_params
  25. # NOTE: Changing the contents of `ctrls` means the XY Grid support may need
  26. # to be updated, see xyz_grid_support.py
  27. ctrls = []
  28. weight_sliders = []
  29. model_dropdowns = []
  30. tabname = "txt2img"
  31. if is_img2img:
  32. tabname = "img2img"
  33. paste_params = addnet_paste_params[tabname]
  34. paste_params.clear()
  35. self.infotext_fields = []
  36. self.paste_field_names = []
  37. with gr.Group():
  38. with gr.Accordion("Additional Networks", open=False):
  39. with gr.Row():
  40. enabled = gr.Checkbox(label="Enable", value=False)
  41. ctrls.append(enabled)
  42. self.infotext_fields.append((enabled, "AddNet Enabled"))
  43. separate_weights = gr.Checkbox(label="Separate UNet/Text Encoder weights", value=False)
  44. ctrls.append(separate_weights)
  45. self.infotext_fields.append((separate_weights, "AddNet Separate Weights"))
  46. for i in range(MAX_MODEL_COUNT):
  47. with FormRow(variant="compact"):
  48. module = gr.Dropdown(["LoRA"], label=f"Network module {i+1}", value="LoRA")
  49. model = gr.Dropdown(list(lora_models.keys()), label=f"Model {i+1}", value="None")
  50. with gr.Row(visible=False):
  51. model_path = gr.Textbox(value="None", interactive=False, visible=False)
  52. model.change(
  53. lambda module, model, i=i: model_util.lora_models.get(model, "None"),
  54. inputs=[module, model],
  55. outputs=[model_path],
  56. )
  57. # Sending from the script UI to the metadata editor has to bypass
  58. # gradio since this button will exit the gr.Blocks context by the
  59. # time the metadata editor tab is created, so event handlers can't
  60. # be registered on it by then.
  61. model_info = ToolButton(value=memo_symbol, elem_id=f"additional_networks_send_to_metadata_editor_{i}")
  62. model_info.click(fn=None, _js="addnet_send_to_metadata_editor", inputs=[module, model_path], outputs=[])
  63. module.change(
  64. lambda module, model, i=i: addnet_xyz_grid_support.update_axis_params(i, module, model),
  65. inputs=[module, model],
  66. outputs=[],
  67. )
  68. model.change(
  69. lambda module, model, i=i: addnet_xyz_grid_support.update_axis_params(i, module, model),
  70. inputs=[module, model],
  71. outputs=[],
  72. )
  73. # perhaps there is no user to train Text Encoder only, Weight A is U-Net
  74. # The name of label will be changed in future (Weight A and B), but UNet and TEnc for now for easy understanding
  75. with gr.Column() as col:
  76. weight = gr.Slider(label=f"Weight {i+1}", value=1.0, minimum=-1.0, maximum=2.0, step=0.05, visible=True)
  77. weight_unet = gr.Slider(
  78. label=f"UNet Weight {i+1}", value=1.0, minimum=-1.0, maximum=2.0, step=0.05, visible=False
  79. )
  80. weight_tenc = gr.Slider(
  81. label=f"TEnc Weight {i+1}", value=1.0, minimum=-1.0, maximum=2.0, step=0.05, visible=False
  82. )
  83. weight.change(lambda w: (w, w), inputs=[weight], outputs=[weight_unet, weight_tenc])
  84. paste_params.append({"module": module, "model": model})
  85. ctrls.extend((module, model, weight_unet, weight_tenc))
  86. weight_sliders.extend((weight, weight_unet, weight_tenc))
  87. model_dropdowns.append(model)
  88. self.infotext_fields.extend(
  89. [
  90. (module, f"AddNet Module {i+1}"),
  91. (model, f"AddNet Model {i+1}"),
  92. (weight, f"AddNet Weight {i+1}"),
  93. (weight_unet, f"AddNet Weight A {i+1}"),
  94. (weight_tenc, f"AddNet Weight B {i+1}"),
  95. ]
  96. )
  97. for _, field_name in self.infotext_fields:
  98. self.paste_field_names.append(field_name)
  99. def update_weight_sliders(separate, *sliders):
  100. updates = []
  101. for w, w_unet, w_tenc in zip(*(iter(sliders),) * 3):
  102. if not separate:
  103. w_unet = w
  104. w_tenc = w
  105. updates.append(gr.Slider.update(visible=not separate)) # Combined
  106. updates.append(gr.Slider.update(visible=separate, value=w_unet)) # UNet
  107. updates.append(gr.Slider.update(visible=separate, value=w_tenc)) # TEnc
  108. return updates
  109. separate_weights.change(update_weight_sliders, inputs=[separate_weights] + weight_sliders, outputs=weight_sliders)
  110. def refresh_all_models(*dropdowns):
  111. model_util.update_models()
  112. updates = []
  113. for dd in dropdowns:
  114. if dd in lora_models:
  115. selected = dd
  116. else:
  117. selected = "None"
  118. update = gr.Dropdown.update(value=selected, choices=list(lora_models.keys()))
  119. updates.append(update)
  120. return updates
  121. # mask for regions
  122. with gr.Accordion("Extra args", open=False):
  123. with gr.Row():
  124. mask_image = gr.Image(label="mask image:")
  125. ctrls.append(mask_image)
  126. refresh_models = gr.Button(value="Refresh models")
  127. refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
  128. ctrls.append(refresh_models)
  129. return ctrls
  130. def set_infotext_fields(self, p, params):
  131. for i, t in enumerate(params):
  132. module, model, weight_unet, weight_tenc = t
  133. if model is None or model == "None" or len(model) == 0 or (weight_unet == 0 and weight_tenc == 0):
  134. continue
  135. p.extra_generation_params.update(
  136. {
  137. "AddNet Enabled": True,
  138. f"AddNet Module {i+1}": module,
  139. f"AddNet Model {i+1}": model,
  140. f"AddNet Weight A {i+1}": weight_unet,
  141. f"AddNet Weight B {i+1}": weight_tenc,
  142. }
  143. )
  144. def restore_networks(self, sd_model):
  145. unet = sd_model.model.diffusion_model
  146. text_encoder = sd_model.cond_stage_model
  147. if len(self.latest_networks) > 0:
  148. print("restoring last networks")
  149. for network, _ in self.latest_networks[::-1]:
  150. network.restore(text_encoder, unet)
  151. self.latest_networks.clear()
  152. def process_batch(self, p, *args, **kwargs):
  153. unet = p.sd_model.model.diffusion_model
  154. text_encoder = p.sd_model.cond_stage_model
  155. if not args[0]:
  156. self.restore_networks(p.sd_model)
  157. return
  158. params = []
  159. for i, ctrl in enumerate(args[2:]):
  160. if i % 4 == 0:
  161. param = [ctrl]
  162. else:
  163. param.append(ctrl)
  164. if i % 4 == 3:
  165. params.append(param)
  166. models_changed = len(self.latest_networks) == 0 # no latest network (cleared by check-off)
  167. models_changed = models_changed or self.latest_model_hash != p.sd_model.sd_model_hash
  168. if not models_changed:
  169. for (l_module, l_model, l_weight_unet, l_weight_tenc), (module, model, weight_unet, weight_tenc) in zip(
  170. self.latest_params, params
  171. ):
  172. if l_module != module or l_model != model or l_weight_unet != weight_unet or l_weight_tenc != weight_tenc:
  173. models_changed = True
  174. break
  175. if models_changed:
  176. self.restore_networks(p.sd_model)
  177. self.latest_params = params
  178. self.latest_model_hash = p.sd_model.sd_model_hash
  179. for module, model, weight_unet, weight_tenc in self.latest_params:
  180. if model is None or model == "None" or len(model) == 0:
  181. continue
  182. if weight_unet == 0 and weight_tenc == 0:
  183. print(f"ignore because weight is 0: {model}")
  184. continue
  185. model_path = lora_models.get(model, None)
  186. if model_path is None:
  187. raise RuntimeError(f"model not found: {model}")
  188. if model_path.startswith('"') and model_path.endswith('"'): # trim '"' at start/end
  189. model_path = model_path[1:-1]
  190. if not os.path.exists(model_path):
  191. print(f"file not found: {model_path}")
  192. continue
  193. print(f"{module} weight_unet: {weight_unet}, weight_tenc: {weight_tenc}, model: {model}")
  194. if module == "LoRA":
  195. if os.path.splitext(model_path)[1] == ".safetensors":
  196. from safetensors.torch import load_file
  197. du_state_dict = load_file(model_path)
  198. else:
  199. du_state_dict = torch.load(model_path, map_location="cpu")
  200. network, info = lora_compvis.create_network_and_apply_compvis(
  201. du_state_dict, weight_tenc, weight_unet, text_encoder, unet
  202. )
  203. # in medvram, device is different for u-net and sd_model, so use sd_model's
  204. network.to(p.sd_model.device, dtype=p.sd_model.dtype)
  205. print(f"LoRA model {model} loaded: {info}")
  206. self.latest_networks.append((network, model))
  207. if len(self.latest_networks) > 0:
  208. print("setting (or sd model) changed. new networks created.")
  209. # apply mask: currently only top 3 networks are supported
  210. if len(self.latest_networks) > 0:
  211. mask_image = args[-2]
  212. if mask_image is not None:
  213. mask_image = mask_image.astype(np.float32) / 255.0
  214. print(f"use mask image to control LoRA regions.")
  215. for i, (network, model) in enumerate(self.latest_networks[:3]):
  216. if not hasattr(network, "set_mask"):
  217. continue
  218. mask = mask_image[:, :, i] # R,G,B
  219. if mask.max() <= 0:
  220. continue
  221. mask = torch.tensor(mask, dtype=p.sd_model.dtype, device=p.sd_model.device)
  222. network.set_mask(mask, height=p.height, width=p.width, hr_height=p.hr_upscale_to_y, hr_width=p.hr_upscale_to_x)
  223. print(f"apply mask. channel: {i}, model: {model}")
  224. else:
  225. for network, _ in self.latest_networks:
  226. if hasattr(network, "set_mask"):
  227. network.set_mask(None)
  228. self.set_infotext_fields(p, self.latest_params)
  229. def on_script_unloaded():
  230. if shared.sd_model:
  231. for s in scripts.scripts_txt2img.alwayson_scripts:
  232. if isinstance(s, Script):
  233. s.restore_networks(shared.sd_model)
  234. break
  235. def on_ui_tabs():
  236. global addnet_paste_params
  237. with gr.Blocks(analytics_enabled=False) as additional_networks_interface:
  238. metadata_editor.setup_ui(addnet_paste_params)
  239. return [(additional_networks_interface, "Additional Networks", "additional_networks")]
  240. def on_ui_settings():
  241. section = ("additional_networks", "Additional Networks")
  242. shared.opts.add_option(
  243. "additional_networks_extra_lora_path",
  244. shared.OptionInfo(
  245. "",
  246. """Extra paths to scan for LoRA models, comma-separated. Paths containing commas must be enclosed in double quotes. In the path, " (one quote) must be replaced by "" (two quotes).""",
  247. section=section,
  248. ),
  249. )
  250. shared.opts.add_option(
  251. "additional_networks_sort_models_by",
  252. shared.OptionInfo(
  253. "name",
  254. "Sort LoRA models by",
  255. gr.Radio,
  256. {"choices": ["name", "date", "path name", "rating", "has user metadata"]},
  257. section=section,
  258. ),
  259. )
  260. shared.opts.add_option(
  261. "additional_networks_reverse_sort_order", shared.OptionInfo(False, "Reverse model sort order", section=section)
  262. )
  263. shared.opts.add_option(
  264. "additional_networks_model_name_filter", shared.OptionInfo("", "LoRA model name filter", section=section)
  265. )
  266. shared.opts.add_option(
  267. "additional_networks_xy_grid_model_metadata",
  268. shared.OptionInfo(
  269. "",
  270. 'Metadata to show in XY-Grid label for Model axes, comma-separated (example: "ss_learning_rate, ss_num_epochs")',
  271. section=section,
  272. ),
  273. )
  274. shared.opts.add_option(
  275. "additional_networks_hash_thread_count",
  276. shared.OptionInfo(1, "# of threads to use for hash calculation (increase if using an SSD)", section=section),
  277. )
  278. shared.opts.add_option(
  279. "additional_networks_back_up_model_when_saving",
  280. shared.OptionInfo(True, "Make a backup copy of the model being edited when saving its metadata.", section=section),
  281. )
  282. shared.opts.add_option(
  283. "additional_networks_show_only_safetensors",
  284. shared.OptionInfo(False, "Only show .safetensors format models", section=section),
  285. )
  286. shared.opts.add_option(
  287. "additional_networks_show_only_models_with_metadata",
  288. shared.OptionInfo(
  289. "disabled",
  290. "Only show models that have/don't have user-added metadata",
  291. gr.Radio,
  292. {"choices": ["disabled", "has metadata", "missing metadata"]},
  293. section=section,
  294. ),
  295. )
  296. shared.opts.add_option(
  297. "additional_networks_max_top_tags", shared.OptionInfo(20, "Max number of top tags to show", section=section)
  298. )
  299. shared.opts.add_option(
  300. "additional_networks_max_dataset_folders", shared.OptionInfo(20, "Max number of dataset folders to show", section=section)
  301. )
  302. def on_infotext_pasted(infotext, params):
  303. if "AddNet Enabled" not in params:
  304. params["AddNet Enabled"] = "False"
  305. # TODO changing "AddNet Separate Weights" does not seem to work
  306. if "AddNet Separate Weights" not in params:
  307. params["AddNet Separate Weights"] = "False"
  308. for i in range(MAX_MODEL_COUNT):
  309. # Convert combined weight into new format
  310. if f"AddNet Weight {i+1}" in params:
  311. params[f"AddNet Weight A {i+1}"] = params[f"AddNet Weight {i+1}"]
  312. params[f"AddNet Weight B {i+1}"] = params[f"AddNet Weight {i+1}"]
  313. if f"AddNet Module {i+1}" not in params:
  314. params[f"AddNet Module {i+1}"] = "LoRA"
  315. if f"AddNet Model {i+1}" not in params:
  316. params[f"AddNet Model {i+1}"] = "None"
  317. if f"AddNet Weight A {i+1}" not in params:
  318. params[f"AddNet Weight A {i+1}"] = "0"
  319. if f"AddNet Weight B {i+1}" not in params:
  320. params[f"AddNet Weight B {i+1}"] = "0"
  321. params[f"AddNet Weight {i+1}"] = params[f"AddNet Weight A {i+1}"]
  322. if params[f"AddNet Weight A {i+1}"] != params[f"AddNet Weight B {i+1}"]:
  323. params["AddNet Separate Weights"] = "True"
  324. # Convert potential legacy name/hash to new format
  325. params[f"AddNet Model {i+1}"] = str(model_util.find_closest_lora_model_name(params[f"AddNet Model {i+1}"]))
  326. addnet_xyz_grid_support.update_axis_params(i, params[f"AddNet Module {i+1}"], params[f"AddNet Model {i+1}"])
  327. addnet_xyz_grid_support.initialize(Script)
  328. script_callbacks.on_script_unloaded(on_script_unloaded)
  329. script_callbacks.on_ui_tabs(on_ui_tabs)
  330. script_callbacks.on_ui_settings(on_ui_settings)
  331. script_callbacks.on_infotext_pasted(on_infotext_pasted)