xyz_grid_support.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import re
  2. import numpy as np
  3. from modules import scripts, shared
  4. try:
  5. from scripts.global_state import update_cn_models, cn_models_names, cn_preprocessor_modules
  6. from scripts.external_code import ResizeMode, ControlMode
  7. except (ImportError, NameError):
  8. import_error = True
  9. else:
  10. import_error = False
  11. DEBUG_MODE = False
  12. def debug_info(func):
  13. def debug_info_(*args, **kwargs):
  14. if DEBUG_MODE:
  15. print(f"Debug info: {func.__name__}, {args}")
  16. return func(*args, **kwargs)
  17. return debug_info_
  18. def find_dict(dict_list, keyword, search_key="name", stop=False):
  19. result = next((d for d in dict_list if d[search_key] == keyword), None)
  20. if result or not stop:
  21. return result
  22. else:
  23. raise ValueError(f"Dictionary with value '{keyword}' in key '{search_key}' not found.")
  24. def flatten(lst):
  25. result = []
  26. for element in lst:
  27. if isinstance(element, list):
  28. result.extend(flatten(element))
  29. else:
  30. result.append(element)
  31. return result
  32. def is_all_included(target_list, check_list, allow_blank=False, stop=False):
  33. for element in flatten(target_list):
  34. if allow_blank and str(element) in ["None", ""]:
  35. continue
  36. elif element not in check_list:
  37. if not stop:
  38. return False
  39. else:
  40. raise ValueError(f"'{element}' is not included in check list.")
  41. return True
  42. class ListParser():
  43. """This class restores a broken list caused by the following process
  44. in the xyz_grid module.
  45. -> valslist = [x.strip() for x in chain.from_iterable(
  46. csv.reader(StringIO(vals)))]
  47. It also performs type conversion,
  48. adjusts the number of elements in the list, and other operations.
  49. This class directly modifies the received list.
  50. """
  51. numeric_pattern = {
  52. int: {
  53. "range": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*",
  54. "count": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*"
  55. },
  56. float: {
  57. "range": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\(([+-]\d+(?:\.\d*)?)\s*\))?\s*",
  58. "count": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\[(\d+(?:\.\d*)?)\s*\])?\s*"
  59. }
  60. }
  61. ################################################
  62. #
  63. # Initialization method from here.
  64. #
  65. ################################################
  66. def __init__(self, my_list, converter=None, allow_blank=True, exclude_list=None, run=True):
  67. self.my_list = my_list
  68. self.converter = converter
  69. self.allow_blank = allow_blank
  70. self.exclude_list = exclude_list
  71. self.re_bracket_start = None
  72. self.re_bracket_start_precheck = None
  73. self.re_bracket_end = None
  74. self.re_bracket_end_precheck = None
  75. self.re_range = None
  76. self.re_count = None
  77. self.compile_regex()
  78. if run:
  79. self.auto_normalize()
  80. def compile_regex(self):
  81. exclude_pattern = "|".join(self.exclude_list) if self.exclude_list else None
  82. if exclude_pattern is None:
  83. self.re_bracket_start = re.compile(r"^\[")
  84. self.re_bracket_end = re.compile(r"\]$")
  85. else:
  86. self.re_bracket_start = re.compile(fr"^\[(?!(?:{exclude_pattern})\])")
  87. self.re_bracket_end = re.compile(fr"(?<!\[(?:{exclude_pattern}))\]$")
  88. if self.converter not in self.numeric_pattern:
  89. return self
  90. # If the converter is either int or float.
  91. self.re_range = re.compile(self.numeric_pattern[self.converter]["range"])
  92. self.re_count = re.compile(self.numeric_pattern[self.converter]["count"])
  93. self.re_bracket_start_precheck = None
  94. self.re_bracket_end_precheck = self.re_count
  95. return self
  96. ################################################
  97. #
  98. # Public method from here.
  99. #
  100. ################################################
  101. ################################################
  102. # This method is executed at the time of initialization.
  103. #
  104. def auto_normalize(self):
  105. if not self.has_list_notation():
  106. self.numeric_range_parser()
  107. self.type_convert()
  108. return self
  109. else:
  110. self.fix_structure()
  111. self.numeric_range_parser()
  112. self.type_convert()
  113. self.fill_to_longest()
  114. return self
  115. def has_list_notation(self):
  116. return any(self._search_bracket(s) for s in self.my_list)
  117. def numeric_range_parser(self, my_list=None, depth=0):
  118. if self.converter not in self.numeric_pattern:
  119. return self
  120. my_list = self.my_list if my_list is None else my_list
  121. result = []
  122. is_matched = False
  123. for s in my_list:
  124. if isinstance(s, list):
  125. result.extend(self.numeric_range_parser(s, depth+1))
  126. continue
  127. match = self._numeric_range_to_list(s)
  128. if s != match:
  129. is_matched = True
  130. result.extend(match if not depth else [match])
  131. continue
  132. else:
  133. result.append(s)
  134. continue
  135. if depth:
  136. return self._transpose(result) if is_matched else [result]
  137. else:
  138. my_list[:] = result
  139. return self
  140. def type_convert(self, my_list=None):
  141. my_list = self.my_list if my_list is None else my_list
  142. for i, s in enumerate(my_list):
  143. if isinstance(s, list):
  144. self.type_convert(s)
  145. elif self.allow_blank and (str(s) in ["None", ""]):
  146. my_list[i] = None
  147. elif self.converter:
  148. my_list[i] = self.converter(s)
  149. else:
  150. my_list[i] = s
  151. return self
  152. def fix_structure(self):
  153. def is_same_length(list1, list2):
  154. return len(list1) == len(list2)
  155. start_indices, end_indices = [], []
  156. for i, s in enumerate(self.my_list):
  157. if is_same_length(start_indices, end_indices):
  158. replace_string = self._search_bracket(s, "[", replace="")
  159. if s != replace_string:
  160. s = replace_string
  161. start_indices.append(i)
  162. if not is_same_length(start_indices, end_indices):
  163. replace_string = self._search_bracket(s, "]", replace="")
  164. if s != replace_string:
  165. s = replace_string
  166. end_indices.append(i + 1)
  167. self.my_list[i] = s
  168. if not is_same_length(start_indices, end_indices):
  169. raise ValueError(f"Lengths of {start_indices} and {end_indices} are different.")
  170. # Restore the structure of a list.
  171. for i, j in zip(reversed(start_indices), reversed(end_indices)):
  172. self.my_list[i:j] = [self.my_list[i:j]]
  173. return self
  174. def fill_to_longest(self, my_list=None, value=None, index=None):
  175. my_list = self.my_list if my_list is None else my_list
  176. if not self.sublist_exists(my_list):
  177. return self
  178. max_length = max(len(sub_list) for sub_list in my_list if isinstance(sub_list, list))
  179. for i, sub_list in enumerate(my_list):
  180. if isinstance(sub_list, list):
  181. fill_value = value if index is None else sub_list[index]
  182. my_list[i] = sub_list + [fill_value] * (max_length-len(sub_list))
  183. return self
  184. def sublist_exists(self, my_list=None):
  185. my_list = self.my_list if my_list is None else my_list
  186. return any(isinstance(item, list) for item in my_list)
  187. def all_sublists(self, my_list=None): # Unused method
  188. my_list = self.my_list if my_list is None else my_list
  189. return all(isinstance(item, list) for item in my_list)
  190. def get_list(self): # Unused method
  191. return self.my_list
  192. ################################################
  193. #
  194. # Private method from here.
  195. #
  196. ################################################
  197. def _search_bracket(self, string, bracket="[", replace=None):
  198. if bracket == "[":
  199. pattern = self.re_bracket_start
  200. precheck = self.re_bracket_start_precheck # None
  201. elif bracket == "]":
  202. pattern = self.re_bracket_end
  203. precheck = self.re_bracket_end_precheck
  204. else:
  205. raise ValueError(f"Invalid argument provided. (bracket: {bracket})")
  206. if precheck and precheck.fullmatch(string):
  207. return None if replace is None else string
  208. elif replace is None:
  209. return pattern.search(string)
  210. else:
  211. return pattern.sub(replace, string)
  212. def _numeric_range_to_list(self, string):
  213. match = self.re_range.fullmatch(string)
  214. if match is not None:
  215. if self.converter == int:
  216. start = int(match.group(1))
  217. end = int(match.group(2)) + 1
  218. step = int(match.group(3)) if match.group(3) is not None else 1
  219. return list(range(start, end, step))
  220. else: # float
  221. start = float(match.group(1))
  222. end = float(match.group(2))
  223. step = float(match.group(3)) if match.group(3) is not None else 1
  224. return np.arange(start, end + step, step).tolist()
  225. match = self.re_count.fullmatch(string)
  226. if match is not None:
  227. if self.converter == int:
  228. start = int(match.group(1))
  229. end = int(match.group(2))
  230. num = int(match.group(3)) if match.group(3) is not None else 1
  231. return [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
  232. else: # float
  233. start = float(match.group(1))
  234. end = float(match.group(2))
  235. num = int(match.group(3)) if match.group(3) is not None else 1
  236. return np.linspace(start=start, stop=end, num=num).tolist()
  237. return string
  238. def _transpose(self, my_list=None):
  239. my_list = self.my_list if my_list is None else my_list
  240. my_list = [item if isinstance(item, list) else [item] for item in my_list]
  241. self.fill_to_longest(my_list, index=-1)
  242. return np.array(my_list, dtype=object).T.tolist()
  243. ################################################
  244. #
  245. # The methods of ListParser class end here.
  246. #
  247. ################################################
  248. ################################################################
  249. ################################################################
  250. #
  251. # Starting the main process of this module.
  252. #
  253. # functions are executed in this order:
  254. # find_module
  255. # add_axis_options
  256. # identity
  257. # enable_script_control
  258. # apply_field
  259. # confirm
  260. # bool_
  261. # choices_for
  262. # make_excluded_list
  263. # config lists for AxisOptions:
  264. # validation_data
  265. # extra_axis_options
  266. ################################################################
  267. ################################################################
  268. def find_module(module_names):
  269. if isinstance(module_names, str):
  270. module_names = [s.strip() for s in module_names.split(",")]
  271. for data in scripts.scripts_data:
  272. if data.script_class.__module__ in module_names and hasattr(data, "module"):
  273. return data.module
  274. return None
  275. def add_axis_options(xyz_grid):
  276. ################################################
  277. #
  278. # Define a function to pass to the AxisOption class from here.
  279. #
  280. ################################################
  281. ################################################
  282. # Set this function as the type attribute of the AxisOption class.
  283. # To skip the following processing of xyz_grid module.
  284. # -> valslist = [opt.type(x) for x in valslist]
  285. # Perform type conversion using the function
  286. # set to the confirm attribute instead.
  287. #
  288. def identity(x):
  289. return x
  290. def enable_script_control():
  291. shared.opts.data["control_net_allow_script_control"] = True
  292. def apply_field(field):
  293. @debug_info
  294. def apply_field_(p, x, xs):
  295. enable_script_control()
  296. setattr(p, field, x)
  297. return apply_field_
  298. ################################################
  299. # The confirm function defined in this module
  300. # enables list notation and performs type conversion.
  301. #
  302. # Example:
  303. # any = [any, any, any, ...]
  304. # [any] = [any, None, None, ...]
  305. # [None, None, any] = [None, None, any]
  306. # [,,any] = [None, None, any]
  307. # any, [,any,] = [any, any, any, ...], [None, any, None]
  308. #
  309. # Enabled Only:
  310. # any = [any] = [any, None, None, ...]
  311. # (any and [any] are considered equivalent)
  312. #
  313. def confirm(func_or_str):
  314. @debug_info
  315. def confirm_(p, xs):
  316. if callable(func_or_str): # func_or_str is converter
  317. ListParser(xs, func_or_str, allow_blank=True)
  318. return
  319. elif isinstance(func_or_str, str): # func_or_str is keyword
  320. valid_data = find_dict(validation_data, func_or_str, stop=True)
  321. converter = valid_data["type"]
  322. exclude_list = valid_data["exclude"]() if valid_data["exclude"] else None
  323. check_list = valid_data["check"]()
  324. ListParser(xs, converter, allow_blank=True, exclude_list=exclude_list)
  325. is_all_included(xs, check_list, allow_blank=True, stop=True)
  326. return
  327. else:
  328. raise TypeError(f"Argument must be callable or str, not {type(func_or_str).__name__}.")
  329. return confirm_
  330. def bool_(string):
  331. string = str(string)
  332. if string in ["None", ""]:
  333. return None
  334. elif string.lower() in ["true", "1"]:
  335. return True
  336. elif string.lower() in ["false", "0"]:
  337. return False
  338. else:
  339. raise ValueError(f"Could not convert string to boolean: {string}")
  340. def choices_bool():
  341. return ["False", "True"]
  342. def choices_model():
  343. update_cn_models()
  344. return list(cn_models_names.values())
  345. def choices_control_mode():
  346. return [e.value for e in ControlMode]
  347. def choices_resize_mode():
  348. return [e.value for e in ResizeMode]
  349. def choices_preprocessor():
  350. return list(cn_preprocessor_modules)
  351. def make_excluded_list():
  352. pattern = re.compile(r"\[(\w+)\]")
  353. return [match.group(1) for s in choices_model()
  354. for match in pattern.finditer(s)]
  355. validation_data = [
  356. {"name": "model", "type": str, "check": choices_model, "exclude": make_excluded_list},
  357. {"name": "control_mode", "type": str, "check": choices_control_mode, "exclude": None},
  358. {"name": "resize_mode", "type": str, "check": choices_resize_mode, "exclude": None},
  359. {"name": "preprocessor", "type": str, "check": choices_preprocessor, "exclude": None},
  360. ]
  361. extra_axis_options = [
  362. xyz_grid.AxisOption("[ControlNet] Enabled", identity, apply_field("control_net_enabled"), confirm=confirm(bool_), choices=choices_bool),
  363. xyz_grid.AxisOption("[ControlNet] Model", identity, apply_field("control_net_model"), confirm=confirm("model"), choices=choices_model, cost=0.9),
  364. xyz_grid.AxisOption("[ControlNet] Weight", identity, apply_field("control_net_weight"), confirm=confirm(float)),
  365. xyz_grid.AxisOption("[ControlNet] Guidance Start", identity, apply_field("control_net_guidance_start"), confirm=confirm(float)),
  366. xyz_grid.AxisOption("[ControlNet] Guidance End", identity, apply_field("control_net_guidance_end"), confirm=confirm(float)),
  367. xyz_grid.AxisOption("[ControlNet] Control Mode", identity, apply_field("control_net_control_mode"), confirm=confirm("control_mode"), choices=choices_control_mode),
  368. xyz_grid.AxisOption("[ControlNet] Resize Mode", identity, apply_field("control_net_resize_mode"), confirm=confirm("resize_mode"), choices=choices_resize_mode),
  369. xyz_grid.AxisOption("[ControlNet] Preprocessor", identity, apply_field("control_net_module"), confirm=confirm("preprocessor"), choices=choices_preprocessor),
  370. xyz_grid.AxisOption("[ControlNet] Pre Resolution", identity, apply_field("control_net_pres"), confirm=confirm(int)),
  371. xyz_grid.AxisOption("[ControlNet] Pre Threshold A", identity, apply_field("control_net_pthr_a"), confirm=confirm(float)),
  372. xyz_grid.AxisOption("[ControlNet] Pre Threshold B", identity, apply_field("control_net_pthr_b"), confirm=confirm(float)),
  373. ]
  374. xyz_grid.axis_options.extend(extra_axis_options)
  375. def run():
  376. xyz_grid = find_module("xyz_grid.py, xy_grid.py")
  377. if xyz_grid:
  378. add_axis_options(xyz_grid)
  379. if not import_error:
  380. run()