global_state.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import os.path
  2. import stat
  3. import functools
  4. from collections import OrderedDict
  5. from modules import shared, scripts, sd_models
  6. from modules.paths import models_path
  7. from scripts.processor import *
  8. from scripts.utils import ndarray_lru_cache
  9. from scripts.logging import logger
  10. from typing import Dict, Callable, Optional, Tuple, List
  11. CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors"]
  12. cn_models_dir = os.path.join(models_path, "ControlNet")
  13. cn_models_dir_old = os.path.join(scripts.basedir(), "models")
  14. cn_models = OrderedDict() # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors
  15. cn_models_names = {} # "my_lora" -> "My_Lora(abcd1234)"
  16. def cache_preprocessors(preprocessor_modules: Dict[str, Callable]) -> Dict[str, Callable]:
  17. """ We want to share the preprocessor results in a single big cache, instead of a small
  18. cache for each preprocessor function. """
  19. CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0)
  20. # Set CACHE_SIZE = 0 will completely remove the caching layer. This can be
  21. # helpful when debugging preprocessor code.
  22. if CACHE_SIZE == 0:
  23. return preprocessor_modules
  24. logger.debug(f'Create LRU cache (max_size={CACHE_SIZE}) for preprocessor results.')
  25. @ndarray_lru_cache(max_size=CACHE_SIZE)
  26. def unified_preprocessor(preprocessor_name: str, *args, **kwargs):
  27. logger.debug(f'Calling preprocessor {preprocessor_name} outside of cache.')
  28. return preprocessor_modules[preprocessor_name](*args, **kwargs)
  29. # TODO: Introduce a seed parameter for shuffle preprocessor?
  30. uncacheable_preprocessors = ['shuffle']
  31. return {
  32. k: (
  33. v if k in uncacheable_preprocessors
  34. else functools.partial(unified_preprocessor, k)
  35. )
  36. for k, v
  37. in preprocessor_modules.items()
  38. }
  39. cn_preprocessor_modules = {
  40. "none": lambda x, *args, **kwargs: (x, True),
  41. "canny": canny,
  42. "depth": midas,
  43. "depth_leres": functools.partial(leres, boost=False),
  44. "depth_leres++": functools.partial(leres, boost=True),
  45. "hed": hed,
  46. "hed_safe": hed_safe,
  47. "mediapipe_face": mediapipe_face,
  48. "mlsd": mlsd,
  49. "normal_map": midas_normal,
  50. "openpose": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=False, include_face=False),
  51. "openpose_hand": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=True, include_face=False),
  52. "openpose_face": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=False, include_face=True),
  53. "openpose_faceonly": functools.partial(g_openpose_model.run_model, include_body=False, include_hand=False, include_face=True),
  54. "openpose_full": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=True, include_face=True),
  55. "dw_openpose_full": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=True, include_face=True, use_dw_pose=True),
  56. "clip_vision": clip,
  57. "color": color,
  58. "pidinet": pidinet,
  59. "pidinet_safe": pidinet_safe,
  60. "pidinet_sketch": pidinet_ts,
  61. "pidinet_scribble": scribble_pidinet,
  62. "scribble_xdog": scribble_xdog,
  63. "scribble_hed": scribble_hed,
  64. "segmentation": uniformer,
  65. "threshold": threshold,
  66. "depth_zoe": zoe_depth,
  67. "normal_bae": normal_bae,
  68. "oneformer_coco": oneformer_coco,
  69. "oneformer_ade20k": oneformer_ade20k,
  70. "lineart": lineart,
  71. "lineart_coarse": lineart_coarse,
  72. "lineart_anime": lineart_anime,
  73. "lineart_standard": lineart_standard,
  74. "shuffle": shuffle,
  75. "tile_resample": tile_resample,
  76. "invert": invert,
  77. "lineart_anime_denoise": lineart_anime_denoise,
  78. "reference_only": identity,
  79. "reference_adain": identity,
  80. "reference_adain+attn": identity,
  81. "inpaint": identity,
  82. "inpaint_only": identity,
  83. "inpaint_only+lama": lama_inpaint,
  84. "tile_colorfix": identity,
  85. "tile_colorfix+sharp": identity,
  86. }
  87. cn_preprocessor_unloadable = {
  88. "hed": unload_hed,
  89. "fake_scribble": unload_hed,
  90. "mlsd": unload_mlsd,
  91. "clip": unload_clip,
  92. "depth": unload_midas,
  93. "depth_leres": unload_leres,
  94. "normal_map": unload_midas,
  95. "pidinet": unload_pidinet,
  96. "openpose": g_openpose_model.unload,
  97. "openpose_hand": g_openpose_model.unload,
  98. "openpose_face": g_openpose_model.unload,
  99. "openpose_full": g_openpose_model.unload,
  100. "dw_openpose_full": g_openpose_model.unload,
  101. "segmentation": unload_uniformer,
  102. "depth_zoe": unload_zoe_depth,
  103. "normal_bae": unload_normal_bae,
  104. "oneformer_coco": unload_oneformer_coco,
  105. "oneformer_ade20k": unload_oneformer_ade20k,
  106. "lineart": unload_lineart,
  107. "lineart_coarse": unload_lineart_coarse,
  108. "lineart_anime": unload_lineart_anime,
  109. "lineart_anime_denoise": unload_lineart_anime_denoise,
  110. "inpaint_only+lama": unload_lama_inpaint
  111. }
  112. preprocessor_aliases = {
  113. "invert": "invert (from white bg & black line)",
  114. "lineart_standard": "lineart_standard (from white bg & black line)",
  115. "lineart": "lineart_realistic",
  116. "color": "t2ia_color_grid",
  117. "clip_vision": "t2ia_style_clipvision",
  118. "pidinet_sketch": "t2ia_sketch_pidi",
  119. "depth": "depth_midas",
  120. "normal_map": "normal_midas",
  121. "hed": "softedge_hed",
  122. "hed_safe": "softedge_hedsafe",
  123. "pidinet": "softedge_pidinet",
  124. "pidinet_safe": "softedge_pidisafe",
  125. "segmentation": "seg_ufade20k",
  126. "oneformer_coco": "seg_ofcoco",
  127. "oneformer_ade20k": "seg_ofade20k",
  128. "pidinet_scribble": "scribble_pidinet",
  129. "inpaint": "inpaint_global_harmonious",
  130. }
  131. ui_preprocessor_keys = ['none', preprocessor_aliases['invert']]
  132. ui_preprocessor_keys += sorted([preprocessor_aliases.get(k, k)
  133. for k in cn_preprocessor_modules.keys()
  134. if preprocessor_aliases.get(k, k) not in ui_preprocessor_keys])
  135. reverse_preprocessor_aliases = {preprocessor_aliases[k]: k for k in preprocessor_aliases.keys()}
  136. def get_module_basename(module: Optional[str]) -> str:
  137. if module is None:
  138. module = 'none'
  139. return reverse_preprocessor_aliases.get(module, module)
  140. default_conf = os.path.join("models", "cldm_v15.yaml")
  141. default_conf_adapter = os.path.join("models", "t2iadapter_sketch_sd14v1.yaml")
  142. default_detectedmap_dir = os.path.join("detected_maps")
  143. script_dir = scripts.basedir()
  144. os.makedirs(cn_models_dir, exist_ok=True)
  145. def traverse_all_files(curr_path, model_list):
  146. f_list = [
  147. (os.path.join(curr_path, entry.name), entry.stat())
  148. for entry in os.scandir(curr_path)
  149. if os.path.isdir(curr_path)
  150. ]
  151. for f_info in f_list:
  152. fname, fstat = f_info
  153. if os.path.splitext(fname)[1] in CN_MODEL_EXTS:
  154. model_list.append(f_info)
  155. elif stat.S_ISDIR(fstat.st_mode):
  156. model_list = traverse_all_files(fname, model_list)
  157. return model_list
  158. def get_all_models(sort_by, filter_by, path):
  159. res = OrderedDict()
  160. fileinfos = traverse_all_files(path, [])
  161. filter_by = filter_by.strip(" ")
  162. if len(filter_by) != 0:
  163. fileinfos = [x for x in fileinfos if filter_by.lower()
  164. in os.path.basename(x[0]).lower()]
  165. if sort_by == "name":
  166. fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0]))
  167. elif sort_by == "date":
  168. fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime)
  169. elif sort_by == "path name":
  170. fileinfos = sorted(fileinfos)
  171. for finfo in fileinfos:
  172. filename = finfo[0]
  173. name = os.path.splitext(os.path.basename(filename))[0]
  174. # Prevent a hypothetical "None.pt" from being listed.
  175. if name != "None":
  176. res[name + f" [{sd_models.model_hash(filename)}]"] = filename
  177. return res
  178. def update_cn_models():
  179. cn_models.clear()
  180. ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None))
  181. extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs
  182. if extra_lora_path is not None and os.path.exists(extra_lora_path))
  183. paths = [cn_models_dir, cn_models_dir_old, *extra_lora_paths]
  184. for path in paths:
  185. sort_by = shared.opts.data.get(
  186. "control_net_models_sort_models_by", "name")
  187. filter_by = shared.opts.data.get("control_net_models_name_filter", "")
  188. found = get_all_models(sort_by, filter_by, path)
  189. cn_models.update({**found, **cn_models})
  190. # insert "None" at the beginning of `cn_models` in-place
  191. cn_models_copy = OrderedDict(cn_models)
  192. cn_models.clear()
  193. cn_models.update({**{"None": None}, **cn_models_copy})
  194. cn_models_names.clear()
  195. for name_and_hash, filename in cn_models.items():
  196. if filename is None:
  197. continue
  198. name = os.path.splitext(os.path.basename(filename))[0].lower()
  199. cn_models_names[name] = name_and_hash
  200. def select_control_type(control_type: str) -> Tuple[List[str], List[str], str, str]:
  201. default_option = preprocessor_filters[control_type]
  202. pattern = control_type.lower()
  203. preprocessor_list = ui_preprocessor_keys
  204. model_list = list(cn_models.keys())
  205. if pattern == "all":
  206. return [
  207. preprocessor_list,
  208. model_list,
  209. 'none', #default option
  210. "None" #default model
  211. ]
  212. filtered_preprocessor_list = [
  213. x
  214. for x in preprocessor_list
  215. if pattern in x.lower() or x.lower() == "none"
  216. ]
  217. if pattern in ["canny", "lineart", "scribble", "mlsd"]:
  218. filtered_preprocessor_list += [
  219. x for x in preprocessor_list if "invert" in x.lower()
  220. ]
  221. filtered_model_list = [
  222. x for x in model_list if pattern in x.lower() or x.lower() == "none"
  223. ]
  224. if default_option not in filtered_preprocessor_list:
  225. default_option = filtered_preprocessor_list[0]
  226. if len(filtered_model_list) == 1:
  227. default_model = "None"
  228. filtered_model_list = model_list
  229. else:
  230. default_model = filtered_model_list[1]
  231. for x in filtered_model_list:
  232. if "11" in x.split("[")[0]:
  233. default_model = x
  234. break
  235. return (
  236. filtered_preprocessor_list,
  237. filtered_model_list,
  238. default_option,
  239. default_model
  240. )