sd_hijack_optimizations.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. import math
  2. import sys
  3. import traceback
  4. import psutil
  5. import torch
  6. from torch import einsum
  7. from ldm.util import default
  8. from einops import rearrange
  9. from modules import shared, errors, devices
  10. from modules.hypernetworks import hypernetwork
  11. from .sub_quadratic_attention import efficient_dot_product_attention
  12. if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
  13. try:
  14. import xformers.ops
  15. shared.xformers_available = True
  16. except Exception:
  17. print("Cannot import xformers", file=sys.stderr)
  18. print(traceback.format_exc(), file=sys.stderr)
  19. def get_available_vram():
  20. if shared.device.type == 'cuda':
  21. stats = torch.cuda.memory_stats(shared.device)
  22. mem_active = stats['active_bytes.all.current']
  23. mem_reserved = stats['reserved_bytes.all.current']
  24. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  25. mem_free_torch = mem_reserved - mem_active
  26. mem_free_total = mem_free_cuda + mem_free_torch
  27. return mem_free_total
  28. else:
  29. return psutil.virtual_memory().available
  30. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  31. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  32. h = self.heads
  33. q_in = self.to_q(x)
  34. context = default(context, x)
  35. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  36. k_in = self.to_k(context_k)
  37. v_in = self.to_v(context_v)
  38. del context, context_k, context_v, x
  39. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
  40. del q_in, k_in, v_in
  41. dtype = q.dtype
  42. if shared.opts.upcast_attn:
  43. q, k, v = q.float(), k.float(), v.float()
  44. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  45. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  46. for i in range(0, q.shape[0], 2):
  47. end = i + 2
  48. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  49. s1 *= self.scale
  50. s2 = s1.softmax(dim=-1)
  51. del s1
  52. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  53. del s2
  54. del q, k, v
  55. r1 = r1.to(dtype)
  56. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  57. del r1
  58. return self.to_out(r2)
  59. # taken from https://github.com/Doggettx/stable-diffusion and modified
  60. def split_cross_attention_forward(self, x, context=None, mask=None):
  61. h = self.heads
  62. q_in = self.to_q(x)
  63. context = default(context, x)
  64. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  65. k_in = self.to_k(context_k)
  66. v_in = self.to_v(context_v)
  67. dtype = q_in.dtype
  68. if shared.opts.upcast_attn:
  69. q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
  70. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  71. k_in = k_in * self.scale
  72. del context, x
  73. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
  74. del q_in, k_in, v_in
  75. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  76. mem_free_total = get_available_vram()
  77. gb = 1024 ** 3
  78. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  79. modifier = 3 if q.element_size() == 2 else 2.5
  80. mem_required = tensor_size * modifier
  81. steps = 1
  82. if mem_required > mem_free_total:
  83. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  84. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  85. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  86. if steps > 64:
  87. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  88. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  89. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  90. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  91. for i in range(0, q.shape[1], slice_size):
  92. end = i + slice_size
  93. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  94. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  95. del s1
  96. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  97. del s2
  98. del q, k, v
  99. r1 = r1.to(dtype)
  100. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  101. del r1
  102. return self.to_out(r2)
  103. # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
  104. mem_total_gb = psutil.virtual_memory().total // (1 << 30)
  105. def einsum_op_compvis(q, k, v):
  106. s = einsum('b i d, b j d -> b i j', q, k)
  107. s = s.softmax(dim=-1, dtype=s.dtype)
  108. return einsum('b i j, b j d -> b i d', s, v)
  109. def einsum_op_slice_0(q, k, v, slice_size):
  110. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  111. for i in range(0, q.shape[0], slice_size):
  112. end = i + slice_size
  113. r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
  114. return r
  115. def einsum_op_slice_1(q, k, v, slice_size):
  116. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  117. for i in range(0, q.shape[1], slice_size):
  118. end = i + slice_size
  119. r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
  120. return r
  121. def einsum_op_mps_v1(q, k, v):
  122. if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
  123. return einsum_op_compvis(q, k, v)
  124. else:
  125. slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
  126. if slice_size % 4096 == 0:
  127. slice_size -= 1
  128. return einsum_op_slice_1(q, k, v, slice_size)
  129. def einsum_op_mps_v2(q, k, v):
  130. if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
  131. return einsum_op_compvis(q, k, v)
  132. else:
  133. return einsum_op_slice_0(q, k, v, 1)
  134. def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
  135. size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
  136. if size_mb <= max_tensor_mb:
  137. return einsum_op_compvis(q, k, v)
  138. div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
  139. if div <= q.shape[0]:
  140. return einsum_op_slice_0(q, k, v, q.shape[0] // div)
  141. return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
  142. def einsum_op_cuda(q, k, v):
  143. stats = torch.cuda.memory_stats(q.device)
  144. mem_active = stats['active_bytes.all.current']
  145. mem_reserved = stats['reserved_bytes.all.current']
  146. mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
  147. mem_free_torch = mem_reserved - mem_active
  148. mem_free_total = mem_free_cuda + mem_free_torch
  149. # Divide factor of safety as there's copying and fragmentation
  150. return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
  151. def einsum_op(q, k, v):
  152. if q.device.type == 'cuda':
  153. return einsum_op_cuda(q, k, v)
  154. if q.device.type == 'mps':
  155. if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
  156. return einsum_op_mps_v1(q, k, v)
  157. return einsum_op_mps_v2(q, k, v)
  158. # Smaller slices are faster due to L2/L3/SLC caches.
  159. # Tested on i7 with 8MB L3 cache.
  160. return einsum_op_tensor_mem(q, k, v, 32)
  161. def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
  162. h = self.heads
  163. q = self.to_q(x)
  164. context = default(context, x)
  165. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  166. k = self.to_k(context_k)
  167. v = self.to_v(context_v)
  168. del context, context_k, context_v, x
  169. dtype = q.dtype
  170. if shared.opts.upcast_attn:
  171. q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
  172. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  173. k = k * self.scale
  174. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  175. r = einsum_op(q, k, v)
  176. r = r.to(dtype)
  177. return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
  178. # -- End of code from https://github.com/invoke-ai/InvokeAI --
  179. # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
  180. # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
  181. def sub_quad_attention_forward(self, x, context=None, mask=None):
  182. assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
  183. h = self.heads
  184. q = self.to_q(x)
  185. context = default(context, x)
  186. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  187. k = self.to_k(context_k)
  188. v = self.to_v(context_v)
  189. del context, context_k, context_v, x
  190. q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  191. k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  192. v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  193. dtype = q.dtype
  194. if shared.opts.upcast_attn:
  195. q, k = q.float(), k.float()
  196. x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  197. x = x.to(dtype)
  198. x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
  199. out_proj, dropout = self.to_out
  200. x = out_proj(x)
  201. x = dropout(x)
  202. return x
  203. def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
  204. bytes_per_token = torch.finfo(q.dtype).bits//8
  205. batch_x_heads, q_tokens, _ = q.shape
  206. _, k_tokens, _ = k.shape
  207. qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
  208. if chunk_threshold is None:
  209. chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
  210. elif chunk_threshold == 0:
  211. chunk_threshold_bytes = None
  212. else:
  213. chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
  214. if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
  215. kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
  216. elif kv_chunk_size_min == 0:
  217. kv_chunk_size_min = None
  218. if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
  219. # the big matmul fits into our memory limit; do everything in 1 chunk,
  220. # i.e. send it down the unchunked fast-path
  221. query_chunk_size = q_tokens
  222. kv_chunk_size = k_tokens
  223. with devices.without_autocast(disable=q.dtype == v.dtype):
  224. return efficient_dot_product_attention(
  225. q,
  226. k,
  227. v,
  228. query_chunk_size=q_chunk_size,
  229. kv_chunk_size=kv_chunk_size,
  230. kv_chunk_size_min = kv_chunk_size_min,
  231. use_checkpoint=use_checkpoint,
  232. )
  233. def get_xformers_flash_attention_op(q, k, v):
  234. if not shared.cmd_opts.xformers_flash_attention:
  235. return None
  236. try:
  237. flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
  238. fw, bw = flash_attention_op
  239. if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
  240. return flash_attention_op
  241. except Exception as e:
  242. errors.display_once(e, "enabling flash attention")
  243. return None
  244. def xformers_attention_forward(self, x, context=None, mask=None):
  245. h = self.heads
  246. q_in = self.to_q(x)
  247. context = default(context, x)
  248. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  249. k_in = self.to_k(context_k)
  250. v_in = self.to_v(context_v)
  251. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
  252. del q_in, k_in, v_in
  253. dtype = q.dtype
  254. if shared.opts.upcast_attn:
  255. q, k, v = q.float(), k.float(), v.float()
  256. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
  257. out = out.to(dtype)
  258. out = rearrange(out, 'b n h d -> b n (h d)', h=h)
  259. return self.to_out(out)
  260. # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
  261. # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
  262. def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
  263. batch_size, sequence_length, inner_dim = x.shape
  264. if mask is not None:
  265. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  266. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  267. h = self.heads
  268. q_in = self.to_q(x)
  269. context = default(context, x)
  270. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  271. k_in = self.to_k(context_k)
  272. v_in = self.to_v(context_v)
  273. head_dim = inner_dim // h
  274. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  275. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  276. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  277. del q_in, k_in, v_in
  278. dtype = q.dtype
  279. if shared.opts.upcast_attn:
  280. q, k, v = q.float(), k.float(), v.float()
  281. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  282. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  283. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  284. )
  285. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
  286. hidden_states = hidden_states.to(dtype)
  287. # linear proj
  288. hidden_states = self.to_out[0](hidden_states)
  289. # dropout
  290. hidden_states = self.to_out[1](hidden_states)
  291. return hidden_states
  292. def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
  293. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  294. return scaled_dot_product_attention_forward(self, x, context, mask)
  295. def cross_attention_attnblock_forward(self, x):
  296. h_ = x
  297. h_ = self.norm(h_)
  298. q1 = self.q(h_)
  299. k1 = self.k(h_)
  300. v = self.v(h_)
  301. # compute attention
  302. b, c, h, w = q1.shape
  303. q2 = q1.reshape(b, c, h*w)
  304. del q1
  305. q = q2.permute(0, 2, 1) # b,hw,c
  306. del q2
  307. k = k1.reshape(b, c, h*w) # b,c,hw
  308. del k1
  309. h_ = torch.zeros_like(k, device=q.device)
  310. mem_free_total = get_available_vram()
  311. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  312. mem_required = tensor_size * 2.5
  313. steps = 1
  314. if mem_required > mem_free_total:
  315. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  316. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  317. for i in range(0, q.shape[1], slice_size):
  318. end = i + slice_size
  319. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  320. w2 = w1 * (int(c)**(-0.5))
  321. del w1
  322. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  323. del w2
  324. # attend to values
  325. v1 = v.reshape(b, c, h*w)
  326. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  327. del w3
  328. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  329. del v1, w4
  330. h2 = h_.reshape(b, c, h, w)
  331. del h_
  332. h3 = self.proj_out(h2)
  333. del h2
  334. h3 += x
  335. return h3
  336. def xformers_attnblock_forward(self, x):
  337. try:
  338. h_ = x
  339. h_ = self.norm(h_)
  340. q = self.q(h_)
  341. k = self.k(h_)
  342. v = self.v(h_)
  343. b, c, h, w = q.shape
  344. q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
  345. dtype = q.dtype
  346. if shared.opts.upcast_attn:
  347. q, k = q.float(), k.float()
  348. q = q.contiguous()
  349. k = k.contiguous()
  350. v = v.contiguous()
  351. out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
  352. out = out.to(dtype)
  353. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  354. out = self.proj_out(out)
  355. return x + out
  356. except NotImplementedError:
  357. return cross_attention_attnblock_forward(self, x)
  358. def sdp_attnblock_forward(self, x):
  359. h_ = x
  360. h_ = self.norm(h_)
  361. q = self.q(h_)
  362. k = self.k(h_)
  363. v = self.v(h_)
  364. b, c, h, w = q.shape
  365. q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
  366. dtype = q.dtype
  367. if shared.opts.upcast_attn:
  368. q, k = q.float(), k.float()
  369. q = q.contiguous()
  370. k = k.contiguous()
  371. v = v.contiguous()
  372. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
  373. out = out.to(dtype)
  374. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  375. out = self.proj_out(out)
  376. return x + out
  377. def sdp_no_mem_attnblock_forward(self, x):
  378. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  379. return sdp_attnblock_forward(self, x)
  380. def sub_quad_attnblock_forward(self, x):
  381. h_ = x
  382. h_ = self.norm(h_)
  383. q = self.q(h_)
  384. k = self.k(h_)
  385. v = self.v(h_)
  386. b, c, h, w = q.shape
  387. q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
  388. q = q.contiguous()
  389. k = k.contiguous()
  390. v = v.contiguous()
  391. out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  392. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  393. out = self.proj_out(out)
  394. return x + out