prompt_matrix.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import math
  2. from collections import namedtuple
  3. from copy import copy
  4. import random
  5. import modules.scripts as scripts
  6. import gradio as gr
  7. from modules import images
  8. from modules.processing import process_images, Processed
  9. from modules.shared import opts, cmd_opts, state
  10. import modules.sd_samplers
  11. def draw_xy_grid(xs, ys, x_label, y_label, cell):
  12. res = []
  13. ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
  14. hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
  15. first_processed = None
  16. state.job_count = len(xs) * len(ys)
  17. for iy, y in enumerate(ys):
  18. for ix, x in enumerate(xs):
  19. state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
  20. processed = cell(x, y)
  21. if first_processed is None:
  22. first_processed = processed
  23. res.append(processed.images[0])
  24. grid = images.image_grid(res, rows=len(ys))
  25. grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
  26. first_processed.images = [grid]
  27. return first_processed
  28. class Script(scripts.Script):
  29. def title(self):
  30. return "Prompt matrix"
  31. def ui(self, is_img2img):
  32. gr.HTML('<br />')
  33. with gr.Row():
  34. with gr.Column():
  35. put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
  36. different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
  37. with gr.Column():
  38. prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
  39. variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
  40. with gr.Column():
  41. margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
  42. return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
  43. def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
  44. modules.processing.fix_seed(p)
  45. # Raise error if promp type is not positive or negative
  46. if prompt_type not in ["positive", "negative"]:
  47. raise ValueError(f"Unknown prompt type {prompt_type}")
  48. # Raise error if variations delimiter is not comma or space
  49. if variations_delimiter not in ["comma", "space"]:
  50. raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
  51. prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
  52. original_prompt = prompt[0] if type(prompt) == list else prompt
  53. positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
  54. delimiter = ", " if variations_delimiter == "comma" else " "
  55. all_prompts = []
  56. prompt_matrix_parts = original_prompt.split("|")
  57. combination_count = 2 ** (len(prompt_matrix_parts) - 1)
  58. for combination_num in range(combination_count):
  59. selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
  60. if put_at_start:
  61. selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
  62. else:
  63. selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
  64. all_prompts.append(delimiter.join(selected_prompts))
  65. p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
  66. p.do_not_save_grid = True
  67. print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
  68. if prompt_type == "positive":
  69. p.prompt = all_prompts
  70. else:
  71. p.negative_prompt = all_prompts
  72. p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
  73. p.prompt_for_display = positive_prompt
  74. processed = process_images(p)
  75. grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
  76. grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
  77. processed.images.insert(0, grid)
  78. processed.index_of_first_image = 1
  79. processed.infotexts.insert(0, processed.infotexts[0])
  80. if opts.grid_save:
  81. images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
  82. return processed