img2imgalt.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from collections import namedtuple
  2. import numpy as np
  3. from tqdm import trange
  4. import modules.scripts as scripts
  5. import gradio as gr
  6. from modules import processing, shared, sd_samplers, sd_samplers_common
  7. import torch
  8. import k_diffusion as K
  9. def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
  10. x = p.init_latent
  11. s_in = x.new_ones([x.shape[0]])
  12. if shared.sd_model.parameterization == "v":
  13. dnw = K.external.CompVisVDenoiser(shared.sd_model)
  14. skip = 1
  15. else:
  16. dnw = K.external.CompVisDenoiser(shared.sd_model)
  17. skip = 0
  18. sigmas = dnw.get_sigmas(steps).flip(0)
  19. shared.state.sampling_steps = steps
  20. for i in trange(1, len(sigmas)):
  21. shared.state.sampling_step += 1
  22. x_in = torch.cat([x] * 2)
  23. sigma_in = torch.cat([sigmas[i] * s_in] * 2)
  24. cond_in = torch.cat([uncond, cond])
  25. image_conditioning = torch.cat([p.image_conditioning] * 2)
  26. cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
  27. c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
  28. t = dnw.sigma_to_t(sigma_in)
  29. eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
  30. denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
  31. denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
  32. d = (x - denoised) / sigmas[i]
  33. dt = sigmas[i] - sigmas[i - 1]
  34. x = x + d * dt
  35. sd_samplers_common.store_latent(x)
  36. # This shouldn't be necessary, but solved some VRAM issues
  37. del x_in, sigma_in, cond_in, c_out, c_in, t,
  38. del eps, denoised_uncond, denoised_cond, denoised, d, dt
  39. shared.state.nextjob()
  40. return x / x.std()
  41. Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
  42. # Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
  43. def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
  44. x = p.init_latent
  45. s_in = x.new_ones([x.shape[0]])
  46. if shared.sd_model.parameterization == "v":
  47. dnw = K.external.CompVisVDenoiser(shared.sd_model)
  48. skip = 1
  49. else:
  50. dnw = K.external.CompVisDenoiser(shared.sd_model)
  51. skip = 0
  52. sigmas = dnw.get_sigmas(steps).flip(0)
  53. shared.state.sampling_steps = steps
  54. for i in trange(1, len(sigmas)):
  55. shared.state.sampling_step += 1
  56. x_in = torch.cat([x] * 2)
  57. sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
  58. cond_in = torch.cat([uncond, cond])
  59. image_conditioning = torch.cat([p.image_conditioning] * 2)
  60. cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
  61. c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
  62. if i == 1:
  63. t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
  64. else:
  65. t = dnw.sigma_to_t(sigma_in)
  66. eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
  67. denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
  68. denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
  69. if i == 1:
  70. d = (x - denoised) / (2 * sigmas[i])
  71. else:
  72. d = (x - denoised) / sigmas[i - 1]
  73. dt = sigmas[i] - sigmas[i - 1]
  74. x = x + d * dt
  75. sd_samplers_common.store_latent(x)
  76. # This shouldn't be necessary, but solved some VRAM issues
  77. del x_in, sigma_in, cond_in, c_out, c_in, t,
  78. del eps, denoised_uncond, denoised_cond, denoised, d, dt
  79. shared.state.nextjob()
  80. return x / sigmas[-1]
  81. class Script(scripts.Script):
  82. def __init__(self):
  83. self.cache = None
  84. def title(self):
  85. return "img2img alternative test"
  86. def show(self, is_img2img):
  87. return is_img2img
  88. def ui(self, is_img2img):
  89. info = gr.Markdown('''
  90. * `CFG Scale` should be 2 or lower.
  91. ''')
  92. override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler"))
  93. override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt"))
  94. original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt"))
  95. original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt"))
  96. override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps"))
  97. st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st"))
  98. override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength"))
  99. cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
  100. randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
  101. sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
  102. return [
  103. info,
  104. override_sampler,
  105. override_prompt, original_prompt, original_negative_prompt,
  106. override_steps, st,
  107. override_strength,
  108. cfg, randomness, sigma_adjustment,
  109. ]
  110. def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
  111. # Override
  112. if override_sampler:
  113. p.sampler_name = "Euler"
  114. if override_prompt:
  115. p.prompt = original_prompt
  116. p.negative_prompt = original_negative_prompt
  117. if override_steps:
  118. p.steps = st
  119. if override_strength:
  120. p.denoising_strength = 1.0
  121. def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
  122. lat = (p.init_latent.cpu().numpy() * 10).astype(int)
  123. same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
  124. and self.cache.original_prompt == original_prompt \
  125. and self.cache.original_negative_prompt == original_negative_prompt \
  126. and self.cache.sigma_adjustment == sigma_adjustment
  127. same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
  128. if same_everything:
  129. rec_noise = self.cache.noise
  130. else:
  131. shared.state.job_count += 1
  132. cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
  133. uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
  134. if sigma_adjustment:
  135. rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
  136. else:
  137. rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
  138. self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
  139. rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
  140. combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
  141. sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
  142. sigmas = sampler.model_wrap.get_sigmas(p.steps)
  143. noise_dt = combined_noise - (p.init_latent / sigmas[0])
  144. p.seed = p.seed + 1
  145. return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
  146. p.sample = sample_extra
  147. p.extra_generation_params["Decode prompt"] = original_prompt
  148. p.extra_generation_params["Decode negative prompt"] = original_negative_prompt
  149. p.extra_generation_params["Decode CFG scale"] = cfg
  150. p.extra_generation_params["Decode steps"] = st
  151. p.extra_generation_params["Randomness"] = randomness
  152. p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment
  153. processed = processing.process_images(p)
  154. return processed