vae_optimize.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840
  1. # ------------------------------------------------------------------------
  2. #
  3. # Ultimate VAE Tile Optimization
  4. #
  5. # Introducing a revolutionary new optimization designed to make
  6. # the VAE work with giant images on limited VRAM!
  7. # Say goodbye to the frustration of OOM and hello to seamless output!
  8. #
  9. # ------------------------------------------------------------------------
  10. #
  11. # This script is a wild hack that splits the image into tiles,
  12. # encodes each tile separately, and merges the result back together.
  13. #
  14. # Advantages:
  15. # - The VAE can now work with giant images on limited VRAM
  16. # (~10 GB for 8K images!)
  17. # - The merged output is completely seamless without any post-processing.
  18. #
  19. # Drawbacks:
  20. # - Giant RAM needed. To store the intermediate results for a 4096x4096
  21. # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
  22. # you need 128 GB RAM machine (it consumes ~100 GB)
  23. # - NaNs always appear in for 8k images when you use fp16 (half) VAE
  24. # You must use --no-half-vae to disable half VAE for that giant image.
  25. # - Slow speed. With default tile size, it takes around 50/200 seconds
  26. # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
  27. # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
  28. # - The gradient calculation is not compatible with this hack. It
  29. # will break any backward() or torch.autograd.grad() that passes VAE.
  30. # (But you can still use the VAE to generate training data.)
  31. #
  32. # How it works:
  33. # 1) The image is split into tiles.
  34. # - To ensure perfect results, each tile is padded with 32 pixels
  35. # on each side.
  36. # - Then the conv2d/silu/upsample/downsample can produce identical
  37. # results to the original image without splitting.
  38. # 2) The original forward is decomposed into a task queue and a task worker.
  39. # - The task queue is a list of functions that will be executed in order.
  40. # - The task worker is a loop that executes the tasks in the queue.
  41. # 3) The task queue is executed for each tile.
  42. # - Current tile is sent to GPU.
  43. # - local operations are directly executed.
  44. # - Group norm calculation is temporarily suspended until the mean
  45. # and var of all tiles are calculated.
  46. # - The residual is pre-calculated and stored and addded back later.
  47. # - When need to go to the next tile, the current tile is send to cpu.
  48. # 4) After all tiles are processed, tiles are merged on cpu and return.
  49. #
  50. # Enjoy!
  51. #
  52. # @author: LI YI @ Nanyang Technological University - Singapore
  53. # @date: 2023-03-02
  54. # @license: MIT License
  55. #
  56. # Please give me a star if you like this project!
  57. #
  58. # -------------------------------------------------------------------------
  59. import gc
  60. from time import time
  61. import math
  62. from tqdm import tqdm
  63. import torch
  64. import torch.version
  65. import torch.nn.functional as F
  66. from einops import rearrange
  67. import gradio as gr
  68. import modules.scripts as scripts
  69. import modules.devices as devices
  70. from modules.shared import state
  71. from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
  72. try:
  73. import xformers
  74. import xformers.ops
  75. except ImportError:
  76. pass
  77. def get_recommend_encoder_tile_size():
  78. if torch.cuda.is_available():
  79. total_memory = torch.cuda.get_device_properties(
  80. devices.device).total_memory // 2**20
  81. if total_memory > 16*1000:
  82. ENCODER_TILE_SIZE = 3072
  83. elif total_memory > 12*1000:
  84. ENCODER_TILE_SIZE = 2048
  85. elif total_memory > 8*1000:
  86. ENCODER_TILE_SIZE = 1536
  87. else:
  88. ENCODER_TILE_SIZE = 960
  89. else:
  90. ENCODER_TILE_SIZE = 512
  91. return ENCODER_TILE_SIZE
  92. def get_recommend_decoder_tile_size():
  93. if torch.cuda.is_available():
  94. total_memory = torch.cuda.get_device_properties(
  95. devices.device).total_memory // 2**20
  96. if total_memory > 30*1000:
  97. DECODER_TILE_SIZE = 256
  98. elif total_memory > 16*1000:
  99. DECODER_TILE_SIZE = 192
  100. elif total_memory > 12*1000:
  101. DECODER_TILE_SIZE = 128
  102. elif total_memory > 8*1000:
  103. DECODER_TILE_SIZE = 96
  104. else:
  105. DECODER_TILE_SIZE = 64
  106. else:
  107. DECODER_TILE_SIZE = 64
  108. return DECODER_TILE_SIZE
  109. if 'global const':
  110. DEFAULT_ENABLED = False
  111. DEFAULT_MOVE_TO_GPU = False
  112. DEFAULT_FAST_ENCODER = True
  113. DEFAULT_FAST_DECODER = True
  114. DEFAULT_COLOR_FIX = 0
  115. DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
  116. DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
  117. # inplace version of silu
  118. def inplace_nonlinearity(x):
  119. # Test: fix for Nans
  120. return F.silu(x, inplace=True)
  121. # extracted from ldm.modules.diffusionmodules.model
  122. def attn_forward(self, h_):
  123. q = self.q(h_)
  124. k = self.k(h_)
  125. v = self.v(h_)
  126. # compute attention
  127. b, c, h, w = q.shape
  128. q = q.reshape(b, c, h*w)
  129. q = q.permute(0, 2, 1) # b,hw,c
  130. k = k.reshape(b, c, h*w) # b,c,hw
  131. w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  132. w_ = w_ * (int(c)**(-0.5))
  133. w_ = torch.nn.functional.softmax(w_, dim=2)
  134. # attend to values
  135. v = v.reshape(b, c, h*w)
  136. w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  137. # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  138. h_ = torch.bmm(v, w_)
  139. h_ = h_.reshape(b, c, h, w)
  140. h_ = self.proj_out(h_)
  141. return h_
  142. def xformer_attn_forward(self, h_):
  143. q = self.q(h_)
  144. k = self.k(h_)
  145. v = self.v(h_)
  146. # compute attention
  147. B, C, H, W = q.shape
  148. q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
  149. q, k, v = map(
  150. lambda t: t.unsqueeze(3)
  151. .reshape(B, t.shape[1], 1, C)
  152. .permute(0, 2, 1, 3)
  153. .reshape(B * 1, t.shape[1], C)
  154. .contiguous(),
  155. (q, k, v),
  156. )
  157. out = xformers.ops.memory_efficient_attention(
  158. q, k, v, attn_bias=None, op=self.attention_op)
  159. out = (
  160. out.unsqueeze(0)
  161. .reshape(B, 1, out.shape[1], C)
  162. .permute(0, 2, 1, 3)
  163. .reshape(B, out.shape[1], C)
  164. )
  165. out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
  166. out = self.proj_out(out)
  167. return out
  168. def attn2task(task_queue, net):
  169. if isinstance(net, AttnBlock):
  170. task_queue.append(('store_res', lambda x: x))
  171. task_queue.append(('pre_norm', net.norm))
  172. task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
  173. task_queue.append(['add_res', None])
  174. elif isinstance(net, MemoryEfficientAttnBlock):
  175. task_queue.append(('store_res', lambda x: x))
  176. task_queue.append(('pre_norm', net.norm))
  177. task_queue.append(
  178. ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
  179. task_queue.append(['add_res', None])
  180. def resblock2task(queue, block):
  181. """
  182. Turn a ResNetBlock into a sequence of tasks and append to the task queue
  183. @param queue: the target task queue
  184. @param block: ResNetBlock
  185. """
  186. if block.in_channels != block.out_channels:
  187. if block.use_conv_shortcut:
  188. queue.append(('store_res', block.conv_shortcut))
  189. else:
  190. queue.append(('store_res', block.nin_shortcut))
  191. else:
  192. queue.append(('store_res', lambda x: x))
  193. queue.append(('pre_norm', block.norm1))
  194. queue.append(('silu', inplace_nonlinearity))
  195. queue.append(('conv1', block.conv1))
  196. queue.append(('pre_norm', block.norm2))
  197. queue.append(('silu', inplace_nonlinearity))
  198. queue.append(('conv2', block.conv2))
  199. queue.append(['add_res', None])
  200. def build_sampling(task_queue, net, is_decoder):
  201. """
  202. Build the sampling part of a task queue
  203. @param task_queue: the target task queue
  204. @param net: the network
  205. @param is_decoder: currently building decoder or encoder
  206. """
  207. if is_decoder:
  208. resblock2task(task_queue, net.mid.block_1)
  209. attn2task(task_queue, net.mid.attn_1)
  210. resblock2task(task_queue, net.mid.block_2)
  211. resolution_iter = reversed(range(net.num_resolutions))
  212. block_ids = net.num_res_blocks + 1
  213. condition = 0
  214. module = net.up
  215. func_name = 'upsample'
  216. else:
  217. resolution_iter = range(net.num_resolutions)
  218. block_ids = net.num_res_blocks
  219. condition = net.num_resolutions - 1
  220. module = net.down
  221. func_name = 'downsample'
  222. for i_level in resolution_iter:
  223. for i_block in range(block_ids):
  224. resblock2task(task_queue, module[i_level].block[i_block])
  225. if i_level != condition:
  226. task_queue.append((func_name, getattr(module[i_level], func_name)))
  227. if not is_decoder:
  228. resblock2task(task_queue, net.mid.block_1)
  229. attn2task(task_queue, net.mid.attn_1)
  230. resblock2task(task_queue, net.mid.block_2)
  231. def build_task_queue(net, is_decoder):
  232. """
  233. Build a single task queue for the encoder or decoder
  234. @param net: the VAE decoder or encoder network
  235. @param is_decoder: currently building decoder or encoder
  236. @return: the task queue
  237. """
  238. task_queue = []
  239. task_queue.append(('conv_in', net.conv_in))
  240. # construct the sampling part of the task queue
  241. # because encoder and decoder share the same architecture, we extract the sampling part
  242. build_sampling(task_queue, net, is_decoder)
  243. if not is_decoder or not net.give_pre_end:
  244. task_queue.append(('pre_norm', net.norm_out))
  245. task_queue.append(('silu', inplace_nonlinearity))
  246. task_queue.append(('conv_out', net.conv_out))
  247. if is_decoder and net.tanh_out:
  248. task_queue.append(('tanh', torch.tanh))
  249. return task_queue
  250. def clone_task_queue(task_queue):
  251. """
  252. Clone a task queue
  253. @param task_queue: the task queue to be cloned
  254. @return: the cloned task queue
  255. """
  256. return [[item for item in task] for task in task_queue]
  257. def get_var_mean(input, num_groups, eps=1e-6):
  258. """
  259. Get mean and var for group norm
  260. """
  261. b, c = input.size(0), input.size(1)
  262. channel_in_group = int(c/num_groups)
  263. input_reshaped = input.contiguous().view(
  264. 1, int(b * num_groups), channel_in_group, *input.size()[2:])
  265. var, mean = torch.var_mean(
  266. input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
  267. return var, mean
  268. def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
  269. """
  270. Custom group norm with fixed mean and var
  271. @param input: input tensor
  272. @param num_groups: number of groups. by default, num_groups = 32
  273. @param mean: mean, must be pre-calculated by get_var_mean
  274. @param var: var, must be pre-calculated by get_var_mean
  275. @param weight: weight, should be fetched from the original group norm
  276. @param bias: bias, should be fetched from the original group norm
  277. @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
  278. @return: normalized tensor
  279. """
  280. b, c = input.size(0), input.size(1)
  281. channel_in_group = int(c/num_groups)
  282. input_reshaped = input.contiguous().view(
  283. 1, int(b * num_groups), channel_in_group, *input.size()[2:])
  284. out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
  285. training=False, momentum=0, eps=eps)
  286. out = out.view(b, c, *input.size()[2:])
  287. # post affine transform
  288. if weight is not None:
  289. out *= weight.view(1, -1, 1, 1)
  290. if bias is not None:
  291. out += bias.view(1, -1, 1, 1)
  292. return out
  293. def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
  294. """
  295. Crop the valid region from the tile
  296. @param x: input tile
  297. @param input_bbox: original input bounding box
  298. @param target_bbox: output bounding box
  299. @param scale: scale factor
  300. @return: cropped tile
  301. """
  302. padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
  303. margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
  304. return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
  305. # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
  306. def perfcount(fn):
  307. def wrapper(*args, **kwargs):
  308. ts = time()
  309. if torch.cuda.is_available():
  310. torch.cuda.reset_peak_memory_stats(devices.device)
  311. devices.torch_gc()
  312. gc.collect()
  313. ret = fn(*args, **kwargs)
  314. devices.torch_gc()
  315. gc.collect()
  316. if torch.cuda.is_available():
  317. vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
  318. torch.cuda.reset_peak_memory_stats(devices.device)
  319. print(
  320. f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
  321. else:
  322. print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
  323. return ret
  324. return wrapper
  325. # copy end :)
  326. class GroupNormParam:
  327. def __init__(self):
  328. self.var_list = []
  329. self.mean_list = []
  330. self.pixel_list = []
  331. self.weight = None
  332. self.bias = None
  333. def add_tile(self, tile, layer):
  334. var, mean = get_var_mean(tile, 32)
  335. # For giant images, the variance can be larger than max float16
  336. # In this case we create a copy to float32
  337. if var.dtype == torch.float16 and var.isinf().any():
  338. fp32_tile = tile.float()
  339. var, mean = get_var_mean(fp32_tile, 32)
  340. # ============= DEBUG: test for infinite =============
  341. # if torch.isinf(var).any():
  342. # print('var: ', var)
  343. # ====================================================
  344. self.var_list.append(var)
  345. self.mean_list.append(mean)
  346. self.pixel_list.append(
  347. tile.shape[2]*tile.shape[3])
  348. if hasattr(layer, 'weight'):
  349. self.weight = layer.weight
  350. self.bias = layer.bias
  351. else:
  352. self.weight = None
  353. self.bias = None
  354. def summary(self):
  355. """
  356. summarize the mean and var and return a function
  357. that apply group norm on each tile
  358. """
  359. if len(self.var_list) == 0:
  360. return None
  361. var = torch.vstack(self.var_list)
  362. mean = torch.vstack(self.mean_list)
  363. max_value = max(self.pixel_list)
  364. pixels = torch.tensor(
  365. self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
  366. sum_pixels = torch.sum(pixels)
  367. pixels = pixels.unsqueeze(
  368. 1) / sum_pixels
  369. var = torch.sum(
  370. var * pixels, dim=0)
  371. mean = torch.sum(
  372. mean * pixels, dim=0)
  373. return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
  374. @staticmethod
  375. def from_tile(tile, norm):
  376. """
  377. create a function from a single tile without summary
  378. """
  379. var, mean = get_var_mean(tile, 32)
  380. if var.dtype == torch.float16 and var.isinf().any():
  381. fp32_tile = tile.float()
  382. var, mean = get_var_mean(fp32_tile, 32)
  383. # if it is a macbook, we need to convert back to float16
  384. if var.device.type == 'mps':
  385. # clamp to avoid overflow
  386. var = torch.clamp(var, 0, 60000)
  387. var = var.half()
  388. mean = mean.half()
  389. if hasattr(norm, 'weight'):
  390. weight = norm.weight
  391. bias = norm.bias
  392. else:
  393. weight = None
  394. bias = None
  395. def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
  396. return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
  397. return group_norm_func
  398. class VAEHook:
  399. def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
  400. self.net = net # encoder | decoder
  401. self.tile_size = tile_size
  402. self.is_decoder = is_decoder
  403. self.fast_mode = (fast_encoder and not is_decoder) or (
  404. fast_decoder and is_decoder)
  405. self.color_fix = color_fix and not is_decoder
  406. self.to_gpu = to_gpu
  407. self.pad = 11 if is_decoder else 32
  408. def __call__(self, x):
  409. B, C, H, W = x.shape
  410. original_device = next(self.net.parameters()).device
  411. try:
  412. if self.to_gpu:
  413. self.net.to(devices.get_optimal_device())
  414. if max(H, W) <= self.pad * 2 + self.tile_size:
  415. print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
  416. return self.net.original_forward(x)
  417. else:
  418. return self.vae_tile_forward(x)
  419. finally:
  420. self.net.to(original_device)
  421. def get_best_tile_size(self, lowerbound, upperbound):
  422. """
  423. Get the best tile size for GPU memory
  424. """
  425. divider = 32
  426. while divider >= 2:
  427. remainer = lowerbound % divider
  428. if remainer == 0:
  429. return lowerbound
  430. candidate = lowerbound - remainer + divider
  431. if candidate <= upperbound:
  432. return candidate
  433. divider //= 2
  434. return lowerbound
  435. def split_tiles(self, h, w):
  436. """
  437. Tool function to split the image into tiles
  438. @param h: height of the image
  439. @param w: width of the image
  440. @return: tile_input_bboxes, tile_output_bboxes
  441. """
  442. tile_input_bboxes, tile_output_bboxes = [], []
  443. tile_size = self.tile_size
  444. pad = self.pad
  445. num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
  446. num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
  447. # If any of the numbers are 0, we let it be 1
  448. # This is to deal with long and thin images
  449. num_height_tiles = max(num_height_tiles, 1)
  450. num_width_tiles = max(num_width_tiles, 1)
  451. # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
  452. real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
  453. real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
  454. real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
  455. real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
  456. print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
  457. f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
  458. for i in range(num_height_tiles):
  459. for j in range(num_width_tiles):
  460. # bbox: [x1, x2, y1, y2]
  461. # the padding is is unnessary for image borders. So we directly start from (32, 32)
  462. input_bbox = [
  463. pad + j * real_tile_width,
  464. min(pad + (j + 1) * real_tile_width, w),
  465. pad + i * real_tile_height,
  466. min(pad + (i + 1) * real_tile_height, h),
  467. ]
  468. # if the output bbox is close to the image boundary, we extend it to the image boundary
  469. output_bbox = [
  470. input_bbox[0] if input_bbox[0] > pad else 0,
  471. input_bbox[1] if input_bbox[1] < w - pad else w,
  472. input_bbox[2] if input_bbox[2] > pad else 0,
  473. input_bbox[3] if input_bbox[3] < h - pad else h,
  474. ]
  475. # scale to get the final output bbox
  476. output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
  477. tile_output_bboxes.append(output_bbox)
  478. # indistinguishable expand the input bbox by pad pixels
  479. tile_input_bboxes.append([
  480. max(0, input_bbox[0] - pad),
  481. min(w, input_bbox[1] + pad),
  482. max(0, input_bbox[2] - pad),
  483. min(h, input_bbox[3] + pad),
  484. ])
  485. return tile_input_bboxes, tile_output_bboxes
  486. @torch.no_grad()
  487. def estimate_group_norm(self, z, task_queue, color_fix):
  488. device = z.device
  489. tile = z
  490. last_id = len(task_queue) - 1
  491. while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
  492. last_id -= 1
  493. if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
  494. raise ValueError('No group norm found in the task queue')
  495. # estimate until the last group norm
  496. for i in range(last_id + 1):
  497. task = task_queue[i]
  498. if task[0] == 'pre_norm':
  499. group_norm_func = GroupNormParam.from_tile(tile, task[1])
  500. task_queue[i] = ('apply_norm', group_norm_func)
  501. if i == last_id:
  502. return True
  503. tile = group_norm_func(tile)
  504. elif task[0] == 'store_res':
  505. task_id = i + 1
  506. while task_id < last_id and task_queue[task_id][0] != 'add_res':
  507. task_id += 1
  508. if task_id >= last_id:
  509. continue
  510. task_queue[task_id][1] = task[1](tile)
  511. elif task[0] == 'add_res':
  512. tile += task[1].to(device)
  513. task[1] = None
  514. elif color_fix and task[0] == 'downsample':
  515. for j in range(i, last_id + 1):
  516. if task_queue[j][0] == 'store_res':
  517. task_queue[j] = ('store_res_cpu', task_queue[j][1])
  518. return True
  519. else:
  520. tile = task[1](tile)
  521. try:
  522. devices.test_for_nans(tile, "vae")
  523. except:
  524. print(f'Nan detected in fast mode estimation. Fast mode disabled.')
  525. return False
  526. raise IndexError('Should not reach here')
  527. @perfcount
  528. @torch.no_grad()
  529. def vae_tile_forward(self, z):
  530. """
  531. Decode a latent vector z into an image in a tiled manner.
  532. @param z: latent vector
  533. @return: image
  534. """
  535. device = next(self.net.parameters()).device
  536. net = self.net
  537. tile_size = self.tile_size
  538. is_decoder = self.is_decoder
  539. z = z.detach() # detach the input to avoid backprop
  540. N, height, width = z.shape[0], z.shape[2], z.shape[3]
  541. net.last_z_shape = z.shape
  542. # Split the input into tiles and build a task queue for each tile
  543. print(
  544. f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
  545. in_bboxes, out_bboxes = self.split_tiles(height, width)
  546. # Prepare tiles by split the input latents
  547. tiles = []
  548. for input_bbox in in_bboxes:
  549. tile = z[:, :, input_bbox[2]:input_bbox[3],
  550. input_bbox[0]:input_bbox[1]].cpu()
  551. tiles.append(tile)
  552. num_tiles = len(tiles)
  553. num_completed = 0
  554. single_task_queue = build_task_queue(net, is_decoder)
  555. if self.fast_mode:
  556. # Fast mode: downsample the input image to the tile size,
  557. # then estimate the group norm parameters on the downsampled image
  558. scale_factor = tile_size / max(height, width)
  559. z = z.to(device)
  560. downsampled_z = F.interpolate(
  561. z, scale_factor=scale_factor, mode='nearest-exact')
  562. # use nearest-exact to keep statictics as close as possible
  563. print(
  564. f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
  565. # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
  566. # The downsampling will heavily distort its mean and std, so we need to recover it.
  567. std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
  568. std_new, mean_new = torch.std_mean(
  569. downsampled_z, dim=[0, 2, 3], keepdim=True)
  570. downsampled_z = (downsampled_z - mean_new) / \
  571. std_new * std_old + mean_old
  572. # occasionally the std_new is too small or too large, which exceeds the range of float16
  573. # so we need to clamp it to max z's range.
  574. downsampled_z = torch.clamp_(
  575. downsampled_z, min=z.min(), max=z.max())
  576. estimate_task_queue = clone_task_queue(single_task_queue)
  577. if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
  578. single_task_queue = estimate_task_queue
  579. task_queues = [clone_task_queue(single_task_queue)
  580. for _ in range(num_tiles)]
  581. # Free memory of input latent tensor
  582. del z
  583. result = None
  584. # Build task queues
  585. # Task queue execution
  586. desc = f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: "
  587. pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=desc)
  588. # execute the task back and forth when switch tiles so that we always
  589. # keep one tile on the GPU to reduce unnecessary data transfer
  590. forward = True
  591. while True:
  592. group_norm_param = GroupNormParam()
  593. for i in range(num_tiles) if forward else reversed(range(num_tiles)):
  594. if state.interrupted:
  595. return
  596. tile = tiles[i].to(device)
  597. input_bbox = in_bboxes[i]
  598. task_queue = task_queues[i]
  599. while len(task_queue) > 0:
  600. if state.interrupted:
  601. return
  602. # DEBUG: current task
  603. # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
  604. task = task_queue.pop(0)
  605. if task[0] == 'pre_norm':
  606. group_norm_param.add_tile(tile, task[1])
  607. break
  608. elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
  609. task_id = 0
  610. res = task[1](tile)
  611. if not self.fast_mode or task[0] == 'store_res_cpu':
  612. res = res.cpu()
  613. while task_queue[task_id][0] != 'add_res':
  614. task_id += 1
  615. task_queue[task_id][1] = res
  616. elif task[0] == 'add_res':
  617. tile += task[1].to(device)
  618. task[1] = None
  619. else:
  620. tile = task[1](tile)
  621. pbar.update(1)
  622. # check for NaNs in the tile.
  623. # If there are NaNs, we abort the process to save user's time
  624. devices.test_for_nans(tile, "vae")
  625. if len(task_queue) == 0:
  626. tiles[i] = None
  627. num_completed += 1
  628. if result is None:
  629. result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
  630. result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
  631. del tile
  632. elif i == num_tiles - 1 and forward:
  633. forward = False
  634. tiles[i] = tile
  635. elif i == 0 and not forward:
  636. forward = True
  637. tiles[i] = tile
  638. else:
  639. tiles[i] = tile.cpu()
  640. del tile
  641. if num_completed == num_tiles:
  642. break
  643. # insert the group norm task to the head of each task queue
  644. group_norm_func = group_norm_param.summary()
  645. if group_norm_func is not None:
  646. for i in range(num_tiles):
  647. task_queue = task_queues[i]
  648. task_queue.insert(0, ('apply_norm', group_norm_func))
  649. # Done!
  650. pbar.close()
  651. return result
  652. class Script(scripts.Script):
  653. def title(self):
  654. return "Tiled VAE"
  655. def show(self, is_img2img):
  656. if devices.get_optimal_device_name() == 'mps':
  657. print(f'[Tiled VAE]: Tiled VAE is not needed on Mac. Skip loading...')
  658. return False
  659. return scripts.AlwaysVisible
  660. def ui(self, is_img2img):
  661. with gr.Accordion('Tiled VAE', open=False):
  662. with gr.Row():
  663. enabled = gr.Checkbox(
  664. label='Enable', value=lambda: DEFAULT_ENABLED)
  665. vae_to_gpu = gr.Checkbox(
  666. label='Move VAE to GPU', value=lambda: DEFAULT_MOVE_TO_GPU)
  667. encoder_size_tips = gr.HTML(
  668. '<p style="margin-bottom:0.8em">Please use smaller tile size when see CUDA error: out of memory.</p>')
  669. with gr.Row():
  670. encoder_tile_size = gr.Slider(
  671. label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=lambda: DEFAULT_ENCODER_TILE_SIZE)
  672. decoder_tile_size = gr.Slider(
  673. label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=lambda: DEFAULT_DECODER_TILE_SIZE)
  674. reset = gr.Button(value="Reset Tile Size")
  675. reset.click(fn=lambda: [DEFAULT_ENCODER_TILE_SIZE, DEFAULT_DECODER_TILE_SIZE], outputs=[
  676. encoder_tile_size, decoder_tile_size])
  677. with gr.Row():
  678. fast_encoder = gr.Checkbox(
  679. label='Fast Encoder', value=lambda: DEFAULT_FAST_ENCODER)
  680. fast_decoder = gr.Checkbox(
  681. label='Fast Decoder', value=lambda: DEFAULT_FAST_DECODER)
  682. with gr.Row():
  683. fast_encoder_tips = gr.HTML(
  684. '<p style="margin-bottom:0.8em">Fast Encoder may change colors; Can fix it with more RAM and lower speed.</p>')
  685. color_fix = gr.Checkbox(
  686. label='Encoder Color Fix', value=lambda: DEFAULT_COLOR_FIX)
  687. def on_fast_encoder(value):
  688. if value:
  689. return gr.update(visible=True, interactive=True), gr.update(visible=True)
  690. else:
  691. return gr.update(visible=False, interactive=False), gr.update(visible=False)
  692. fast_encoder.change(fn=on_fast_encoder, inputs=[fast_encoder], outputs=[
  693. color_fix, fast_encoder_tips])
  694. return [enabled, vae_to_gpu, fast_decoder, fast_encoder, color_fix, encoder_tile_size, decoder_tile_size]
  695. def process(self, p, enabled, vae_to_gpu, fast_decoder, fast_encoder, color_fix, encoder_tile_size, decoder_tile_size):
  696. vae = p.sd_model.first_stage_model
  697. # for shorthand
  698. encoder = vae.encoder
  699. decoder = vae.decoder
  700. # save original forward (only once)
  701. if not hasattr(encoder, 'original_forward'):
  702. setattr(encoder, 'original_forward', encoder.forward)
  703. if not hasattr(decoder, 'original_forward'):
  704. setattr(decoder, 'original_forward', decoder.forward)
  705. # undo hijack if disabled
  706. if not enabled:
  707. if isinstance(encoder.forward, VAEHook): encoder.forward = encoder.original_forward
  708. if isinstance(decoder.forward, VAEHook): decoder.forward = decoder.original_forward
  709. return
  710. if devices.get_optimal_device == torch.device('cpu'):
  711. print("[Tiled VAE] Tiled VAE is not needed as your device has no GPU VRAM.")
  712. return
  713. if vae.device == torch.device('cpu') and not vae_to_gpu:
  714. print(
  715. "[Tiled VAE] VAE is on CPU. Please enable 'Move VAE to GPU' to use Tiled VAE.")
  716. return
  717. # do hijack
  718. encoder.forward = VAEHook(
  719. encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
  720. decoder.forward = VAEHook(
  721. decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)