123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685 |
- from collections import namedtuple
- from copy import copy
- from itertools import permutations, chain
- import random
- import csv
- from io import StringIO
- from PIL import Image
- import numpy as np
- import modules.scripts as scripts
- import gradio as gr
- from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
- from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
- from modules.shared import opts, cmd_opts, state
- import modules.shared as shared
- import modules.sd_samplers
- import modules.sd_models
- import modules.sd_vae
- import glob
- import os
- import re
- from modules.ui_components import ToolButton
- fill_values_symbol = "\U0001f4d2" # 📒
- AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
- def apply_field(field):
- def fun(p, x, xs):
- setattr(p, field, x)
- return fun
- def apply_prompt(p, x, xs):
- if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
- raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
- p.prompt = p.prompt.replace(xs[0], x)
- p.negative_prompt = p.negative_prompt.replace(xs[0], x)
- def apply_order(p, x, xs):
- token_order = []
- # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
- for token in x:
- token_order.append((p.prompt.find(token), token))
- token_order.sort(key=lambda t: t[0])
- prompt_parts = []
- # Split the prompt up, taking out the tokens
- for _, token in token_order:
- n = p.prompt.find(token)
- prompt_parts.append(p.prompt[0:n])
- p.prompt = p.prompt[n + len(token):]
- # Rebuild the prompt with the tokens in the order we want
- prompt_tmp = ""
- for idx, part in enumerate(prompt_parts):
- prompt_tmp += part
- prompt_tmp += x[idx]
- p.prompt = prompt_tmp + p.prompt
- def apply_sampler(p, x, xs):
- sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
- if sampler_name is None:
- raise RuntimeError(f"Unknown sampler: {x}")
- p.sampler_name = sampler_name
- def confirm_samplers(p, xs):
- for x in xs:
- if x.lower() not in sd_samplers.samplers_map:
- raise RuntimeError(f"Unknown sampler: {x}")
- def apply_checkpoint(p, x, xs):
- info = modules.sd_models.get_closet_checkpoint_match(x)
- if info is None:
- raise RuntimeError(f"Unknown checkpoint: {x}")
- modules.sd_models.reload_model_weights(shared.sd_model, info)
- def confirm_checkpoints(p, xs):
- for x in xs:
- if modules.sd_models.get_closet_checkpoint_match(x) is None:
- raise RuntimeError(f"Unknown checkpoint: {x}")
- def apply_clip_skip(p, x, xs):
- opts.data["CLIP_stop_at_last_layers"] = x
- def apply_upscale_latent_space(p, x, xs):
- if x.lower().strip() != '0':
- opts.data["use_scale_latent_for_hires_fix"] = True
- else:
- opts.data["use_scale_latent_for_hires_fix"] = False
- def find_vae(name: str):
- if name.lower() in ['auto', 'automatic']:
- return modules.sd_vae.unspecified
- if name.lower() == 'none':
- return None
- else:
- choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
- if len(choices) == 0:
- print(f"No VAE found for {name}; using automatic")
- return modules.sd_vae.unspecified
- else:
- return modules.sd_vae.vae_dict[choices[0]]
- def apply_vae(p, x, xs):
- modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
- def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
- p.styles.extend(x.split(','))
- def apply_uni_pc_order(p, x, xs):
- opts.data["uni_pc_order"] = min(x, p.steps - 1)
- def apply_face_restore(p, opt, x):
- opt = opt.lower()
- if opt == 'codeformer':
- is_active = True
- p.face_restoration_model = 'CodeFormer'
- elif opt == 'gfpgan':
- is_active = True
- p.face_restoration_model = 'GFPGAN'
- else:
- is_active = opt in ('true', 'yes', 'y', '1')
- p.restore_faces = is_active
- def format_value_add_label(p, opt, x):
- if type(x) == float:
- x = round(x, 8)
- return f"{opt.label}: {x}"
- def format_value(p, opt, x):
- if type(x) == float:
- x = round(x, 8)
- return x
- def format_value_join_list(p, opt, x):
- return ", ".join(x)
- def do_nothing(p, x, xs):
- pass
- def format_nothing(p, opt, x):
- return ""
- def str_permutations(x):
- """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
- return x
- class AxisOption:
- def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
- self.label = label
- self.type = type
- self.apply = apply
- self.format_value = format_value
- self.confirm = confirm
- self.cost = cost
- self.choices = choices
- class AxisOptionImg2Img(AxisOption):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_img2img = True
- class AxisOptionTxt2Img(AxisOption):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_img2img = False
- axis_options = [
- AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
- AxisOption("Seed", int, apply_field("seed")),
- AxisOption("Var. seed", int, apply_field("subseed")),
- AxisOption("Var. strength", float, apply_field("subseed_strength")),
- AxisOption("Steps", int, apply_field("steps")),
- AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
- AxisOption("CFG Scale", float, apply_field("cfg_scale")),
- AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
- AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
- AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
- AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
- AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
- AxisOption("Sigma Churn", float, apply_field("s_churn")),
- AxisOption("Sigma min", float, apply_field("s_tmin")),
- AxisOption("Sigma max", float, apply_field("s_tmax")),
- AxisOption("Sigma noise", float, apply_field("s_noise")),
- AxisOption("Eta", float, apply_field("eta")),
- AxisOption("Clip skip", int, apply_clip_skip),
- AxisOption("Denoising", float, apply_field("denoising_strength")),
- AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
- AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
- AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
- AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
- AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
- AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
- ]
- def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
- hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
- ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
- title_texts = [[images.GridAnnotation(z)] for z in z_labels]
- list_size = (len(xs) * len(ys) * len(zs))
- processed_result = None
- state.job_count = list_size * p.n_iter
- def process_cell(x, y, z, ix, iy, iz):
- nonlocal processed_result
- def index(ix, iy, iz):
- return ix + iy * len(xs) + iz * len(xs) * len(ys)
- state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
- processed: Processed = cell(x, y, z, ix, iy, iz)
- if processed_result is None:
- # Use our first processed result object as a template container to hold our full results
- processed_result = copy(processed)
- processed_result.images = [None] * list_size
- processed_result.all_prompts = [None] * list_size
- processed_result.all_seeds = [None] * list_size
- processed_result.infotexts = [None] * list_size
- processed_result.index_of_first_image = 1
- idx = index(ix, iy, iz)
- if processed.images:
- # Non-empty list indicates some degree of success.
- processed_result.images[idx] = processed.images[0]
- processed_result.all_prompts[idx] = processed.prompt
- processed_result.all_seeds[idx] = processed.seed
- processed_result.infotexts[idx] = processed.infotexts[0]
- else:
- cell_mode = "P"
- cell_size = (processed_result.width, processed_result.height)
- if processed_result.images[0] is not None:
- cell_mode = processed_result.images[0].mode
- #This corrects size in case of batches:
- cell_size = processed_result.images[0].size
- processed_result.images[idx] = Image.new(cell_mode, cell_size)
- if first_axes_processed == 'x':
- for ix, x in enumerate(xs):
- if second_axes_processed == 'y':
- for iy, y in enumerate(ys):
- for iz, z in enumerate(zs):
- process_cell(x, y, z, ix, iy, iz)
- else:
- for iz, z in enumerate(zs):
- for iy, y in enumerate(ys):
- process_cell(x, y, z, ix, iy, iz)
- elif first_axes_processed == 'y':
- for iy, y in enumerate(ys):
- if second_axes_processed == 'x':
- for ix, x in enumerate(xs):
- for iz, z in enumerate(zs):
- process_cell(x, y, z, ix, iy, iz)
- else:
- for iz, z in enumerate(zs):
- for ix, x in enumerate(xs):
- process_cell(x, y, z, ix, iy, iz)
- elif first_axes_processed == 'z':
- for iz, z in enumerate(zs):
- if second_axes_processed == 'x':
- for ix, x in enumerate(xs):
- for iy, y in enumerate(ys):
- process_cell(x, y, z, ix, iy, iz)
- else:
- for iy, y in enumerate(ys):
- for ix, x in enumerate(xs):
- process_cell(x, y, z, ix, iy, iz)
- if not processed_result:
- # Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
- print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
- return Processed(p, [])
- elif not any(processed_result.images):
- print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
- return Processed(p, [])
- z_count = len(zs)
- sub_grids = [None] * z_count
- for i in range(z_count):
- start_index = (i * len(xs) * len(ys)) + i
- end_index = start_index + len(xs) * len(ys)
- grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
- if draw_legend:
- grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)
- processed_result.images.insert(i, grid)
- processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
- processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
- processed_result.infotexts.insert(i, processed_result.infotexts[start_index])
- sub_grid_size = processed_result.images[0].size
- z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
- if draw_legend:
- z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
- processed_result.images.insert(0, z_grid)
- #TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
- #processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
- #processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
- processed_result.infotexts.insert(0, processed_result.infotexts[0])
- return processed_result
- class SharedSettingsStackHelper(object):
- def __enter__(self):
- self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
- self.vae = opts.sd_vae
- self.uni_pc_order = opts.uni_pc_order
-
- def __exit__(self, exc_type, exc_value, tb):
- opts.data["sd_vae"] = self.vae
- opts.data["uni_pc_order"] = self.uni_pc_order
- modules.sd_models.reload_model_weights()
- modules.sd_vae.reload_vae_weights()
- opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
- re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
- re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
- re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
- re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
- class Script(scripts.Script):
- def title(self):
- return "X/Y/Z plot"
- def ui(self, is_img2img):
- self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
- with gr.Row():
- with gr.Column(scale=19):
- with gr.Row():
- x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
- x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
- fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
- with gr.Row():
- y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
- y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
- fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
- with gr.Row():
- z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
- z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
- fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
- with gr.Row(variant="compact", elem_id="axis_options"):
- with gr.Column():
- draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
- no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
- with gr.Column():
- include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
- include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
- with gr.Column():
- margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
-
- with gr.Row(variant="compact", elem_id="swap_axes"):
- swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
- swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
- swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
- def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
- return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
- xy_swap_args = [x_type, x_values, y_type, y_values]
- swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
- yz_swap_args = [y_type, y_values, z_type, z_values]
- swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
- xz_swap_args = [x_type, x_values, z_type, z_values]
- swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
- def fill(x_type):
- axis = self.current_axis_options[x_type]
- return ", ".join(axis.choices()) if axis.choices else gr.update()
- fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
- fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
- fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
- def select_axis(x_type):
- return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
- x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
- y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
- z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
- self.infotext_fields = (
- (x_type, "X Type"),
- (x_values, "X Values"),
- (y_type, "Y Type"),
- (y_values, "Y Values"),
- (z_type, "Z Type"),
- (z_values, "Z Values"),
- )
- return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
- def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
- if not no_fixed_seeds:
- modules.processing.fix_seed(p)
- if not opts.return_grid:
- p.batch_size = 1
- def process_axis(opt, vals):
- if opt.label == 'Nothing':
- return [0]
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
- if opt.type == int:
- valslist_ext = []
- for val in valslist:
- m = re_range.fullmatch(val)
- mc = re_range_count.fullmatch(val)
- if m is not None:
- start = int(m.group(1))
- end = int(m.group(2))+1
- step = int(m.group(3)) if m.group(3) is not None else 1
- valslist_ext += list(range(start, end, step))
- elif mc is not None:
- start = int(mc.group(1))
- end = int(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
-
- valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
- else:
- valslist_ext.append(val)
- valslist = valslist_ext
- elif opt.type == float:
- valslist_ext = []
- for val in valslist:
- m = re_range_float.fullmatch(val)
- mc = re_range_count_float.fullmatch(val)
- if m is not None:
- start = float(m.group(1))
- end = float(m.group(2))
- step = float(m.group(3)) if m.group(3) is not None else 1
- valslist_ext += np.arange(start, end + step, step).tolist()
- elif mc is not None:
- start = float(mc.group(1))
- end = float(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
-
- valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
- else:
- valslist_ext.append(val)
- valslist = valslist_ext
- elif opt.type == str_permutations:
- valslist = list(permutations(valslist))
- valslist = [opt.type(x) for x in valslist]
- # Confirm options are valid before starting
- if opt.confirm:
- opt.confirm(p, valslist)
- return valslist
- x_opt = self.current_axis_options[x_type]
- xs = process_axis(x_opt, x_values)
- y_opt = self.current_axis_options[y_type]
- ys = process_axis(y_opt, y_values)
- z_opt = self.current_axis_options[z_type]
- zs = process_axis(z_opt, z_values)
- # this could be moved to common code, but unlikely to be ever triggered anywhere else
- Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
- grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
- assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
- def fix_axis_seeds(axis_opt, axis_list):
- if axis_opt.label in ['Seed', 'Var. seed']:
- return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
- else:
- return axis_list
- if not no_fixed_seeds:
- xs = fix_axis_seeds(x_opt, xs)
- ys = fix_axis_seeds(y_opt, ys)
- zs = fix_axis_seeds(z_opt, zs)
- if x_opt.label == 'Steps':
- total_steps = sum(xs) * len(ys) * len(zs)
- elif y_opt.label == 'Steps':
- total_steps = sum(ys) * len(xs) * len(zs)
- elif z_opt.label == 'Steps':
- total_steps = sum(zs) * len(xs) * len(ys)
- else:
- total_steps = p.steps * len(xs) * len(ys) * len(zs)
- if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
- if x_opt.label == "Hires steps":
- total_steps += sum(xs) * len(ys) * len(zs)
- elif y_opt.label == "Hires steps":
- total_steps += sum(ys) * len(xs) * len(zs)
- elif z_opt.label == "Hires steps":
- total_steps += sum(zs) * len(xs) * len(ys)
- elif p.hr_second_pass_steps:
- total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
- else:
- total_steps *= 2
- total_steps *= p.n_iter
- image_cell_count = p.n_iter * p.batch_size
- cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
- plural_s = 's' if len(zs) > 1 else ''
- print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
- shared.total_tqdm.updateTotal(total_steps)
- state.xyz_plot_x = AxisInfo(x_opt, xs)
- state.xyz_plot_y = AxisInfo(y_opt, ys)
- state.xyz_plot_z = AxisInfo(z_opt, zs)
- # If one of the axes is very slow to change between (like SD model
- # checkpoint), then make sure it is in the outer iteration of the nested
- # `for` loop.
- first_axes_processed = 'z'
- second_axes_processed = 'y'
- if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
- first_axes_processed = 'x'
- if y_opt.cost > z_opt.cost:
- second_axes_processed = 'y'
- else:
- second_axes_processed = 'z'
- elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
- first_axes_processed = 'y'
- if x_opt.cost > z_opt.cost:
- second_axes_processed = 'x'
- else:
- second_axes_processed = 'z'
- elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
- first_axes_processed = 'z'
- if x_opt.cost > y_opt.cost:
- second_axes_processed = 'x'
- else:
- second_axes_processed = 'y'
- grid_infotext = [None] * (1 + len(zs))
- def cell(x, y, z, ix, iy, iz):
- if shared.state.interrupted:
- return Processed(p, [], p.seed, "")
- pc = copy(p)
- pc.styles = pc.styles[:]
- x_opt.apply(pc, x, xs)
- y_opt.apply(pc, y, ys)
- z_opt.apply(pc, z, zs)
- res = process_images(pc)
- # Sets subgrid infotexts
- subgrid_index = 1 + iz
- if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:
- pc.extra_generation_params = copy(pc.extra_generation_params)
- pc.extra_generation_params['Script'] = self.title()
- if x_opt.label != 'Nothing':
- pc.extra_generation_params["X Type"] = x_opt.label
- pc.extra_generation_params["X Values"] = x_values
- if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
- if y_opt.label != 'Nothing':
- pc.extra_generation_params["Y Type"] = y_opt.label
- pc.extra_generation_params["Y Values"] = y_values
- if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
- grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
- # Sets main grid infotext
- if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:
- pc.extra_generation_params = copy(pc.extra_generation_params)
- if z_opt.label != 'Nothing':
- pc.extra_generation_params["Z Type"] = z_opt.label
- pc.extra_generation_params["Z Values"] = z_values
- if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
- grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
- return res
- with SharedSettingsStackHelper():
- processed = draw_xyz_grid(
- p,
- xs=xs,
- ys=ys,
- zs=zs,
- x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
- y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
- z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
- cell=cell,
- draw_legend=draw_legend,
- include_lone_images=include_lone_images,
- include_sub_grids=include_sub_grids,
- first_axes_processed=first_axes_processed,
- second_axes_processed=second_axes_processed,
- margin_size=margin_size
- )
- if not processed.images:
- # It broke, no further handling needed.
- return processed
- z_count = len(zs)
- # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)
- processed.infotexts[:1+z_count] = grid_infotext[:1+z_count]
- if not include_lone_images:
- # Don't need sub-images anymore, drop from list:
- processed.images = processed.images[:z_count+1]
- if opts.grid_save:
- # Auto-save main and sub-grids:
- grid_count = z_count + 1 if z_count > 1 else 1
- for g in range(grid_count):
- #TODO: See previous comment about intentional data misalignment.
- adj_g = g-1 if g > 0 else g
- images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
- if not include_sub_grids:
- # Done with sub-grids, drop all related information:
- for sg in range(z_count):
- del processed.images[1]
- del processed.all_prompts[1]
- del processed.all_seeds[1]
- del processed.infotexts[1]
- return processed
|