multidiffusion.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import torch
  2. from modules.shared import state
  3. from methods.abstractdiffusion import TiledDiffusion
  4. class MultiDiffusion(TiledDiffusion):
  5. """
  6. MultiDiffusion Implementation
  7. Hijack the sampler for latent image tiling and fusion
  8. """
  9. def __init__(self, sampler, sampler_name:str, *args, **kwargs):
  10. super().__init__("MultiDiffusion", sampler, sampler_name, *args, **kwargs)
  11. # record the steps for progress bar
  12. # hook the sampler
  13. assert sampler_name != 'UniPC', \
  14. 'MultiDiffusion is not compatible with UniPC, please use other samplers instead.'
  15. if self.is_kdiff:
  16. # For K-Diffusion sampler with uniform prompt, we hijack into the inner model for simplicity
  17. # Otherwise, the masked-redraw will break due to the init_latent
  18. self.sampler_func = self.sampler.inner_model.forward
  19. self.sampler.inner_model.forward = self.kdiff_repeat
  20. else:
  21. self.sampler_func = sampler.orig_p_sample_ddim
  22. self.sampler.orig_p_sample_ddim = self.ddim_repeat
  23. # For ddim sampler we need to cache the pred_x0
  24. self.x_buffer_pred = None
  25. def repeat_cond_dict(self, cond_input, bboxes):
  26. cond = cond_input['c_crossattn'][0]
  27. # repeat the condition on its first dim
  28. cond_shape = cond.shape
  29. cond = cond.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1))
  30. image_cond = cond_input['c_concat'][0]
  31. if image_cond.shape[2] == self.h and image_cond.shape[3] == self.w:
  32. image_cond_list = []
  33. for bbox in bboxes:
  34. image_cond_list.append(image_cond[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]])
  35. image_cond_tile = torch.cat(image_cond_list, dim=0)
  36. else:
  37. image_cond_shape = image_cond.shape
  38. image_cond_tile = image_cond.repeat((len(bboxes),) + (1,) * (len(image_cond_shape) - 1))
  39. return {"c_crossattn": [cond], "c_concat": [image_cond_tile]}
  40. def get_global_weights(self):
  41. return 1.0
  42. def init_custom_bbox(self, global_multiplier, bbox_control_states, *args, **kwargs):
  43. super().init_custom_bbox(global_multiplier, bbox_control_states, *args, **kwargs)
  44. for bbox, _, _, m in self.custom_bboxes:
  45. self.weights[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += m
  46. @torch.no_grad()
  47. def kdiff_repeat(self, x_in, sigma_in, cond):
  48. '''
  49. This function will replace the original forward function in ldm/diffusionmodels/kdiffusion.py
  50. So its signature should be the same as the original function,
  51. especially the "cond" should be with exactly the same name
  52. '''
  53. def repeat_func(x_tile, bboxes):
  54. # For kdiff sampler, the dim 0 of input x_in is:
  55. # = batch_size * (num_AND + 1) if not an edit model
  56. # = batch_size * (num_AND + 2) otherwise
  57. sigma_in_tile = sigma_in.repeat(len(bboxes))
  58. new_cond = self.repeat_cond_dict(cond, bboxes)
  59. x_tile_out = self.sampler_func(x_tile, sigma_in_tile, cond=new_cond)
  60. return x_tile_out
  61. def custom_func(x, custom_cond, uncond, bbox_id, bbox):
  62. return self.kdiff_custom_forward(x, cond, custom_cond, uncond, bbox_id, bbox,
  63. sigma_in, self.sampler_func)
  64. return self.compute_x_tile(x_in, repeat_func, custom_func)
  65. @torch.no_grad()
  66. def ddim_repeat(self, x_in, cond_in, ts, unconditional_conditioning, *args, **kwargs):
  67. '''
  68. This function will replace the original p_sample_ddim function in ldm/diffusionmodels/ddim.py
  69. So its signature should be the same as the original function,
  70. Particularly, the unconditional_conditioning should be with exactly the same name
  71. '''
  72. def repeat_func(x_tile, bboxes):
  73. if isinstance(cond_in, dict):
  74. ts_tile = ts.repeat(len(bboxes))
  75. cond_tile = self.repeat_cond_dict(cond_in, bboxes)
  76. ucond_tile = self.repeat_cond_dict(unconditional_conditioning, bboxes)
  77. else:
  78. ts_tile = ts.repeat(len(bboxes))
  79. cond_shape = cond_in.shape
  80. cond_tile = cond_in.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1))
  81. ucond_shape = unconditional_conditioning.shape
  82. ucond_tile = unconditional_conditioning.repeat((len(bboxes),) + (1,) * (len(ucond_shape) - 1))
  83. x_tile_out, x_pred = self.sampler_func(
  84. x_tile, cond_tile,
  85. ts_tile,
  86. unconditional_conditioning=ucond_tile,
  87. *args, **kwargs)
  88. return x_tile_out, x_pred
  89. def custom_func(x, cond, uncond, bbox_id, bbox):
  90. # before the final forward, we can set the control tensor
  91. def forward_func(x, *args, **kwargs):
  92. self.set_control_tensor(bbox_id, 2*x.shape[0])
  93. return self.sampler_func(x, *args, **kwargs)
  94. return self.ddim_custom_forward(x, cond_in, cond, uncond, bbox, ts,
  95. forward_func, *args, **kwargs)
  96. return self.compute_x_tile(x_in, repeat_func, custom_func)
  97. def compute_x_tile(self, x_in, func, custom_func):
  98. N, C, H, W = x_in.shape
  99. assert H == self.h and W == self.w
  100. self.init(x_in)
  101. if not self.is_kdiff:
  102. if self.x_buffer_pred is None:
  103. self.x_buffer_pred = torch.zeros_like(x_in, device=x_in.device)
  104. else:
  105. self.x_buffer_pred.zero_()
  106. # Global sampling
  107. if self.global_multiplier > 0:
  108. for batch_id, bboxes in enumerate(self.batched_bboxes):
  109. if state.interrupted: return x_in
  110. x_tile_list = []
  111. for bbox in bboxes:
  112. x_tile_list.append(x_in[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]])
  113. x_tile = torch.cat(x_tile_list, dim=0)
  114. # controlnet tiling
  115. self.switch_controlnet_tensors(batch_id, N, len(bboxes))
  116. # compute tiles
  117. if self.is_kdiff:
  118. x_tile_out = func(x_tile, bboxes)
  119. for i, bbox in enumerate(bboxes):
  120. x = x_tile_out[i*N:(i+1)*N, :, :, :]
  121. self.x_buffer[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += x
  122. else:
  123. x_tile_out, x_tile_pred = func(x_tile, bboxes)
  124. for i, bbox in enumerate(bboxes):
  125. x_o = x_tile_out [i*N:(i+1)*N, :, :, :]
  126. x_p = x_tile_pred[i*N:(i+1)*N, :, :, :]
  127. self.x_buffer [:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += x_o
  128. self.x_buffer_pred[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += x_p
  129. # update progress bar
  130. self.update_pbar()
  131. # Custom region sampling
  132. if len(self.custom_bboxes) > 0:
  133. if self.global_multiplier > 0 and abs(self.global_multiplier - 1.0) > 1e-6:
  134. self.x_buffer *= self.global_multiplier
  135. if not self.is_kdiff:
  136. self.x_buffer_pred *= self.global_multiplier
  137. for index, (bbox, cond, uncond, multiplier) in enumerate(self.custom_bboxes):
  138. x_tile = x_in[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]]
  139. if self.is_kdiff:
  140. # retrieve original x_in from construncted input
  141. # kdiff last batch is always the correct original input
  142. x_tile_out = custom_func(x_tile, cond, uncond, index, bbox)
  143. x_tile_out *= multiplier
  144. self.x_buffer[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += x_tile_out
  145. else:
  146. x_tile_out, x_tile_pred = custom_func(x_tile, cond, uncond, index, bbox)
  147. x_tile_out *= multiplier
  148. x_tile_pred *= multiplier
  149. self.x_buffer[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += x_tile_out
  150. self.x_buffer_pred[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] += x_tile_pred
  151. # update progress bar
  152. self.update_pbar()
  153. # Normalize. only divide when weights are greater than 0
  154. x_out = torch.where(self.weights > 0, self.x_buffer / self.weights, self.x_buffer)
  155. if not self.is_kdiff:
  156. x_pred = torch.where(self.weights > 0, self.x_buffer_pred / self.weights, self.x_buffer_pred)
  157. return x_out, x_pred
  158. return x_out