outpainting_mk_2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import math
  2. import numpy as np
  3. import skimage
  4. import modules.scripts as scripts
  5. import gradio as gr
  6. from PIL import Image, ImageDraw
  7. from modules import images, processing, devices
  8. from modules.processing import Processed, process_images
  9. from modules.shared import opts, cmd_opts, state
  10. # this function is taken from https://github.com/parlance-zz/g-diffuser-bot
  11. def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
  12. # helper fft routines that keep ortho normalization and auto-shift before and after fft
  13. def _fft2(data):
  14. if data.ndim > 2: # has channels
  15. out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
  16. for c in range(data.shape[2]):
  17. c_data = data[:, :, c]
  18. out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
  19. out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
  20. else: # one channel
  21. out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
  22. out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
  23. out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
  24. return out_fft
  25. def _ifft2(data):
  26. if data.ndim > 2: # has channels
  27. out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
  28. for c in range(data.shape[2]):
  29. c_data = data[:, :, c]
  30. out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
  31. out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
  32. else: # one channel
  33. out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
  34. out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
  35. out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
  36. return out_ifft
  37. def _get_gaussian_window(width, height, std=3.14, mode=0):
  38. window_scale_x = float(width / min(width, height))
  39. window_scale_y = float(height / min(width, height))
  40. window = np.zeros((width, height))
  41. x = (np.arange(width) / width * 2. - 1.) * window_scale_x
  42. for y in range(height):
  43. fy = (y / height * 2. - 1.) * window_scale_y
  44. if mode == 0:
  45. window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
  46. else:
  47. window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
  48. return window
  49. def _get_masked_window_rgb(np_mask_grey, hardness=1.):
  50. np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
  51. if hardness != 1.:
  52. hardened = np_mask_grey[:] ** hardness
  53. else:
  54. hardened = np_mask_grey[:]
  55. for c in range(3):
  56. np_mask_rgb[:, :, c] = hardened[:]
  57. return np_mask_rgb
  58. width = _np_src_image.shape[0]
  59. height = _np_src_image.shape[1]
  60. num_channels = _np_src_image.shape[2]
  61. np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
  62. np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
  63. img_mask = np_mask_grey > 1e-6
  64. ref_mask = np_mask_grey < 1e-3
  65. windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
  66. windowed_image /= np.max(windowed_image)
  67. windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
  68. src_fft = _fft2(windowed_image) # get feature statistics from masked src img
  69. src_dist = np.absolute(src_fft)
  70. src_phase = src_fft / src_dist
  71. # create a generator with a static seed to make outpainting deterministic / only follow global seed
  72. rng = np.random.default_rng(0)
  73. noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
  74. noise_rgb = rng.random((width, height, num_channels))
  75. noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
  76. noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
  77. for c in range(num_channels):
  78. noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
  79. noise_fft = _fft2(noise_rgb)
  80. for c in range(num_channels):
  81. noise_fft[:, :, c] *= noise_window
  82. noise_rgb = np.real(_ifft2(noise_fft))
  83. shaped_noise_fft = _fft2(noise_rgb)
  84. shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
  85. brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
  86. contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
  87. # scikit-image is used for histogram matching, very convenient!
  88. shaped_noise = np.real(_ifft2(shaped_noise_fft))
  89. shaped_noise -= np.min(shaped_noise)
  90. shaped_noise /= np.max(shaped_noise)
  91. shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
  92. shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
  93. matched_noise = shaped_noise[:]
  94. return np.clip(matched_noise, 0., 1.)
  95. class Script(scripts.Script):
  96. def title(self):
  97. return "Outpainting mk2"
  98. def show(self, is_img2img):
  99. return is_img2img
  100. def ui(self, is_img2img):
  101. if not is_img2img:
  102. return None
  103. info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
  104. pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
  105. mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
  106. direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
  107. noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
  108. color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
  109. return [info, pixels, mask_blur, direction, noise_q, color_variation]
  110. def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
  111. initial_seed_and_info = [None, None]
  112. process_width = p.width
  113. process_height = p.height
  114. p.mask_blur = mask_blur*4
  115. p.inpaint_full_res = False
  116. p.inpainting_fill = 1
  117. p.do_not_save_samples = True
  118. p.do_not_save_grid = True
  119. left = pixels if "left" in direction else 0
  120. right = pixels if "right" in direction else 0
  121. up = pixels if "up" in direction else 0
  122. down = pixels if "down" in direction else 0
  123. init_img = p.init_images[0]
  124. target_w = math.ceil((init_img.width + left + right) / 64) * 64
  125. target_h = math.ceil((init_img.height + up + down) / 64) * 64
  126. if left > 0:
  127. left = left * (target_w - init_img.width) // (left + right)
  128. if right > 0:
  129. right = target_w - init_img.width - left
  130. if up > 0:
  131. up = up * (target_h - init_img.height) // (up + down)
  132. if down > 0:
  133. down = target_h - init_img.height - up
  134. def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
  135. is_horiz = is_left or is_right
  136. is_vert = is_top or is_bottom
  137. pixels_horiz = expand_pixels if is_horiz else 0
  138. pixels_vert = expand_pixels if is_vert else 0
  139. images_to_process = []
  140. output_images = []
  141. for n in range(count):
  142. res_w = init[n].width + pixels_horiz
  143. res_h = init[n].height + pixels_vert
  144. process_res_w = math.ceil(res_w / 64) * 64
  145. process_res_h = math.ceil(res_h / 64) * 64
  146. img = Image.new("RGB", (process_res_w, process_res_h))
  147. img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
  148. mask = Image.new("RGB", (process_res_w, process_res_h), "white")
  149. draw = ImageDraw.Draw(mask)
  150. draw.rectangle((
  151. expand_pixels + mask_blur if is_left else 0,
  152. expand_pixels + mask_blur if is_top else 0,
  153. mask.width - expand_pixels - mask_blur if is_right else res_w,
  154. mask.height - expand_pixels - mask_blur if is_bottom else res_h,
  155. ), fill="black")
  156. np_image = (np.asarray(img) / 255.0).astype(np.float64)
  157. np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
  158. noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
  159. output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
  160. target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
  161. target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
  162. p.width = target_width if is_horiz else img.width
  163. p.height = target_height if is_vert else img.height
  164. crop_region = (
  165. 0 if is_left else output_images[n].width - target_width,
  166. 0 if is_top else output_images[n].height - target_height,
  167. target_width if is_left else output_images[n].width,
  168. target_height if is_top else output_images[n].height,
  169. )
  170. mask = mask.crop(crop_region)
  171. p.image_mask = mask
  172. image_to_process = output_images[n].crop(crop_region)
  173. images_to_process.append(image_to_process)
  174. p.init_images = images_to_process
  175. latent_mask = Image.new("RGB", (p.width, p.height), "white")
  176. draw = ImageDraw.Draw(latent_mask)
  177. draw.rectangle((
  178. expand_pixels + mask_blur * 2 if is_left else 0,
  179. expand_pixels + mask_blur * 2 if is_top else 0,
  180. mask.width - expand_pixels - mask_blur * 2 if is_right else res_w,
  181. mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h,
  182. ), fill="black")
  183. p.latent_mask = latent_mask
  184. proc = process_images(p)
  185. if initial_seed_and_info[0] is None:
  186. initial_seed_and_info[0] = proc.seed
  187. initial_seed_and_info[1] = proc.info
  188. for n in range(count):
  189. output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
  190. output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
  191. return output_images
  192. batch_count = p.n_iter
  193. batch_size = p.batch_size
  194. p.n_iter = 1
  195. state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
  196. all_processed_images = []
  197. for i in range(batch_count):
  198. imgs = [init_img] * batch_size
  199. state.job = f"Batch {i + 1} out of {batch_count}"
  200. if left > 0:
  201. imgs = expand(imgs, batch_size, left, is_left=True)
  202. if right > 0:
  203. imgs = expand(imgs, batch_size, right, is_right=True)
  204. if up > 0:
  205. imgs = expand(imgs, batch_size, up, is_top=True)
  206. if down > 0:
  207. imgs = expand(imgs, batch_size, down, is_bottom=True)
  208. all_processed_images += imgs
  209. all_images = all_processed_images
  210. combined_grid_image = images.image_grid(all_processed_images)
  211. unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
  212. if opts.return_grid and not unwanted_grid_because_of_img_count:
  213. all_images = [combined_grid_image] + all_processed_images
  214. res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
  215. if opts.samples_save:
  216. for img in all_processed_images:
  217. images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
  218. if opts.grid_save and not unwanted_grid_because_of_img_count:
  219. images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
  220. return res