ui_extra_networks.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import glob
  2. import os.path
  3. import urllib.parse
  4. from pathlib import Path
  5. from PIL import PngImagePlugin
  6. from modules import shared
  7. from modules.images import read_info_from_image
  8. import gradio as gr
  9. import json
  10. import html
  11. from modules.generation_parameters_copypaste import image_from_url_text
  12. extra_pages = []
  13. allowed_dirs = set()
  14. def register_page(page):
  15. """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
  16. extra_pages.append(page)
  17. allowed_dirs.clear()
  18. allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
  19. def fetch_file(filename: str = ""):
  20. from starlette.responses import FileResponse
  21. if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
  22. raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
  23. ext = os.path.splitext(filename)[1].lower()
  24. if ext not in (".png", ".jpg", ".webp"):
  25. raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
  26. # would profit from returning 304
  27. return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
  28. def get_metadata(page: str = "", item: str = ""):
  29. from starlette.responses import JSONResponse
  30. page = next(iter([x for x in extra_pages if x.name == page]), None)
  31. if page is None:
  32. return JSONResponse({})
  33. metadata = page.metadata.get(item)
  34. if metadata is None:
  35. return JSONResponse({})
  36. return JSONResponse({"metadata": metadata})
  37. def add_pages_to_demo(app):
  38. app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
  39. app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
  40. class ExtraNetworksPage:
  41. def __init__(self, title):
  42. self.title = title
  43. self.name = title.lower()
  44. self.card_page = shared.html("extra-networks-card.html")
  45. self.allow_negative_prompt = False
  46. self.metadata = {}
  47. def refresh(self):
  48. pass
  49. def link_preview(self, filename):
  50. return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
  51. def search_terms_from_path(self, filename, possible_directories=None):
  52. abspath = os.path.abspath(filename)
  53. for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
  54. parentdir = os.path.abspath(parentdir)
  55. if abspath.startswith(parentdir):
  56. return abspath[len(parentdir):].replace('\\', '/')
  57. return ""
  58. def create_html(self, tabname):
  59. view = shared.opts.extra_networks_default_view
  60. items_html = ''
  61. self.metadata = {}
  62. subdirs = {}
  63. for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
  64. for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
  65. if not os.path.isdir(x):
  66. continue
  67. subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
  68. while subdir.startswith("/"):
  69. subdir = subdir[1:]
  70. is_empty = len(os.listdir(x)) == 0
  71. if not is_empty and not subdir.endswith("/"):
  72. subdir = subdir + "/"
  73. subdirs[subdir] = 1
  74. if subdirs:
  75. subdirs = {"": 1, **subdirs}
  76. subdirs_html = "".join([f"""
  77. <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
  78. {html.escape(subdir if subdir!="" else "all")}
  79. </button>
  80. """ for subdir in subdirs])
  81. for item in self.list_items():
  82. metadata = item.get("metadata")
  83. if metadata:
  84. self.metadata[item["name"]] = metadata
  85. items_html += self.create_html_for_item(item, tabname)
  86. if items_html == '':
  87. dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
  88. items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
  89. self_name_id = self.name.replace(" ", "_")
  90. res = f"""
  91. <div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
  92. {subdirs_html}
  93. </div>
  94. <div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
  95. {items_html}
  96. </div>
  97. """
  98. return res
  99. def list_items(self):
  100. raise NotImplementedError()
  101. def allowed_directories_for_previews(self):
  102. return []
  103. def create_html_for_item(self, item, tabname):
  104. preview = item.get("preview", None)
  105. onclick = item.get("onclick", None)
  106. if onclick is None:
  107. onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
  108. height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
  109. width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
  110. background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
  111. metadata_button = ""
  112. metadata = item.get("metadata")
  113. if metadata:
  114. metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
  115. args = {
  116. "style": f"'{height}{width}{background_image}'",
  117. "prompt": item.get("prompt", None),
  118. "tabname": json.dumps(tabname),
  119. "local_preview": json.dumps(item["local_preview"]),
  120. "name": item["name"],
  121. "description": (item.get("description") or ""),
  122. "card_clicked": onclick,
  123. "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
  124. "search_term": item.get("search_term", ""),
  125. "metadata_button": metadata_button,
  126. }
  127. return self.card_page.format(**args)
  128. def find_preview(self, path):
  129. """
  130. Find a preview PNG for a given path (without extension) and call link_preview on it.
  131. """
  132. preview_extensions = ["png", "jpg", "webp"]
  133. if shared.opts.samples_format not in preview_extensions:
  134. preview_extensions.append(shared.opts.samples_format)
  135. potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
  136. for file in potential_files:
  137. if os.path.isfile(file):
  138. return self.link_preview(file)
  139. return None
  140. def find_description(self, path):
  141. """
  142. Find and read a description file for a given path (without extension).
  143. """
  144. for file in [f"{path}.txt", f"{path}.description.txt"]:
  145. try:
  146. with open(file, "r", encoding="utf-8", errors="replace") as f:
  147. return f.read()
  148. except OSError:
  149. pass
  150. return None
  151. def intialize():
  152. extra_pages.clear()
  153. class ExtraNetworksUi:
  154. def __init__(self):
  155. self.pages = None
  156. self.stored_extra_pages = None
  157. self.button_save_preview = None
  158. self.preview_target_filename = None
  159. self.tabname = None
  160. def pages_in_preferred_order(pages):
  161. tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
  162. def tab_name_score(name):
  163. name = name.lower()
  164. for i, possible_match in enumerate(tab_order):
  165. if possible_match in name:
  166. return i
  167. return len(pages)
  168. tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
  169. return sorted(pages, key=lambda x: tab_scores[x.name])
  170. def create_ui(container, button, tabname):
  171. ui = ExtraNetworksUi()
  172. ui.pages = []
  173. ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
  174. ui.tabname = tabname
  175. with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
  176. for page in ui.stored_extra_pages:
  177. with gr.Tab(page.title):
  178. page_elem = gr.HTML(page.create_html(ui.tabname))
  179. ui.pages.append(page_elem)
  180. filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
  181. button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
  182. ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
  183. ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
  184. def toggle_visibility(is_visible):
  185. is_visible = not is_visible
  186. return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
  187. state_visible = gr.State(value=False)
  188. button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
  189. def refresh():
  190. res = []
  191. for pg in ui.stored_extra_pages:
  192. pg.refresh()
  193. res.append(pg.create_html(ui.tabname))
  194. return res
  195. button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
  196. return ui
  197. def path_is_parent(parent_path, child_path):
  198. parent_path = os.path.abspath(parent_path)
  199. child_path = os.path.abspath(child_path)
  200. return child_path.startswith(parent_path)
  201. def setup_ui(ui, gallery):
  202. def save_preview(index, images, filename):
  203. if len(images) == 0:
  204. print("There is no image in gallery to save as a preview.")
  205. return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
  206. index = int(index)
  207. index = 0 if index < 0 else index
  208. index = len(images) - 1 if index >= len(images) else index
  209. img_info = images[index if index >= 0 else 0]
  210. image = image_from_url_text(img_info)
  211. geninfo, items = read_info_from_image(image)
  212. is_allowed = False
  213. for extra_page in ui.stored_extra_pages:
  214. if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
  215. is_allowed = True
  216. break
  217. assert is_allowed, f'writing to {filename} is not allowed'
  218. if geninfo:
  219. pnginfo_data = PngImagePlugin.PngInfo()
  220. pnginfo_data.add_text('parameters', geninfo)
  221. image.save(filename, pnginfo=pnginfo_data)
  222. else:
  223. image.save(filename)
  224. return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
  225. ui.button_save_preview.click(
  226. fn=save_preview,
  227. _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
  228. inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
  229. outputs=[*ui.pages]
  230. )