hook.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. import torch
  2. import einops
  3. import hashlib
  4. import numpy as np
  5. import torch.nn as nn
  6. import modules.processing
  7. from enum import Enum
  8. from scripts.logging import logger
  9. from modules import devices, lowvram, shared, scripts
  10. cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)
  11. from ldm.modules.diffusionmodules.util import timestep_embedding
  12. from ldm.modules.diffusionmodules.openaimodel import UNetModel
  13. from ldm.modules.attention import BasicTransformerBlock
  14. from ldm.models.diffusion.ddpm import extract_into_tensor
  15. from modules.prompt_parser import MulticondLearnedConditioning, ComposableScheduledPromptConditioning, ScheduledPromptConditioning
  16. from modules.processing import StableDiffusionProcessing
  17. POSITIVE_MARK_TOKEN = 1024
  18. NEGATIVE_MARK_TOKEN = - POSITIVE_MARK_TOKEN
  19. MARK_EPS = 1e-3
  20. def prompt_context_is_marked(x):
  21. t = x[..., 0, :]
  22. m = torch.abs(t) - POSITIVE_MARK_TOKEN
  23. m = torch.mean(torch.abs(m)).detach().cpu().float().numpy()
  24. return float(m) < MARK_EPS
  25. def mark_prompt_context(x, positive):
  26. if isinstance(x, list):
  27. for i in range(len(x)):
  28. x[i] = mark_prompt_context(x[i], positive)
  29. return x
  30. if isinstance(x, MulticondLearnedConditioning):
  31. x.batch = mark_prompt_context(x.batch, positive)
  32. return x
  33. if isinstance(x, ComposableScheduledPromptConditioning):
  34. x.schedules = mark_prompt_context(x.schedules, positive)
  35. return x
  36. if isinstance(x, ScheduledPromptConditioning):
  37. cond = x.cond
  38. if prompt_context_is_marked(cond):
  39. return x
  40. mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
  41. cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
  42. return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)
  43. return x
  44. disable_controlnet_prompt_warning = True
  45. # You can disable this warning using disable_controlnet_prompt_warning.
  46. def unmark_prompt_context(x):
  47. if not prompt_context_is_marked(x):
  48. # ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt).
  49. # You can use the hook.py's `mark_prompt_context` to mark the prompts that will be seen by ControlNet.
  50. # Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components,
  51. # if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True)
  52. # if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False)
  53. # After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
  54. # After you mark the prompts, the mismatch errors will disappear.
  55. if not disable_controlnet_prompt_warning:
  56. logger.warning('ControlNet Error: Failed to detect whether an instance is cond or uncond!')
  57. logger.warning('ControlNet Error: This is mainly because other extension(s) blocked A1111\'s \"process.sample()\" and deleted ControlNet\'s sample function.')
  58. logger.warning('ControlNet Error: ControlNet will shift to a backup backend but the results will be worse than expectation.')
  59. logger.warning('Solution (For extension developers): Take a look at ControlNet\' hook.py '
  60. 'UnetHook.hook.process_sample and manually call mark_prompt_context to mark cond/uncond prompts.')
  61. mark_batch = torch.ones(size=(x.shape[0], 1, 1, 1), dtype=x.dtype, device=x.device)
  62. uc_indices = []
  63. context = x
  64. return mark_batch, uc_indices, context
  65. mark = x[:, 0, :]
  66. context = x[:, 1:, :]
  67. mark = torch.mean(torch.abs(mark - NEGATIVE_MARK_TOKEN), dim=1)
  68. mark = (mark > MARK_EPS).float()
  69. mark_batch = mark[:, None, None, None].to(x.dtype).to(x.device)
  70. uc_indices = mark.detach().cpu().numpy().tolist()
  71. uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5]
  72. StableDiffusionProcessing.cached_c = [None, None]
  73. StableDiffusionProcessing.cached_uc = [None, None]
  74. return mark_batch, uc_indices, context
  75. def create_random_tensors_hacked(*args, **kwargs):
  76. result = modules.processing.create_random_tensors_original(*args, **kwargs)
  77. p = kwargs.get('p', None)
  78. if p is None:
  79. return result
  80. controlnet_initial_noise_modifier = getattr(p, 'controlnet_initial_noise_modifier', None)
  81. if controlnet_initial_noise_modifier is not None:
  82. x0 = controlnet_initial_noise_modifier
  83. if result.shape[2] != x0.shape[2] or result.shape[3] != x0.shape[3]:
  84. return result
  85. x0 = x0.to(result.dtype).to(result.device)
  86. ts = torch.tensor([p.sd_model.num_timesteps - 1] * result.shape[0]).long().to(result.device)
  87. result = p.sd_model.q_sample(x0, ts, result)
  88. logger.info(f'[ControlNet] Initial noise hack applied to {result.shape}.')
  89. return result
  90. if getattr(modules.processing, 'create_random_tensors_original', None) is None:
  91. modules.processing.create_random_tensors_original = modules.processing.create_random_tensors
  92. modules.processing.create_random_tensors = create_random_tensors_hacked
  93. class ControlModelType(Enum):
  94. """
  95. The type of Control Models (supported or not).
  96. """
  97. ControlNet = "ControlNet, Lvmin Zhang"
  98. T2I_Adapter = "T2I_Adapter, Chong Mou"
  99. T2I_StyleAdapter = "T2I_StyleAdapter, Chong Mou"
  100. T2I_CoAdapter = "T2I_CoAdapter, Chong Mou"
  101. MasaCtrl = "MasaCtrl, Mingdeng Cao"
  102. GLIGEN = "GLIGEN, Yuheng Li"
  103. AttentionInjection = "AttentionInjection, Lvmin Zhang" # A simple attention injection written by Lvmin
  104. StableSR = "StableSR, Jianyi Wang"
  105. PromptDiffusion = "PromptDiffusion, Zhendong Wang"
  106. ControlLoRA = "ControlLoRA, Wu Hecong"
  107. # Written by Lvmin
  108. class AutoMachine(Enum):
  109. """
  110. Lvmin's algorithm for Attention/AdaIn AutoMachine States.
  111. """
  112. Read = "Read"
  113. Write = "Write"
  114. class TorchHijackForUnet:
  115. """
  116. This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
  117. this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
  118. """
  119. def __getattr__(self, item):
  120. if item == 'cat':
  121. return self.cat
  122. if hasattr(torch, item):
  123. return getattr(torch, item)
  124. raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
  125. def cat(self, tensors, *args, **kwargs):
  126. if len(tensors) == 2:
  127. a, b = tensors
  128. if a.shape[-2:] != b.shape[-2:]:
  129. a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
  130. tensors = (a, b)
  131. return torch.cat(tensors, *args, **kwargs)
  132. th = TorchHijackForUnet()
  133. class ControlParams:
  134. def __init__(
  135. self,
  136. control_model,
  137. preprocessor,
  138. hint_cond,
  139. weight,
  140. guidance_stopped,
  141. start_guidance_percent,
  142. stop_guidance_percent,
  143. advanced_weighting,
  144. control_model_type,
  145. hr_hint_cond,
  146. global_average_pooling,
  147. soft_injection,
  148. cfg_injection,
  149. **kwargs # To avoid errors
  150. ):
  151. self.control_model = control_model
  152. self.preprocessor = preprocessor
  153. self._hint_cond = hint_cond
  154. self.weight = weight
  155. self.guidance_stopped = guidance_stopped
  156. self.start_guidance_percent = start_guidance_percent
  157. self.stop_guidance_percent = stop_guidance_percent
  158. self.advanced_weighting = advanced_weighting
  159. self.control_model_type = control_model_type
  160. self.global_average_pooling = global_average_pooling
  161. self.hr_hint_cond = hr_hint_cond
  162. self.used_hint_cond = None
  163. self.used_hint_cond_latent = None
  164. self.used_hint_inpaint_hijack = None
  165. self.soft_injection = soft_injection
  166. self.cfg_injection = cfg_injection
  167. @property
  168. def hint_cond(self):
  169. return self._hint_cond
  170. # fix for all the extensions that modify hint_cond,
  171. # by forcing used_hint_cond to update on the next timestep
  172. # hr_hint_cond can stay the same, since most extensions dont modify the hires pass
  173. # but if they do, it will cause problems
  174. @hint_cond.setter
  175. def hint_cond(self, new_hint_cond):
  176. self._hint_cond = new_hint_cond
  177. self.used_hint_cond = None
  178. self.used_hint_cond_latent = None
  179. self.used_hint_inpaint_hijack = None
  180. def aligned_adding(base, x, require_channel_alignment):
  181. if isinstance(x, float):
  182. if x == 0.0:
  183. return base
  184. return base + x
  185. if require_channel_alignment:
  186. zeros = torch.zeros_like(base)
  187. zeros[:, :x.shape[1], ...] = x
  188. x = zeros
  189. # resize to sample resolution
  190. base_h, base_w = base.shape[-2:]
  191. xh, xw = x.shape[-2:]
  192. if xh > 1 or xw > 1:
  193. if base_h != xh or base_w != xw:
  194. # logger.info('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.')
  195. x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
  196. return base + x
  197. # DFS Search for Torch.nn.Module, Written by Lvmin
  198. def torch_dfs(model: torch.nn.Module):
  199. result = [model]
  200. for child in model.children():
  201. result += torch_dfs(child)
  202. return result
  203. def predict_start_from_noise(ldm, x_t, t, noise):
  204. return extract_into_tensor(ldm.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(ldm.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
  205. def predict_noise_from_start(ldm, x_t, t, x0):
  206. return (extract_into_tensor(ldm.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract_into_tensor(ldm.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
  207. def blur(x, k):
  208. y = torch.nn.functional.pad(x, (k, k, k, k), mode='replicate')
  209. y = torch.nn.functional.avg_pool2d(y, (k*2+1, k*2+1), stride=(1, 1))
  210. return y
  211. class TorchCache:
  212. def __init__(self):
  213. self.cache = {}
  214. def hash(self, key):
  215. v = key.detach().cpu().numpy().astype(np.float32)
  216. v = (v * 1000.0).astype(np.int32)
  217. v = np.ascontiguousarray(v.copy())
  218. sha = hashlib.sha1(v).hexdigest()
  219. return sha
  220. def get(self, key):
  221. key = self.hash(key)
  222. return self.cache.get(key, None)
  223. def set(self, key, value):
  224. self.cache[self.hash(key)] = value
  225. class UnetHook(nn.Module):
  226. def __init__(self, lowvram=False) -> None:
  227. super().__init__()
  228. self.lowvram = lowvram
  229. self.model = None
  230. self.sd_ldm = None
  231. self.control_params = None
  232. self.attention_auto_machine = AutoMachine.Read
  233. self.attention_auto_machine_weight = 1.0
  234. self.gn_auto_machine = AutoMachine.Read
  235. self.gn_auto_machine_weight = 1.0
  236. self.current_style_fidelity = 0.0
  237. self.current_uc_indices = None
  238. @staticmethod
  239. def call_vae_using_process(p, x, batch_size=None, mask=None):
  240. vae_cache = getattr(p, 'controlnet_vae_cache', None)
  241. if vae_cache is None:
  242. vae_cache = TorchCache()
  243. setattr(p, 'controlnet_vae_cache', vae_cache)
  244. try:
  245. if x.shape[1] > 3:
  246. x = x[:, 0:3, :, :]
  247. x = x * 2.0 - 1.0
  248. if mask is not None:
  249. x = x * (1.0 - mask)
  250. x = x.type(devices.dtype_vae)
  251. vae_output = vae_cache.get(x)
  252. if vae_output is None:
  253. with devices.autocast():
  254. vae_output = p.sd_model.encode_first_stage(x)
  255. vae_output = p.sd_model.get_first_stage_encoding(vae_output)
  256. vae_cache.set(x, vae_output)
  257. logger.info(f'ControlNet used {str(devices.dtype_vae)} VAE to encode {vae_output.shape}.')
  258. latent = vae_output
  259. if batch_size is not None and latent.shape[0] != batch_size:
  260. latent = torch.cat([latent.clone() for _ in range(batch_size)], dim=0)
  261. latent = latent.type(devices.dtype_unet)
  262. return latent
  263. except Exception as e:
  264. logger.error(e)
  265. raise ValueError('ControlNet failed to use VAE. Please try to add `--no-half-vae`, `--no-half` and remove `--precision full` in launch cmd.')
  266. def guidance_schedule_handler(self, x):
  267. for param in self.control_params:
  268. current_sampling_percent = (x.sampling_step / x.total_sampling_steps)
  269. param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent
  270. def hook(self, model, sd_ldm, control_params, process):
  271. self.model = model
  272. self.sd_ldm = sd_ldm
  273. self.control_params = control_params
  274. outer = self
  275. def process_sample(*args, **kwargs):
  276. # ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt).
  277. # You can use the hook.py's `mark_prompt_context` to mark the prompts that will be seen by ControlNet.
  278. # Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components,
  279. # if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True)
  280. # if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False)
  281. # After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
  282. # After you mark the prompts, the mismatch errors will disappear.
  283. mark_prompt_context(kwargs.get('conditioning', []), positive=True)
  284. mark_prompt_context(kwargs.get('unconditional_conditioning', []), positive=False)
  285. mark_prompt_context(getattr(process, 'hr_c', []), positive=True)
  286. mark_prompt_context(getattr(process, 'hr_uc', []), positive=False)
  287. return process.sample_before_CN_hack(*args, **kwargs)
  288. def forward(self, x, timesteps=None, context=None, **kwargs):
  289. total_controlnet_embedding = [0.0] * 13
  290. total_t2i_adapter_embedding = [0.0] * 4
  291. require_inpaint_hijack = False
  292. is_in_high_res_fix = False
  293. batch_size = int(x.shape[0])
  294. # Handle cond-uncond marker
  295. cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
  296. # logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))
  297. # High-res fix
  298. for param in outer.control_params:
  299. # select which hint_cond to use
  300. if param.used_hint_cond is None:
  301. param.used_hint_cond = param.hint_cond
  302. param.used_hint_cond_latent = None
  303. param.used_hint_inpaint_hijack = None
  304. # has high-res fix
  305. if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
  306. _, _, h_lr, w_lr = param.hint_cond.shape
  307. _, _, h_hr, w_hr = param.hr_hint_cond.shape
  308. _, _, h, w = x.shape
  309. h, w = h * 8, w * 8
  310. if abs(h - h_lr) < abs(h - h_hr):
  311. is_in_high_res_fix = False
  312. if param.used_hint_cond is not param.hint_cond:
  313. param.used_hint_cond = param.hint_cond
  314. param.used_hint_cond_latent = None
  315. param.used_hint_inpaint_hijack = None
  316. else:
  317. is_in_high_res_fix = True
  318. if param.used_hint_cond is not param.hr_hint_cond:
  319. param.used_hint_cond = param.hr_hint_cond
  320. param.used_hint_cond_latent = None
  321. param.used_hint_inpaint_hijack = None
  322. no_high_res_control = is_in_high_res_fix and shared.opts.data.get("control_net_no_high_res_fix", False)
  323. # Convert control image to latent
  324. for param in outer.control_params:
  325. if param.used_hint_cond_latent is not None:
  326. continue
  327. if param.control_model_type not in [ControlModelType.AttentionInjection] \
  328. and 'colorfix' not in param.preprocessor['name'] \
  329. and 'inpaint_only' not in param.preprocessor['name']:
  330. continue
  331. param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size)
  332. # handle prompt token control
  333. for param in outer.control_params:
  334. if no_high_res_control:
  335. continue
  336. if param.guidance_stopped:
  337. continue
  338. if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]:
  339. continue
  340. param.control_model.to(devices.get_device_for("controlnet"))
  341. control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
  342. control = torch.cat([control.clone() for _ in range(batch_size)], dim=0)
  343. control *= param.weight
  344. control *= cond_mark[:, :, :, 0]
  345. context = torch.cat([context, control.clone()], dim=1)
  346. # handle ControlNet / T2I_Adapter
  347. for param in outer.control_params:
  348. if no_high_res_control:
  349. continue
  350. if param.guidance_stopped:
  351. continue
  352. if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]:
  353. continue
  354. param.control_model.to(devices.get_device_for("controlnet"))
  355. # inpaint model workaround
  356. x_in = x
  357. control_model = param.control_model.control_model
  358. if param.control_model_type == ControlModelType.ControlNet:
  359. if x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9:
  360. # inpaint_model: 4 data + 4 downscaled image + 1 mask
  361. x_in = x[:, :4, ...]
  362. require_inpaint_hijack = True
  363. assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given"
  364. hint = param.used_hint_cond
  365. # ControlNet inpaint protocol
  366. if hint.shape[1] == 4:
  367. c = hint[:, 0:3, :, :]
  368. m = hint[:, 3:4, :, :]
  369. m = (m > 0.5).float()
  370. hint = c * (1 - m) - m
  371. control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context)
  372. control_scales = ([param.weight] * 13)
  373. if outer.lowvram:
  374. param.control_model.to("cpu")
  375. if param.cfg_injection or param.global_average_pooling:
  376. if param.control_model_type == ControlModelType.T2I_Adapter:
  377. control = [torch.cat([c.clone() for _ in range(batch_size)], dim=0) for c in control]
  378. control = [c * cond_mark for c in control]
  379. high_res_fix_forced_soft_injection = False
  380. if is_in_high_res_fix:
  381. if 'canny' in param.preprocessor['name']:
  382. high_res_fix_forced_soft_injection = True
  383. if 'mlsd' in param.preprocessor['name']:
  384. high_res_fix_forced_soft_injection = True
  385. # if high_res_fix_forced_soft_injection:
  386. # logger.info('[ControlNet] Forced soft_injection in high_res_fix in enabled.')
  387. if param.soft_injection or high_res_fix_forced_soft_injection:
  388. # important! use the soft weights with high-res fix can significantly reduce artifacts.
  389. if param.control_model_type == ControlModelType.T2I_Adapter:
  390. control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)]
  391. elif param.control_model_type == ControlModelType.ControlNet:
  392. control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]
  393. if param.advanced_weighting is not None:
  394. control_scales = param.advanced_weighting
  395. control = [c * scale for c, scale in zip(control, control_scales)]
  396. if param.global_average_pooling:
  397. control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
  398. for idx, item in enumerate(control):
  399. target = None
  400. if param.control_model_type == ControlModelType.ControlNet:
  401. target = total_controlnet_embedding
  402. if param.control_model_type == ControlModelType.T2I_Adapter:
  403. target = total_t2i_adapter_embedding
  404. if target is not None:
  405. target[idx] = item + target[idx]
  406. # Replace x_t to support inpaint models
  407. for param in outer.control_params:
  408. if param.used_hint_cond.shape[1] != 4:
  409. continue
  410. if x.shape[1] != 9:
  411. continue
  412. if param.used_hint_inpaint_hijack is None:
  413. mask_pixel = param.used_hint_cond[:, 3:4, :, :]
  414. image_pixel = param.used_hint_cond[:, 0:3, :, :]
  415. mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype)
  416. masked_latent = outer.call_vae_using_process(process, image_pixel, batch_size, mask=mask_pixel)
  417. mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8))
  418. if mask_latent.shape[0] != batch_size:
  419. mask_latent = torch.cat([mask_latent.clone() for _ in range(batch_size)], dim=0)
  420. param.used_hint_inpaint_hijack = torch.cat([mask_latent, masked_latent], dim=1)
  421. param.used_hint_inpaint_hijack.to(x.dtype).to(x.device)
  422. x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1)
  423. # A1111 fix for medvram.
  424. if shared.cmd_opts.medvram:
  425. try:
  426. # Trigger the register_forward_pre_hook
  427. outer.sd_ldm.model()
  428. except:
  429. pass
  430. # Clear attention and AdaIn cache
  431. for module in outer.attn_module_list:
  432. module.bank = []
  433. module.style_cfgs = []
  434. for module in outer.gn_module_list:
  435. module.mean_bank = []
  436. module.var_bank = []
  437. module.style_cfgs = []
  438. # Handle attention and AdaIn control
  439. for param in outer.control_params:
  440. if no_high_res_control:
  441. continue
  442. if param.guidance_stopped:
  443. continue
  444. if param.used_hint_cond_latent is None:
  445. continue
  446. if param.control_model_type not in [ControlModelType.AttentionInjection]:
  447. continue
  448. ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())
  449. # Inpaint Hijack
  450. if x.shape[1] == 9:
  451. ref_xt = torch.cat([
  452. ref_xt,
  453. torch.zeros_like(ref_xt)[:, 0:1, :, :],
  454. param.used_hint_cond_latent
  455. ], dim=1)
  456. outer.current_style_fidelity = float(param.preprocessor['threshold_a'])
  457. outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity))
  458. if param.cfg_injection:
  459. outer.current_style_fidelity = 1.0
  460. elif param.soft_injection or is_in_high_res_fix:
  461. outer.current_style_fidelity = 0.0
  462. control_name = param.preprocessor['name']
  463. if control_name in ['reference_only', 'reference_adain+attn']:
  464. outer.attention_auto_machine = AutoMachine.Write
  465. outer.attention_auto_machine_weight = param.weight
  466. if control_name in ['reference_adain', 'reference_adain+attn']:
  467. outer.gn_auto_machine = AutoMachine.Write
  468. outer.gn_auto_machine_weight = param.weight
  469. outer.original_forward(
  470. x=ref_xt.to(devices.dtype_unet),
  471. timesteps=timesteps.to(devices.dtype_unet),
  472. context=context.to(devices.dtype_unet)
  473. )
  474. outer.attention_auto_machine = AutoMachine.Read
  475. outer.gn_auto_machine = AutoMachine.Read
  476. # U-Net Encoder
  477. hs = []
  478. with th.no_grad():
  479. t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
  480. emb = self.time_embed(t_emb)
  481. h = x.type(self.dtype)
  482. for i, module in enumerate(self.input_blocks):
  483. h = module(h, emb, context)
  484. if (i + 1) % 3 == 0:
  485. h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
  486. hs.append(h)
  487. h = self.middle_block(h, emb, context)
  488. # U-Net Middle Block
  489. h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
  490. # U-Net Decoder
  491. for i, module in enumerate(self.output_blocks):
  492. h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
  493. h = module(h, emb, context)
  494. # U-Net Output
  495. h = h.type(x.dtype)
  496. h = self.out(h)
  497. # Post-processing for color fix
  498. for param in outer.control_params:
  499. if param.used_hint_cond_latent is None:
  500. continue
  501. if 'colorfix' not in param.preprocessor['name']:
  502. continue
  503. k = int(param.preprocessor['threshold_a'])
  504. if is_in_high_res_fix and not no_high_res_control:
  505. k *= 2
  506. # Inpaint hijack
  507. xt = x[:, :4, :, :]
  508. x0_origin = param.used_hint_cond_latent
  509. t = torch.round(timesteps.float()).long()
  510. x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
  511. x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k)
  512. if '+sharp' in param.preprocessor['name']:
  513. detail_weight = float(param.preprocessor['threshold_b']) * 0.01
  514. neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0
  515. x0 = cond_mark * x0 + (1 - cond_mark) * neg
  516. eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)
  517. w = max(0.0, min(1.0, float(param.weight)))
  518. h = eps_prd * w + h * (1 - w)
  519. # Post-processing for restore
  520. for param in outer.control_params:
  521. if param.used_hint_cond_latent is None:
  522. continue
  523. if 'inpaint_only' not in param.preprocessor['name']:
  524. continue
  525. if param.used_hint_cond.shape[1] != 4:
  526. continue
  527. # Inpaint hijack
  528. xt = x[:, :4, :, :]
  529. mask = param.used_hint_cond[:, 3:4, :, :]
  530. mask = torch.nn.functional.max_pool2d(mask, (10, 10), stride=(8, 8), padding=1)
  531. x0_origin = param.used_hint_cond_latent
  532. t = torch.round(timesteps.float()).long()
  533. x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
  534. x0 = x0_prd * mask + x0_origin * (1 - mask)
  535. eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)
  536. w = max(0.0, min(1.0, float(param.weight)))
  537. h = eps_prd * w + h * (1 - w)
  538. return h
  539. def forward_webui(*args, **kwargs):
  540. # webui will handle other compoments
  541. try:
  542. if shared.cmd_opts.lowvram:
  543. lowvram.send_everything_to_cpu()
  544. return forward(*args, **kwargs)
  545. finally:
  546. if self.lowvram:
  547. for param in self.control_params:
  548. if isinstance(param.control_model, torch.nn.Module):
  549. param.control_model.to("cpu")
  550. def hacked_basic_transformer_inner_forward(self, x, context=None):
  551. x_norm1 = self.norm1(x)
  552. self_attn1 = None
  553. if self.disable_self_attn:
  554. # Do not use self-attention
  555. self_attn1 = self.attn1(x_norm1, context=context)
  556. else:
  557. # Use self-attention
  558. self_attention_context = x_norm1
  559. if outer.attention_auto_machine == AutoMachine.Write:
  560. if outer.attention_auto_machine_weight > self.attn_weight:
  561. self.bank.append(self_attention_context.detach().clone())
  562. self.style_cfgs.append(outer.current_style_fidelity)
  563. if outer.attention_auto_machine == AutoMachine.Read:
  564. if len(self.bank) > 0:
  565. style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
  566. self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))
  567. self_attn1_c = self_attn1_uc.clone()
  568. if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
  569. self_attn1_c[outer.current_uc_indices] = self.attn1(
  570. x_norm1[outer.current_uc_indices],
  571. context=self_attention_context[outer.current_uc_indices])
  572. self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
  573. self.bank = []
  574. self.style_cfgs = []
  575. if self_attn1 is None:
  576. self_attn1 = self.attn1(x_norm1, context=self_attention_context)
  577. x = self_attn1.to(x.dtype) + x
  578. x = self.attn2(self.norm2(x), context=context) + x
  579. x = self.ff(self.norm3(x)) + x
  580. return x
  581. def hacked_group_norm_forward(self, *args, **kwargs):
  582. eps = 1e-6
  583. x = self.original_forward(*args, **kwargs)
  584. y = None
  585. if outer.gn_auto_machine == AutoMachine.Write:
  586. if outer.gn_auto_machine_weight > self.gn_weight:
  587. var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
  588. self.mean_bank.append(mean)
  589. self.var_bank.append(var)
  590. self.style_cfgs.append(outer.current_style_fidelity)
  591. if outer.gn_auto_machine == AutoMachine.Read:
  592. if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
  593. style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
  594. var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
  595. std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
  596. mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
  597. var_acc = sum(self.var_bank) / float(len(self.var_bank))
  598. std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
  599. y_uc = (((x - mean) / std) * std_acc) + mean_acc
  600. y_c = y_uc.clone()
  601. if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
  602. y_c[outer.current_uc_indices] = x.to(y_c.dtype)[outer.current_uc_indices]
  603. y = style_cfg * y_c + (1.0 - style_cfg) * y_uc
  604. self.mean_bank = []
  605. self.var_bank = []
  606. self.style_cfgs = []
  607. if y is None:
  608. y = x
  609. return y.to(x.dtype)
  610. if getattr(process, 'sample_before_CN_hack', None) is None:
  611. process.sample_before_CN_hack = process.sample
  612. process.sample = process_sample
  613. model._original_forward = model.forward
  614. outer.original_forward = model.forward
  615. model.forward = forward_webui.__get__(model, UNetModel)
  616. all_modules = torch_dfs(model)
  617. attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
  618. attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
  619. for i, module in enumerate(attn_modules):
  620. if getattr(module, '_original_inner_forward', None) is None:
  621. module._original_inner_forward = module._forward
  622. module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
  623. module.bank = []
  624. module.style_cfgs = []
  625. module.attn_weight = float(i) / float(len(attn_modules))
  626. gn_modules = [model.middle_block]
  627. model.middle_block.gn_weight = 0
  628. input_block_indices = [4, 5, 7, 8, 10, 11]
  629. for w, i in enumerate(input_block_indices):
  630. module = model.input_blocks[i]
  631. module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
  632. gn_modules.append(module)
  633. output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
  634. for w, i in enumerate(output_block_indices):
  635. module = model.output_blocks[i]
  636. module.gn_weight = float(w) / float(len(output_block_indices))
  637. gn_modules.append(module)
  638. for i, module in enumerate(gn_modules):
  639. if getattr(module, 'original_forward', None) is None:
  640. module.original_forward = module.forward
  641. module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
  642. module.mean_bank = []
  643. module.var_bank = []
  644. module.style_cfgs = []
  645. module.gn_weight *= 2
  646. outer.attn_module_list = attn_modules
  647. outer.gn_module_list = gn_modules
  648. scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)
  649. def restore(self, model):
  650. scripts.script_callbacks.remove_callbacks_for_function(self.guidance_schedule_handler)
  651. if hasattr(self, "control_params"):
  652. del self.control_params
  653. if not hasattr(model, "_original_forward"):
  654. # no such handle, ignore
  655. return
  656. model.forward = model._original_forward
  657. del model._original_forward