batch_hijack.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os
  2. from copy import copy
  3. from enum import Enum
  4. from typing import Tuple, List
  5. from modules import img2img, processing, shared, script_callbacks
  6. from scripts import external_code
  7. class BatchHijack:
  8. def __init__(self):
  9. self.is_batch = False
  10. self.batch_index = 0
  11. self.batch_size = 1
  12. self.init_seed = None
  13. self.init_subseed = None
  14. self.process_batch_callbacks = [self.on_process_batch]
  15. self.process_batch_each_callbacks = []
  16. self.postprocess_batch_each_callbacks = [self.on_postprocess_batch_each]
  17. self.postprocess_batch_callbacks = [self.on_postprocess_batch]
  18. def img2img_process_batch_hijack(self, p, *args, **kwargs):
  19. cn_is_batch, batches, output_dir, _ = get_cn_batches(p)
  20. if not cn_is_batch:
  21. return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
  22. self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir)
  23. try:
  24. return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
  25. finally:
  26. self.dispatch_callbacks(self.postprocess_batch_callbacks, p)
  27. def processing_process_images_hijack(self, p, *args, **kwargs):
  28. if self.is_batch:
  29. # we are in img2img batch tab, do a single batch iteration
  30. return self.process_images_cn_batch(p, *args, **kwargs)
  31. cn_is_batch, batches, output_dir, input_file_names = get_cn_batches(p)
  32. if not cn_is_batch:
  33. # we are not in batch mode, fallback to original function
  34. return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
  35. output_images = []
  36. try:
  37. self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir)
  38. for batch_i in range(self.batch_size):
  39. processed = self.process_images_cn_batch(p, *args, **kwargs)
  40. if shared.opts.data.get('controlnet_show_batch_images_in_ui', False):
  41. output_images.extend(processed.images[processed.index_of_first_image:])
  42. if output_dir:
  43. self.save_images(output_dir, input_file_names[batch_i], processed.images[processed.index_of_first_image:])
  44. if shared.state.interrupted:
  45. break
  46. finally:
  47. self.dispatch_callbacks(self.postprocess_batch_callbacks, p)
  48. if output_images:
  49. processed.images = output_images
  50. else:
  51. processed = processing.Processed(p, [], p.seed)
  52. return processed
  53. def process_images_cn_batch(self, p, *args, **kwargs):
  54. self.dispatch_callbacks(self.process_batch_each_callbacks, p)
  55. old_detectmap_output = shared.opts.data.get('control_net_no_detectmap', False)
  56. try:
  57. shared.opts.data.update({'control_net_no_detectmap': True})
  58. processed = getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
  59. finally:
  60. shared.opts.data.update({'control_net_no_detectmap': old_detectmap_output})
  61. self.dispatch_callbacks(self.postprocess_batch_each_callbacks, p, processed)
  62. # do not go past control net batch size
  63. if self.batch_index >= self.batch_size:
  64. shared.state.interrupted = True
  65. return processed
  66. def save_images(self, output_dir, init_image_path, output_images):
  67. os.makedirs(output_dir, exist_ok=True)
  68. for n, processed_image in enumerate(output_images):
  69. filename = os.path.basename(init_image_path)
  70. if n > 0:
  71. left, right = os.path.splitext(filename)
  72. filename = f"{left}-{n}{right}"
  73. if processed_image.mode == 'RGBA':
  74. processed_image = processed_image.convert("RGB")
  75. processed_image.save(os.path.join(output_dir, filename))
  76. def do_hijack(self):
  77. script_callbacks.on_script_unloaded(self.undo_hijack)
  78. hijack_function(
  79. module=img2img,
  80. name='process_batch',
  81. new_name='__controlnet_original_process_batch',
  82. new_value=self.img2img_process_batch_hijack,
  83. )
  84. hijack_function(
  85. module=processing,
  86. name='process_images_inner',
  87. new_name='__controlnet_original_process_images_inner',
  88. new_value=self.processing_process_images_hijack
  89. )
  90. def undo_hijack(self):
  91. unhijack_function(
  92. module=img2img,
  93. name='process_batch',
  94. new_name='__controlnet_original_process_batch',
  95. )
  96. unhijack_function(
  97. module=processing,
  98. name='process_images_inner',
  99. new_name='__controlnet_original_process_images_inner',
  100. )
  101. def adjust_job_count(self, p):
  102. if shared.state.job_count == -1:
  103. shared.state.job_count = p.n_iter
  104. shared.state.job_count *= self.batch_size
  105. def on_process_batch(self, p, batches, output_dir, *args):
  106. print('controlnet batch mode')
  107. self.is_batch = True
  108. self.batch_index = 0
  109. self.batch_size = len(batches)
  110. processing.fix_seed(p)
  111. if shared.opts.data.get('controlnet_increment_seed_during_batch', False):
  112. self.init_seed = p.seed
  113. self.init_subseed = p.subseed
  114. self.adjust_job_count(p)
  115. p.do_not_save_grid = True
  116. p.do_not_save_samples = bool(output_dir)
  117. def on_postprocess_batch_each(self, p, *args):
  118. self.batch_index += 1
  119. if shared.opts.data.get('controlnet_increment_seed_during_batch', False):
  120. p.seed = p.seed + len(p.all_prompts)
  121. p.subseed = p.subseed + len(p.all_prompts)
  122. def on_postprocess_batch(self, p, *args):
  123. self.is_batch = False
  124. self.batch_index = 0
  125. self.batch_size = 1
  126. if shared.opts.data.get('controlnet_increment_seed_during_batch', False):
  127. p.seed = self.init_seed
  128. p.all_seeds = [self.init_seed]
  129. p.subseed = self.init_subseed
  130. p.all_subseeds = [self.init_subseed]
  131. def dispatch_callbacks(self, callbacks, *args):
  132. for callback in callbacks:
  133. callback(*args)
  134. def hijack_function(module, name, new_name, new_value):
  135. # restore original function in case of reload
  136. unhijack_function(module=module, name=name, new_name=new_name)
  137. setattr(module, new_name, getattr(module, name))
  138. setattr(module, name, new_value)
  139. def unhijack_function(module, name, new_name):
  140. if hasattr(module, new_name):
  141. setattr(module, name, getattr(module, new_name))
  142. delattr(module, new_name)
  143. class InputMode(Enum):
  144. SIMPLE = "simple"
  145. BATCH = "batch"
  146. def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]:
  147. units = external_code.get_all_units_in_processing(p)
  148. units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)]
  149. any_unit_is_batch = False
  150. output_dir = ''
  151. input_file_names = []
  152. for unit in units:
  153. if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
  154. any_unit_is_batch = True
  155. output_dir = getattr(unit, 'output_dir', '')
  156. if isinstance(unit.batch_images, str):
  157. unit.batch_images = shared.listfiles(unit.batch_images)
  158. input_file_names = unit.batch_images
  159. if any_unit_is_batch:
  160. cn_batch_size = min(len(getattr(unit, 'batch_images', []))
  161. for unit in units
  162. if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH)
  163. else:
  164. cn_batch_size = 1
  165. batches = [[] for _ in range(cn_batch_size)]
  166. for i in range(cn_batch_size):
  167. for unit in units:
  168. if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.SIMPLE:
  169. batches[i].append(unit.image)
  170. else:
  171. batches[i].append(unit.batch_images[i])
  172. return any_unit_is_batch, batches, output_dir, input_file_names
  173. instance = BatchHijack()