123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449 |
- import re
- import numpy as np
- from modules import scripts, shared
- try:
- from scripts.global_state import update_cn_models, cn_models_names, cn_preprocessor_modules
- from scripts.external_code import ResizeMode, ControlMode
- except (ImportError, NameError):
- import_error = True
- else:
- import_error = False
- DEBUG_MODE = False
- def debug_info(func):
- def debug_info_(*args, **kwargs):
- if DEBUG_MODE:
- print(f"Debug info: {func.__name__}, {args}")
- return func(*args, **kwargs)
- return debug_info_
- def find_dict(dict_list, keyword, search_key="name", stop=False):
- result = next((d for d in dict_list if d[search_key] == keyword), None)
- if result or not stop:
- return result
- else:
- raise ValueError(f"Dictionary with value '{keyword}' in key '{search_key}' not found.")
- def flatten(lst):
- result = []
- for element in lst:
- if isinstance(element, list):
- result.extend(flatten(element))
- else:
- result.append(element)
- return result
- def is_all_included(target_list, check_list, allow_blank=False, stop=False):
- for element in flatten(target_list):
- if allow_blank and str(element) in ["None", ""]:
- continue
- elif element not in check_list:
- if not stop:
- return False
- else:
- raise ValueError(f"'{element}' is not included in check list.")
- return True
- class ListParser():
- """This class restores a broken list caused by the following process
- in the xyz_grid module.
- -> valslist = [x.strip() for x in chain.from_iterable(
- csv.reader(StringIO(vals)))]
- It also performs type conversion,
- adjusts the number of elements in the list, and other operations.
- This class directly modifies the received list.
- """
- numeric_pattern = {
- int: {
- "range": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*",
- "count": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*"
- },
- float: {
- "range": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\(([+-]\d+(?:\.\d*)?)\s*\))?\s*",
- "count": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\[(\d+(?:\.\d*)?)\s*\])?\s*"
- }
- }
- ################################################
- #
- # Initialization method from here.
- #
- ################################################
- def __init__(self, my_list, converter=None, allow_blank=True, exclude_list=None, run=True):
- self.my_list = my_list
- self.converter = converter
- self.allow_blank = allow_blank
- self.exclude_list = exclude_list
- self.re_bracket_start = None
- self.re_bracket_start_precheck = None
- self.re_bracket_end = None
- self.re_bracket_end_precheck = None
- self.re_range = None
- self.re_count = None
- self.compile_regex()
- if run:
- self.auto_normalize()
- def compile_regex(self):
- exclude_pattern = "|".join(self.exclude_list) if self.exclude_list else None
- if exclude_pattern is None:
- self.re_bracket_start = re.compile(r"^\[")
- self.re_bracket_end = re.compile(r"\]$")
- else:
- self.re_bracket_start = re.compile(fr"^\[(?!(?:{exclude_pattern})\])")
- self.re_bracket_end = re.compile(fr"(?<!\[(?:{exclude_pattern}))\]$")
- if self.converter not in self.numeric_pattern:
- return self
- # If the converter is either int or float.
- self.re_range = re.compile(self.numeric_pattern[self.converter]["range"])
- self.re_count = re.compile(self.numeric_pattern[self.converter]["count"])
- self.re_bracket_start_precheck = None
- self.re_bracket_end_precheck = self.re_count
- return self
- ################################################
- #
- # Public method from here.
- #
- ################################################
- ################################################
- # This method is executed at the time of initialization.
- #
- def auto_normalize(self):
- if not self.has_list_notation():
- self.numeric_range_parser()
- self.type_convert()
- return self
- else:
- self.fix_structure()
- self.numeric_range_parser()
- self.type_convert()
- self.fill_to_longest()
- return self
- def has_list_notation(self):
- return any(self._search_bracket(s) for s in self.my_list)
- def numeric_range_parser(self, my_list=None, depth=0):
- if self.converter not in self.numeric_pattern:
- return self
- my_list = self.my_list if my_list is None else my_list
- result = []
- is_matched = False
- for s in my_list:
- if isinstance(s, list):
- result.extend(self.numeric_range_parser(s, depth+1))
- continue
- match = self._numeric_range_to_list(s)
- if s != match:
- is_matched = True
- result.extend(match if not depth else [match])
- continue
- else:
- result.append(s)
- continue
- if depth:
- return self._transpose(result) if is_matched else [result]
- else:
- my_list[:] = result
- return self
- def type_convert(self, my_list=None):
- my_list = self.my_list if my_list is None else my_list
- for i, s in enumerate(my_list):
- if isinstance(s, list):
- self.type_convert(s)
- elif self.allow_blank and (str(s) in ["None", ""]):
- my_list[i] = None
- elif self.converter:
- my_list[i] = self.converter(s)
- else:
- my_list[i] = s
- return self
- def fix_structure(self):
- def is_same_length(list1, list2):
- return len(list1) == len(list2)
- start_indices, end_indices = [], []
- for i, s in enumerate(self.my_list):
- if is_same_length(start_indices, end_indices):
- replace_string = self._search_bracket(s, "[", replace="")
- if s != replace_string:
- s = replace_string
- start_indices.append(i)
- if not is_same_length(start_indices, end_indices):
- replace_string = self._search_bracket(s, "]", replace="")
- if s != replace_string:
- s = replace_string
- end_indices.append(i + 1)
- self.my_list[i] = s
- if not is_same_length(start_indices, end_indices):
- raise ValueError(f"Lengths of {start_indices} and {end_indices} are different.")
- # Restore the structure of a list.
- for i, j in zip(reversed(start_indices), reversed(end_indices)):
- self.my_list[i:j] = [self.my_list[i:j]]
- return self
- def fill_to_longest(self, my_list=None, value=None, index=None):
- my_list = self.my_list if my_list is None else my_list
- if not self.sublist_exists(my_list):
- return self
- max_length = max(len(sub_list) for sub_list in my_list if isinstance(sub_list, list))
- for i, sub_list in enumerate(my_list):
- if isinstance(sub_list, list):
- fill_value = value if index is None else sub_list[index]
- my_list[i] = sub_list + [fill_value] * (max_length-len(sub_list))
- return self
- def sublist_exists(self, my_list=None):
- my_list = self.my_list if my_list is None else my_list
- return any(isinstance(item, list) for item in my_list)
- def all_sublists(self, my_list=None): # Unused method
- my_list = self.my_list if my_list is None else my_list
- return all(isinstance(item, list) for item in my_list)
- def get_list(self): # Unused method
- return self.my_list
- ################################################
- #
- # Private method from here.
- #
- ################################################
- def _search_bracket(self, string, bracket="[", replace=None):
- if bracket == "[":
- pattern = self.re_bracket_start
- precheck = self.re_bracket_start_precheck # None
- elif bracket == "]":
- pattern = self.re_bracket_end
- precheck = self.re_bracket_end_precheck
- else:
- raise ValueError(f"Invalid argument provided. (bracket: {bracket})")
- if precheck and precheck.fullmatch(string):
- return None if replace is None else string
- elif replace is None:
- return pattern.search(string)
- else:
- return pattern.sub(replace, string)
- def _numeric_range_to_list(self, string):
- match = self.re_range.fullmatch(string)
- if match is not None:
- if self.converter == int:
- start = int(match.group(1))
- end = int(match.group(2)) + 1
- step = int(match.group(3)) if match.group(3) is not None else 1
- return list(range(start, end, step))
- else: # float
- start = float(match.group(1))
- end = float(match.group(2))
- step = float(match.group(3)) if match.group(3) is not None else 1
- return np.arange(start, end + step, step).tolist()
- match = self.re_count.fullmatch(string)
- if match is not None:
- if self.converter == int:
- start = int(match.group(1))
- end = int(match.group(2))
- num = int(match.group(3)) if match.group(3) is not None else 1
- return [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
- else: # float
- start = float(match.group(1))
- end = float(match.group(2))
- num = int(match.group(3)) if match.group(3) is not None else 1
- return np.linspace(start=start, stop=end, num=num).tolist()
- return string
- def _transpose(self, my_list=None):
- my_list = self.my_list if my_list is None else my_list
- my_list = [item if isinstance(item, list) else [item] for item in my_list]
- self.fill_to_longest(my_list, index=-1)
- return np.array(my_list, dtype=object).T.tolist()
- ################################################
- #
- # The methods of ListParser class end here.
- #
- ################################################
- ################################################################
- ################################################################
- #
- # Starting the main process of this module.
- #
- # functions are executed in this order:
- # find_module
- # add_axis_options
- # identity
- # enable_script_control
- # apply_field
- # confirm
- # bool_
- # choices_for
- # make_excluded_list
- # config lists for AxisOptions:
- # validation_data
- # extra_axis_options
- ################################################################
- ################################################################
- def find_module(module_names):
- if isinstance(module_names, str):
- module_names = [s.strip() for s in module_names.split(",")]
- for data in scripts.scripts_data:
- if data.script_class.__module__ in module_names and hasattr(data, "module"):
- return data.module
- return None
- def add_axis_options(xyz_grid):
- ################################################
- #
- # Define a function to pass to the AxisOption class from here.
- #
- ################################################
-
- ################################################
- # Set this function as the type attribute of the AxisOption class.
- # To skip the following processing of xyz_grid module.
- # -> valslist = [opt.type(x) for x in valslist]
- # Perform type conversion using the function
- # set to the confirm attribute instead.
- #
- def identity(x):
- return x
-
- def enable_script_control():
- shared.opts.data["control_net_allow_script_control"] = True
- def apply_field(field):
- @debug_info
- def apply_field_(p, x, xs):
- enable_script_control()
- setattr(p, field, x)
- return apply_field_
- ################################################
- # The confirm function defined in this module
- # enables list notation and performs type conversion.
- #
- # Example:
- # any = [any, any, any, ...]
- # [any] = [any, None, None, ...]
- # [None, None, any] = [None, None, any]
- # [,,any] = [None, None, any]
- # any, [,any,] = [any, any, any, ...], [None, any, None]
- #
- # Enabled Only:
- # any = [any] = [any, None, None, ...]
- # (any and [any] are considered equivalent)
- #
- def confirm(func_or_str):
- @debug_info
- def confirm_(p, xs):
- if callable(func_or_str): # func_or_str is converter
- ListParser(xs, func_or_str, allow_blank=True)
- return
- elif isinstance(func_or_str, str): # func_or_str is keyword
- valid_data = find_dict(validation_data, func_or_str, stop=True)
- converter = valid_data["type"]
- exclude_list = valid_data["exclude"]() if valid_data["exclude"] else None
- check_list = valid_data["check"]()
- ListParser(xs, converter, allow_blank=True, exclude_list=exclude_list)
- is_all_included(xs, check_list, allow_blank=True, stop=True)
- return
- else:
- raise TypeError(f"Argument must be callable or str, not {type(func_or_str).__name__}.")
- return confirm_
- def bool_(string):
- string = str(string)
- if string in ["None", ""]:
- return None
- elif string.lower() in ["true", "1"]:
- return True
- elif string.lower() in ["false", "0"]:
- return False
- else:
- raise ValueError(f"Could not convert string to boolean: {string}")
- def choices_bool():
- return ["False", "True"]
- def choices_model():
- update_cn_models()
- return list(cn_models_names.values())
- def choices_control_mode():
- return [e.value for e in ControlMode]
- def choices_resize_mode():
- return [e.value for e in ResizeMode]
- def choices_preprocessor():
- return list(cn_preprocessor_modules)
- def make_excluded_list():
- pattern = re.compile(r"\[(\w+)\]")
- return [match.group(1) for s in choices_model()
- for match in pattern.finditer(s)]
- validation_data = [
- {"name": "model", "type": str, "check": choices_model, "exclude": make_excluded_list},
- {"name": "control_mode", "type": str, "check": choices_control_mode, "exclude": None},
- {"name": "resize_mode", "type": str, "check": choices_resize_mode, "exclude": None},
- {"name": "preprocessor", "type": str, "check": choices_preprocessor, "exclude": None},
- ]
- extra_axis_options = [
- xyz_grid.AxisOption("[ControlNet] Enabled", identity, apply_field("control_net_enabled"), confirm=confirm(bool_), choices=choices_bool),
- xyz_grid.AxisOption("[ControlNet] Model", identity, apply_field("control_net_model"), confirm=confirm("model"), choices=choices_model, cost=0.9),
- xyz_grid.AxisOption("[ControlNet] Weight", identity, apply_field("control_net_weight"), confirm=confirm(float)),
- xyz_grid.AxisOption("[ControlNet] Guidance Start", identity, apply_field("control_net_guidance_start"), confirm=confirm(float)),
- xyz_grid.AxisOption("[ControlNet] Guidance End", identity, apply_field("control_net_guidance_end"), confirm=confirm(float)),
- xyz_grid.AxisOption("[ControlNet] Control Mode", identity, apply_field("control_net_control_mode"), confirm=confirm("control_mode"), choices=choices_control_mode),
- xyz_grid.AxisOption("[ControlNet] Resize Mode", identity, apply_field("control_net_resize_mode"), confirm=confirm("resize_mode"), choices=choices_resize_mode),
- xyz_grid.AxisOption("[ControlNet] Preprocessor", identity, apply_field("control_net_module"), confirm=confirm("preprocessor"), choices=choices_preprocessor),
- xyz_grid.AxisOption("[ControlNet] Pre Resolution", identity, apply_field("control_net_pres"), confirm=confirm(int)),
- xyz_grid.AxisOption("[ControlNet] Pre Threshold A", identity, apply_field("control_net_pthr_a"), confirm=confirm(float)),
- xyz_grid.AxisOption("[ControlNet] Pre Threshold B", identity, apply_field("control_net_pthr_b"), confirm=confirm(float)),
- ]
- xyz_grid.axis_options.extend(extra_axis_options)
- def run():
- xyz_grid = find_module("xyz_grid.py, xy_grid.py")
- if xyz_grid:
- add_axis_options(xyz_grid)
- if not import_error:
- run()
|