tilediffusion.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. '''
  2. # ------------------------------------------------------------------------
  3. #
  4. # Tiled Diffusion for Automatic1111 WebUI
  5. #
  6. # Introducing revolutionary large image drawing methods:
  7. # MultiDiffusion and Mixture of Diffusers!
  8. #
  9. # Techniques is not originally proposed by me, please refer to
  10. #
  11. # MultiDiffusion: https://multidiffusion.github.io
  12. # Mixture of Diffusers: https://github.com/albarji/mixture-of-diffusers
  13. #
  14. # The script contains a few optimizations including:
  15. # - symmetric tiling bboxes
  16. # - cached tiling weights
  17. # - batched denoising
  18. # - advanced prompt control for each tile
  19. #
  20. # ------------------------------------------------------------------------
  21. #
  22. # This script hooks into the original sampler and decomposes the latent
  23. # image, sampled separately and run weighted average to merge them back.
  24. #
  25. # Advantages:
  26. # - Allows for super large resolutions (2k~8k) for both txt2img and img2img.
  27. # - The merged output is completely seamless without any post-processing.
  28. # - Training free. No need to train a new model, and you can control the
  29. # text prompt for each tile.
  30. #
  31. # Drawbacks:
  32. # - Depending on your parameter settings, the process can be very slow,
  33. # especially when overlap is relatively large.
  34. # - The gradient calculation is not compatible with this hack. It
  35. # will break any backward() or torch.autograd.grad() that passes UNet.
  36. #
  37. # How it works (insanely simple!)
  38. # 1) The latent image x_t is split into tiles
  39. # 2) The tiles are denoised by original sampler to get x_t-1
  40. # 3) The tiles are added together, but divided by how many times each pixel
  41. # is added.
  42. #
  43. # Enjoy!
  44. #
  45. # @author: LI YI @ Nanyang Technological University - Singapore
  46. # @date: 2023-03-03
  47. # @license: MIT License
  48. #
  49. # Please give me a star if you like this project!
  50. #
  51. # ------------------------------------------------------------------------
  52. '''
  53. import torch
  54. import numpy as np
  55. from enum import Enum
  56. import gradio as gr
  57. from modules import sd_samplers, images, shared, scripts
  58. from modules.shared import opts
  59. from modules.ui import gr_show
  60. from modules.processing import StableDiffusionProcessing
  61. from methods import TiledDiffusion, MultiDiffusion, MixtureOfDiffusers, splitable
  62. BBOX_MAX_NUM = min(shared.cmd_opts.md_max_regions if hasattr(
  63. shared.cmd_opts, "md_max_regions") else 8, 16)
  64. class Method(Enum):
  65. MULTI_DIFF = 'MultiDiffusion'
  66. MIX_DIFF = 'Mixture of Diffusers'
  67. class Script(scripts.Script):
  68. def title(self):
  69. return "Tiled Diffusion"
  70. def show(self, is_img2img):
  71. return scripts.AlwaysVisible
  72. def ui(self, is_img2img):
  73. tab = 't2i' if not is_img2img else 'i2i'
  74. is_t2i = 'true' if not is_img2img else 'false'
  75. with gr.Accordion('Tiled Diffusion', open=False):
  76. with gr.Row(variant='compact'):
  77. enabled = gr.Checkbox(label='Enable', value=False)
  78. method = gr.Dropdown(label='Method', choices=[e.value for e in Method], value=Method.MULTI_DIFF.value)
  79. with gr.Row(variant='compact', visible=False) as tab_size:
  80. image_width = gr.Slider(minimum=256, maximum=16384, step=16, label='Image width', value=1024,
  81. elem_id=f'MD-overwrite-width-{tab}')
  82. image_height = gr.Slider(minimum=256, maximum=16384, step=16, label='Image height', value=1024,
  83. elem_id=f'MD-overwrite-height-{tab}')
  84. with gr.Group():
  85. with gr.Row(variant='compact'):
  86. tile_width = gr.Slider(minimum=16, maximum=256, step=16, label='Latent tile width', value=96,
  87. elem_id=self.elem_id("latent_tile_width"))
  88. tile_height = gr.Slider(minimum=16, maximum=256, step=16, label='Latent tile height', value=96,
  89. elem_id=self.elem_id("latent_tile_height"))
  90. with gr.Row(variant='compact'):
  91. overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent tile overlap', value=48,
  92. elem_id=self.elem_id("latent_overlap"))
  93. batch_size = gr.Slider(
  94. minimum=1, maximum=8, step=1, label='Latent tile batch size', value=1)
  95. with gr.Row(variant='compact', visible=is_img2img):
  96. upscaler_index = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value="None",
  97. elem_id='MD-upscaler-index')
  98. scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label='Scale Factor', value=2.0,
  99. elem_id='MD-upscaler-factor')
  100. with gr.Row(variant='compact'):
  101. overwrite_image_size = gr.Checkbox(
  102. label='Overwrite image size', value=False, visible=(not is_img2img))
  103. keep_input_size = gr.Checkbox(
  104. label='Keep input image size', value=True, visible=(is_img2img))
  105. if not is_img2img:
  106. overwrite_image_size.change(fn=lambda x: gr_show(
  107. x), inputs=overwrite_image_size, outputs=tab_size)
  108. control_tensor_cpu = gr.Checkbox(
  109. label='Move ControlNet images to CPU (if applicable)', value=False)
  110. # The control includes txt2img and img2img, we use t2i and i2i to distinguish them
  111. with gr.Group(variant='panel', elem_id=f'MD-bbox-control-{tab}'):
  112. with gr.Accordion('Region Prompt Control', open=False):
  113. with gr.Row(variant='compact'):
  114. enable_bbox_control = gr.Checkbox(
  115. label='Enable', value=False)
  116. global_multiplier = gr.Slider(
  117. minimum=0, maximum=10, step=0.1, label='Background Multiplier', value=1, interactive=True)
  118. with gr.Row(variant='compact'):
  119. create_button = gr.Button(
  120. value="Create txt2img canvas" if not is_img2img else "From img2img")
  121. bbox_controls = [] # control set for each bbox
  122. with gr.Row(variant='compact'):
  123. ref_image = gr.Image(label='Ref image (for conviently locate regions)', image_mode=None,
  124. elem_id=f'MD-bbox-ref-{tab}', interactive=True)
  125. if not is_img2img:
  126. # gradio has a serious bug: it cannot accept multiple inputs when you use both js and fn.
  127. # to workaround this, we concat the inputs into a single string and parse it in js
  128. def create_t2i_ref(string):
  129. w, h = [int(x) for x in string.split('x')]
  130. w = max(w, 8)
  131. h = max(h, 8)
  132. return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
  133. create_button.click(
  134. fn=create_t2i_ref,
  135. inputs=overwrite_image_size,
  136. outputs=ref_image,
  137. _js='onCreateT2IRefClick')
  138. else:
  139. create_button.click(fn=None, outputs=ref_image, _js='onCreateI2IRefClick')
  140. for i in range(BBOX_MAX_NUM):
  141. with gr.Accordion(f'Region {i+1}', open=False):
  142. with gr.Row(variant='compact'):
  143. e = gr.Checkbox(label='Enable', value=False, elem_id=f'MD-enable-{i}')
  144. e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)')
  145. m = gr.Slider(label='Multiplier', value=1, minimum=0, maximum=10, step=0.1,
  146. interactive=True, elem_id=f'MD-mt-{i}')
  147. with gr.Row(variant='compact'):
  148. x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.01,
  149. interactive=True, elem_id=f'MD-{tab}-{i}-x')
  150. y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.01,
  151. interactive=True, elem_id=f'MD-{tab}-{i}-y')
  152. w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.01,
  153. interactive=True, elem_id=f'MD-{tab}-{i}-w')
  154. h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.01,
  155. interactive=True, elem_id=f'MD-{tab}-{i}-h')
  156. x.change(fn=None, inputs=x, outputs=x, _js=f'(v) => onBoxChange({is_t2i}, {i}, "x", v)')
  157. y.change(fn=None, inputs=y, outputs=y, _js=f'(v) => onBoxChange({is_t2i}, {i}, "y", v)')
  158. w.change(fn=None, inputs=w, outputs=w, _js=f'(v) => onBoxChange({is_t2i}, {i}, "w", v)')
  159. h.change(fn=None, inputs=h, outputs=h, _js=f'(v) => onBoxChange({is_t2i}, {i}, "h", v)')
  160. p = gr.Text(show_label=False, placeholder=f'Prompt, will append to your {tab} prompt',
  161. max_lines=2, elem_id=f'MD-p-{i}')
  162. neg = gr.Text(show_label=False, placeholder=f'Negative Prompt, will also be appended',
  163. max_lines=1, elem_id=f'MD-n-{i}')
  164. bbox_controls.append((e, m, x, y, w, h, p, neg))
  165. controls = [
  166. enabled, method,
  167. overwrite_image_size, keep_input_size, image_width, image_height,
  168. tile_width, tile_height, overlap, batch_size,
  169. upscaler_index, scale_factor,
  170. control_tensor_cpu,
  171. enable_bbox_control,
  172. global_multiplier
  173. ]
  174. for i in range(BBOX_MAX_NUM): controls.extend(bbox_controls[i])
  175. return controls
  176. def process(self, p: StableDiffusionProcessing,
  177. enabled: bool, method: str,
  178. overwrite_image_size: bool, keep_input_size: bool, image_width: int, image_height: int,
  179. tile_width: int, tile_height: int, overlap: int, tile_batch_size: int,
  180. upscaler_index: str, scale_factor: float,
  181. control_tensor_cpu: bool, enable_bbox_control: bool, global_multiplier: float,
  182. *bbox_control_states
  183. ):
  184. if hasattr(sd_samplers, "md_org_create_sampler"):
  185. sd_samplers.create_sampler = sd_samplers.md_org_create_sampler
  186. del sd_samplers.md_org_create_sampler
  187. MixtureOfDiffusers.unhook()
  188. if not enabled: return
  189. ''' upscale '''
  190. if hasattr(p, "init_images") and len(p.init_images) > 0: # img2img
  191. upscaler_name = [x.name for x in shared.sd_upscalers].index(upscaler_index)
  192. init_img = p.init_images[0]
  193. init_img = images.flatten(init_img, opts.img2img_background_color)
  194. upscaler = shared.sd_upscalers[upscaler_name]
  195. if upscaler.name != "None":
  196. print(f"[Tiled Diffusion] upscaling image with {upscaler.name}...")
  197. image = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
  198. p.extra_generation_params["Tiled Diffusion upscaler"] = upscaler.name
  199. p.extra_generation_params["Tiled Diffusion scale factor"] = scale_factor
  200. else:
  201. image = init_img
  202. p.init_images[0] = image
  203. if keep_input_size:
  204. p.width = image.width
  205. p.height = image.height
  206. elif upscaler.name != "None":
  207. p.width *= scale_factor
  208. p.height *= scale_factor
  209. elif overwrite_image_size: # txt2img
  210. p.width = image_width
  211. p.height = image_height
  212. ''' sanitiy check '''
  213. if not splitable(p.width, p.height, tile_width, tile_height, overlap):
  214. print("[Tiled Diffusion] ignore due to image too small or tile size too large.")
  215. return
  216. p.extra_generation_params["Tiled Diffusion method"] = method
  217. p.extra_generation_params["Tiled Diffusion tile width"] = tile_width
  218. p.extra_generation_params["Tiled Diffusion tile height"] = tile_height
  219. p.extra_generation_params["Tiled Diffusion overlap"] = overlap
  220. p.extra_generation_params["Tiled Diffusion batch size"] = tile_batch_size
  221. def process_batch(self, p: StableDiffusionProcessing,
  222. enabled: bool, method: str,
  223. overwrite_image_size: bool, keep_input_size: bool, image_width: int, image_height: int,
  224. tile_width: int, tile_height: int, overlap: int, tile_batch_size: int,
  225. upscaler_index: str, scale_factor: float,
  226. control_tensor_cpu: bool, enable_bbox_control: bool, global_multiplier: float,
  227. *bbox_control_states, batch_number, prompts, seeds, subseeds):
  228. '''
  229. compatible with the webui batch processing
  230. '''
  231. if not enabled: return
  232. method: Method = Method(method)
  233. n = batch_number
  234. ''' ControlNet hackin '''
  235. # try to hook into controlnet tensors
  236. controlnet_script = None
  237. try:
  238. from scripts.cldm import ControlNet
  239. # fix controlnet multi-batch issue
  240. def align(self, hint, h, w):
  241. if (len(hint.shape) == 3):
  242. hint = hint.unsqueeze(0)
  243. _, _, h1, w1 = hint.shape
  244. if h != h1 or w != w1:
  245. hint = torch.nn.functional.interpolate(hint, size=(h, w), mode="nearest")
  246. return hint
  247. ControlNet.align = align
  248. for script in p.scripts.scripts + p.scripts.alwayson_scripts:
  249. if hasattr(script, "latest_network") and script.title().lower() == "controlnet":
  250. controlnet_script = script
  251. print("[Tiled Diffusion] ControlNet found, MultiDiffusion-ControlNet support is enabled.")
  252. break
  253. except ImportError:
  254. pass
  255. ''' sampler hijack '''
  256. # custom sampler
  257. def create_sampler(name, model):
  258. # create the sampler with the original function
  259. sampler = sd_samplers.md_org_create_sampler(name, model)
  260. # unhook the create_sampler function
  261. if method == Method.MULTI_DIFF:
  262. delegate = MultiDiffusion(
  263. sampler, p.sampler_name,
  264. p.batch_size, p.steps, p.width, p.height,
  265. tile_width, tile_height, overlap, tile_batch_size,
  266. controlnet_script=controlnet_script,
  267. control_tensor_cpu=control_tensor_cpu
  268. )
  269. elif method == Method.MIX_DIFF:
  270. delegate = MixtureOfDiffusers(
  271. sampler, p.sampler_name,
  272. p.batch_size, p.steps, p.width, p.height,
  273. tile_width, tile_height, overlap, tile_batch_size,
  274. controlnet_script=controlnet_script,
  275. control_tensor_cpu=control_tensor_cpu
  276. )
  277. delegate.hook()
  278. else:
  279. raise NotImplementedError(f"Method {method} not implemented.")
  280. if enable_bbox_control:
  281. neg_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  282. delegate.init_custom_bbox(global_multiplier, bbox_control_states, prompts, neg_prompts)
  283. print(f"{method.value} hooked into {p.sampler_name} sampler. " +
  284. f"Tile size: {tile_width}x{tile_height}, " +
  285. f"Tile batches: {len(delegate.batched_bboxes)}, " +
  286. f"Batch size:", tile_batch_size)
  287. return sampler
  288. # hack the create_sampler function to get the created sampler
  289. if not hasattr(sd_samplers, "md_org_create_sampler"):
  290. setattr(sd_samplers, "md_org_create_sampler", sd_samplers.create_sampler)
  291. sd_samplers.create_sampler = create_sampler