lora_compvis.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. # LoRA network module
  2. # reference:
  3. # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
  4. # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
  5. import copy
  6. import math
  7. import re
  8. from typing import NamedTuple
  9. import torch
  10. class LoRAInfo(NamedTuple):
  11. lora_name: str
  12. module_name: str
  13. module: torch.nn.Module
  14. multiplier: float
  15. dim: int
  16. alpha: float
  17. class LoRAModule(torch.nn.Module):
  18. """
  19. replaces forward method of the original Linear, instead of replacing the original Linear module.
  20. """
  21. def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
  22. """if alpha == 0 or None, alpha is rank (no scaling)."""
  23. super().__init__()
  24. self.lora_name = lora_name
  25. self.lora_dim = lora_dim
  26. if org_module.__class__.__name__ == "Conv2d":
  27. in_dim = org_module.in_channels
  28. out_dim = org_module.out_channels
  29. # self.lora_dim = min(self.lora_dim, in_dim, out_dim)
  30. # if self.lora_dim != lora_dim:
  31. # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
  32. kernel_size = org_module.kernel_size
  33. stride = org_module.stride
  34. padding = org_module.padding
  35. self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
  36. self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
  37. else:
  38. in_dim = org_module.in_features
  39. out_dim = org_module.out_features
  40. self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
  41. self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
  42. if type(alpha) == torch.Tensor:
  43. alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
  44. alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
  45. self.scale = alpha / self.lora_dim
  46. self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
  47. # same as microsoft's
  48. torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
  49. torch.nn.init.zeros_(self.lora_up.weight)
  50. self.multiplier = multiplier
  51. self.org_forward = org_module.forward
  52. self.org_module = org_module # remove in applying
  53. self.mask_dic = None
  54. self.mask = None
  55. self.mask_area = -1
  56. def apply_to(self):
  57. self.org_forward = self.org_module.forward
  58. self.org_module.forward = self.forward
  59. del self.org_module
  60. def set_mask_dic(self, mask_dic):
  61. # called before every generation
  62. # check this module is related to h,w (not context and time emb)
  63. if "attn2_to_k" in self.lora_name or "attn2_to_v" in self.lora_name or "emb_layers" in self.lora_name:
  64. # print(f"LoRA for context or time emb: {self.lora_name}")
  65. self.mask_dic = None
  66. else:
  67. self.mask_dic = mask_dic
  68. self.mask = None
  69. def forward(self, x):
  70. """
  71. may be cascaded.
  72. """
  73. if self.mask_dic is None:
  74. return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
  75. # regional LoRA
  76. # calculate lora and get size
  77. lx = self.lora_up(self.lora_down(x))
  78. if len(lx.size()) == 4: # b,c,h,w
  79. area = lx.size()[2] * lx.size()[3]
  80. else:
  81. area = lx.size()[1] # b,seq,dim
  82. if self.mask is None or self.mask_area != area:
  83. # get mask
  84. # print(self.lora_name, x.size(), lx.size(), area)
  85. mask = self.mask_dic[area]
  86. if len(lx.size()) == 3:
  87. mask = torch.reshape(mask, (1, -1, 1))
  88. self.mask = mask
  89. self.mask_area = area
  90. return self.org_forward(x) + lx * self.multiplier * self.scale * self.mask
  91. def create_network_and_apply_compvis(du_state_dict, multiplier_tenc, multiplier_unet, text_encoder, unet, **kwargs):
  92. # get device and dtype from unet
  93. for module in unet.modules():
  94. if module.__class__.__name__ == "Linear":
  95. param: torch.nn.Parameter = module.weight
  96. # device = param.device
  97. dtype = param.dtype
  98. break
  99. # get dims (rank) and alpha from state dict
  100. modules_dim = {}
  101. modules_alpha = {}
  102. for key, value in du_state_dict.items():
  103. if "." not in key:
  104. continue
  105. lora_name = key.split(".")[0]
  106. if "alpha" in key:
  107. modules_alpha[lora_name] = float(value.detach().to(torch.float).cpu().numpy())
  108. elif "lora_down" in key:
  109. dim = value.size()[0]
  110. modules_dim[lora_name] = dim
  111. # support old LoRA without alpha
  112. for key in modules_dim.keys():
  113. if key not in modules_alpha:
  114. modules_alpha[key] = modules_dim[key]
  115. print(
  116. f"dimension: {set(modules_dim.values())}, alpha: {set(modules_alpha.values())}, multiplier_unet: {multiplier_unet}, multiplier_tenc: {multiplier_tenc}"
  117. )
  118. # if network_dim is None:
  119. # print(f"The selected model is not LoRA or not trained by `sd-scripts`?")
  120. # network_dim = 4
  121. # network_alpha = 1
  122. # create, apply and load weights
  123. network = LoRANetworkCompvis(text_encoder, unet, multiplier_tenc, multiplier_unet, modules_dim, modules_alpha)
  124. state_dict = network.apply_lora_modules(du_state_dict) # some weights are applied to text encoder
  125. network.to(dtype) # with this, if error comes from next line, the model will be used
  126. info = network.load_state_dict(state_dict, strict=False)
  127. # remove redundant warnings
  128. if len(info.missing_keys) > 4:
  129. missing_keys = []
  130. alpha_count = 0
  131. for key in info.missing_keys:
  132. if "alpha" not in key:
  133. missing_keys.append(key)
  134. else:
  135. if alpha_count == 0:
  136. missing_keys.append(key)
  137. alpha_count += 1
  138. if alpha_count > 1:
  139. missing_keys.append(
  140. f"... and {alpha_count-1} alphas. The model doesn't have alpha, use dim (rannk) as alpha. You can ignore this message."
  141. )
  142. info = torch.nn.modules.module._IncompatibleKeys(missing_keys, info.unexpected_keys)
  143. return network, info
  144. class LoRANetworkCompvis(torch.nn.Module):
  145. # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
  146. # TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
  147. UNET_TARGET_REPLACE_MODULE = ["SpatialTransformer", "ResBlock", "Downsample", "Upsample"] # , "Attention"]
  148. TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"]
  149. LORA_PREFIX_UNET = "lora_unet"
  150. LORA_PREFIX_TEXT_ENCODER = "lora_te"
  151. @classmethod
  152. def convert_diffusers_name_to_compvis(cls, v2, du_name):
  153. """
  154. convert diffusers's LoRA name to CompVis
  155. """
  156. cv_name = None
  157. if "lora_unet_" in du_name:
  158. m = re.search(r"_down_blocks_(\d+)_attentions_(\d+)_(.+)", du_name)
  159. if m:
  160. du_block_index = int(m.group(1))
  161. du_attn_index = int(m.group(2))
  162. du_suffix = m.group(3)
  163. cv_index = 1 + du_block_index * 3 + du_attn_index # 1,2, 4,5, 7,8
  164. cv_name = f"lora_unet_input_blocks_{cv_index}_1_{du_suffix}"
  165. return cv_name
  166. m = re.search(r"_mid_block_attentions_(\d+)_(.+)", du_name)
  167. if m:
  168. du_suffix = m.group(2)
  169. cv_name = f"lora_unet_middle_block_1_{du_suffix}"
  170. return cv_name
  171. m = re.search(r"_up_blocks_(\d+)_attentions_(\d+)_(.+)", du_name)
  172. if m:
  173. du_block_index = int(m.group(1))
  174. du_attn_index = int(m.group(2))
  175. du_suffix = m.group(3)
  176. cv_index = du_block_index * 3 + du_attn_index # 3,4,5, 6,7,8, 9,10,11
  177. cv_name = f"lora_unet_output_blocks_{cv_index}_1_{du_suffix}"
  178. return cv_name
  179. m = re.search(r"_down_blocks_(\d+)_resnets_(\d+)_(.+)", du_name)
  180. if m:
  181. du_block_index = int(m.group(1))
  182. du_res_index = int(m.group(2))
  183. du_suffix = m.group(3)
  184. cv_suffix = {
  185. "conv1": "in_layers_2",
  186. "conv2": "out_layers_3",
  187. "time_emb_proj": "emb_layers_1",
  188. "conv_shortcut": "skip_connection",
  189. }[du_suffix]
  190. cv_index = 1 + du_block_index * 3 + du_res_index # 1,2, 4,5, 7,8
  191. cv_name = f"lora_unet_input_blocks_{cv_index}_0_{cv_suffix}"
  192. return cv_name
  193. m = re.search(r"_down_blocks_(\d+)_downsamplers_0_conv", du_name)
  194. if m:
  195. block_index = int(m.group(1))
  196. cv_index = 3 + block_index * 3
  197. cv_name = f"lora_unet_input_blocks_{cv_index}_0_op"
  198. return cv_name
  199. m = re.search(r"_mid_block_resnets_(\d+)_(.+)", du_name)
  200. if m:
  201. index = int(m.group(1))
  202. du_suffix = m.group(2)
  203. cv_suffix = {
  204. "conv1": "in_layers_2",
  205. "conv2": "out_layers_3",
  206. "time_emb_proj": "emb_layers_1",
  207. "conv_shortcut": "skip_connection",
  208. }[du_suffix]
  209. cv_name = f"lora_unet_middle_block_{index*2}_{cv_suffix}"
  210. return cv_name
  211. m = re.search(r"_up_blocks_(\d+)_resnets_(\d+)_(.+)", du_name)
  212. if m:
  213. du_block_index = int(m.group(1))
  214. du_res_index = int(m.group(2))
  215. du_suffix = m.group(3)
  216. cv_suffix = {
  217. "conv1": "in_layers_2",
  218. "conv2": "out_layers_3",
  219. "time_emb_proj": "emb_layers_1",
  220. "conv_shortcut": "skip_connection",
  221. }[du_suffix]
  222. cv_index = du_block_index * 3 + du_res_index # 1,2, 4,5, 7,8
  223. cv_name = f"lora_unet_output_blocks_{cv_index}_0_{cv_suffix}"
  224. return cv_name
  225. m = re.search(r"_up_blocks_(\d+)_upsamplers_0_conv", du_name)
  226. if m:
  227. block_index = int(m.group(1))
  228. cv_index = block_index * 3 + 2
  229. cv_name = f"lora_unet_output_blocks_{cv_index}_{bool(block_index)+1}_conv"
  230. return cv_name
  231. elif "lora_te_" in du_name:
  232. m = re.search(r"_model_encoder_layers_(\d+)_(.+)", du_name)
  233. if m:
  234. du_block_index = int(m.group(1))
  235. du_suffix = m.group(2)
  236. cv_index = du_block_index
  237. if v2:
  238. if "mlp_fc1" in du_suffix:
  239. cv_name = (
  240. f"lora_te_wrapped_model_transformer_resblocks_{cv_index}_{du_suffix.replace('mlp_fc1', 'mlp_c_fc')}"
  241. )
  242. elif "mlp_fc2" in du_suffix:
  243. cv_name = (
  244. f"lora_te_wrapped_model_transformer_resblocks_{cv_index}_{du_suffix.replace('mlp_fc2', 'mlp_c_proj')}"
  245. )
  246. elif "self_attn":
  247. # handled later
  248. cv_name = f"lora_te_wrapped_model_transformer_resblocks_{cv_index}_{du_suffix.replace('self_attn', 'attn')}"
  249. else:
  250. cv_name = f"lora_te_wrapped_transformer_text_model_encoder_layers_{cv_index}_{du_suffix}"
  251. assert cv_name is not None, f"conversion failed: {du_name}. the model may not be trained by `sd-scripts`."
  252. return cv_name
  253. @classmethod
  254. def convert_state_dict_name_to_compvis(cls, v2, state_dict):
  255. """
  256. convert keys in state dict to load it by load_state_dict
  257. """
  258. new_sd = {}
  259. for key, value in state_dict.items():
  260. tokens = key.split(".")
  261. compvis_name = LoRANetworkCompvis.convert_diffusers_name_to_compvis(v2, tokens[0])
  262. new_key = compvis_name + "." + ".".join(tokens[1:])
  263. new_sd[new_key] = value
  264. return new_sd
  265. def __init__(self, text_encoder, unet, multiplier_tenc=1.0, multiplier_unet=1.0, modules_dim=None, modules_alpha=None) -> None:
  266. super().__init__()
  267. self.multiplier_unet = multiplier_unet
  268. self.multiplier_tenc = multiplier_tenc
  269. self.latest_mask_info = None
  270. # check v1 or v2
  271. self.v2 = False
  272. for _, module in text_encoder.named_modules():
  273. for _, child_module in module.named_modules():
  274. if child_module.__class__.__name__ == "MultiheadAttention":
  275. self.v2 = True
  276. break
  277. if self.v2:
  278. break
  279. # convert lora name to CompVis and get dim and alpha
  280. comp_vis_loras_dim_alpha = {}
  281. for du_lora_name in modules_dim.keys():
  282. dim = modules_dim[du_lora_name]
  283. alpha = modules_alpha[du_lora_name]
  284. comp_vis_lora_name = LoRANetworkCompvis.convert_diffusers_name_to_compvis(self.v2, du_lora_name)
  285. comp_vis_loras_dim_alpha[comp_vis_lora_name] = (dim, alpha)
  286. # create module instances
  287. def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules, multiplier):
  288. loras = []
  289. replaced_modules = []
  290. for name, module in root_module.named_modules():
  291. if module.__class__.__name__ in target_replace_modules:
  292. for child_name, child_module in module.named_modules():
  293. # enumerate all Linear and Conv2d
  294. if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
  295. lora_name = prefix + "." + name + "." + child_name
  296. lora_name = lora_name.replace(".", "_")
  297. if "_resblocks_23_" in lora_name: # ignore last block in StabilityAi Text Encoder
  298. break
  299. if lora_name not in comp_vis_loras_dim_alpha:
  300. continue
  301. dim, alpha = comp_vis_loras_dim_alpha[lora_name]
  302. lora = LoRAModule(lora_name, child_module, multiplier, dim, alpha)
  303. loras.append(lora)
  304. replaced_modules.append(child_module)
  305. elif child_module.__class__.__name__ == "MultiheadAttention":
  306. # make four modules: not replacing forward method but merge weights later
  307. for suffix in ["q_proj", "k_proj", "v_proj", "out_proj"]:
  308. module_name = prefix + "." + name + "." + child_name # ~.attn
  309. module_name = module_name.replace(".", "_")
  310. if "_resblocks_23_" in module_name: # ignore last block in StabilityAi Text Encoder
  311. break
  312. lora_name = module_name + "_" + suffix
  313. if lora_name not in comp_vis_loras_dim_alpha:
  314. continue
  315. dim, alpha = comp_vis_loras_dim_alpha[lora_name]
  316. lora_info = LoRAInfo(lora_name, module_name, child_module, multiplier, dim, alpha)
  317. loras.append(lora_info)
  318. replaced_modules.append(child_module)
  319. return loras, replaced_modules
  320. self.text_encoder_loras, te_rep_modules = create_modules(
  321. LoRANetworkCompvis.LORA_PREFIX_TEXT_ENCODER,
  322. text_encoder,
  323. LoRANetworkCompvis.TEXT_ENCODER_TARGET_REPLACE_MODULE,
  324. self.multiplier_tenc,
  325. )
  326. print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
  327. self.unet_loras, unet_rep_modules = create_modules(
  328. LoRANetworkCompvis.LORA_PREFIX_UNET, unet, LoRANetworkCompvis.UNET_TARGET_REPLACE_MODULE, self.multiplier_unet
  329. )
  330. print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
  331. # make backup of original forward/weights, if multiple modules are applied, do in 1st module only
  332. backed_up = False # messaging purpose only
  333. for rep_module in te_rep_modules + unet_rep_modules:
  334. if (
  335. rep_module.__class__.__name__ == "MultiheadAttention"
  336. ): # multiple MHA modules are in list, prevent to backed up forward
  337. if not hasattr(rep_module, "_lora_org_weights"):
  338. # avoid updating of original weights. state_dict is reference to original weights
  339. rep_module._lora_org_weights = copy.deepcopy(rep_module.state_dict())
  340. backed_up = True
  341. elif not hasattr(rep_module, "_lora_org_forward"):
  342. rep_module._lora_org_forward = rep_module.forward
  343. backed_up = True
  344. if backed_up:
  345. print("original forward/weights is backed up.")
  346. # assertion
  347. names = set()
  348. for lora in self.text_encoder_loras + self.unet_loras:
  349. assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
  350. names.add(lora.lora_name)
  351. def restore(self, text_encoder, unet):
  352. # restore forward/weights from property for all modules
  353. restored = False # messaging purpose only
  354. modules = []
  355. modules.extend(text_encoder.modules())
  356. modules.extend(unet.modules())
  357. for module in modules:
  358. if hasattr(module, "_lora_org_forward"):
  359. module.forward = module._lora_org_forward
  360. del module._lora_org_forward
  361. restored = True
  362. if hasattr(
  363. module, "_lora_org_weights"
  364. ): # module doesn't have forward and weights at same time currently, but supports it for future changing
  365. module.load_state_dict(module._lora_org_weights)
  366. del module._lora_org_weights
  367. restored = True
  368. if restored:
  369. print("original forward/weights is restored.")
  370. def apply_lora_modules(self, du_state_dict):
  371. # conversion 1st step: convert names in state_dict
  372. state_dict = LoRANetworkCompvis.convert_state_dict_name_to_compvis(self.v2, du_state_dict)
  373. # check state_dict has text_encoder or unet
  374. weights_has_text_encoder = weights_has_unet = False
  375. for key in state_dict.keys():
  376. if key.startswith(LoRANetworkCompvis.LORA_PREFIX_TEXT_ENCODER):
  377. weights_has_text_encoder = True
  378. elif key.startswith(LoRANetworkCompvis.LORA_PREFIX_UNET):
  379. weights_has_unet = True
  380. if weights_has_text_encoder and weights_has_unet:
  381. break
  382. apply_text_encoder = weights_has_text_encoder
  383. apply_unet = weights_has_unet
  384. if apply_text_encoder:
  385. print("enable LoRA for text encoder")
  386. else:
  387. self.text_encoder_loras = []
  388. if apply_unet:
  389. print("enable LoRA for U-Net")
  390. else:
  391. self.unet_loras = []
  392. # add modules to network: this makes state_dict can be got from LoRANetwork
  393. mha_loras = {}
  394. for lora in self.text_encoder_loras + self.unet_loras:
  395. if type(lora) == LoRAModule:
  396. lora.apply_to() # ensure remove reference to original Linear: reference makes key of state_dict
  397. self.add_module(lora.lora_name, lora)
  398. else:
  399. # SD2.x MultiheadAttention merge weights to MHA weights
  400. lora_info: LoRAInfo = lora
  401. if lora_info.module_name not in mha_loras:
  402. mha_loras[lora_info.module_name] = {}
  403. lora_dic = mha_loras[lora_info.module_name]
  404. lora_dic[lora_info.lora_name] = lora_info
  405. if len(lora_dic) == 4:
  406. # calculate and apply
  407. module = lora_info.module
  408. module_name = lora_info.module_name
  409. w_q_dw = state_dict.get(module_name + "_q_proj.lora_down.weight")
  410. if w_q_dw is not None: # corresponding LoRA module exists
  411. w_q_up = state_dict[module_name + "_q_proj.lora_up.weight"]
  412. w_k_dw = state_dict[module_name + "_k_proj.lora_down.weight"]
  413. w_k_up = state_dict[module_name + "_k_proj.lora_up.weight"]
  414. w_v_dw = state_dict[module_name + "_v_proj.lora_down.weight"]
  415. w_v_up = state_dict[module_name + "_v_proj.lora_up.weight"]
  416. w_out_dw = state_dict[module_name + "_out_proj.lora_down.weight"]
  417. w_out_up = state_dict[module_name + "_out_proj.lora_up.weight"]
  418. q_lora_info = lora_dic[module_name + "_q_proj"]
  419. k_lora_info = lora_dic[module_name + "_k_proj"]
  420. v_lora_info = lora_dic[module_name + "_v_proj"]
  421. out_lora_info = lora_dic[module_name + "_out_proj"]
  422. sd = module.state_dict()
  423. qkv_weight = sd["in_proj_weight"]
  424. out_weight = sd["out_proj.weight"]
  425. dev = qkv_weight.device
  426. def merge_weights(l_info, weight, up_weight, down_weight):
  427. # calculate in float
  428. scale = l_info.alpha / l_info.dim
  429. dtype = weight.dtype
  430. weight = (
  431. weight.float()
  432. + l_info.multiplier
  433. * (up_weight.to(dev, dtype=torch.float) @ down_weight.to(dev, dtype=torch.float))
  434. * scale
  435. )
  436. weight = weight.to(dtype)
  437. return weight
  438. q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3)
  439. if q_weight.size()[1] == w_q_up.size()[0]:
  440. q_weight = merge_weights(q_lora_info, q_weight, w_q_up, w_q_dw)
  441. k_weight = merge_weights(k_lora_info, k_weight, w_k_up, w_k_dw)
  442. v_weight = merge_weights(v_lora_info, v_weight, w_v_up, w_v_dw)
  443. qkv_weight = torch.cat([q_weight, k_weight, v_weight])
  444. out_weight = merge_weights(out_lora_info, out_weight, w_out_up, w_out_dw)
  445. sd["in_proj_weight"] = qkv_weight.to(dev)
  446. sd["out_proj.weight"] = out_weight.to(dev)
  447. lora_info.module.load_state_dict(sd)
  448. else:
  449. # different dim, version mismatch
  450. print(f"shape of weight is different: {module_name}. SD version may be different")
  451. for t in ["q", "k", "v", "out"]:
  452. del state_dict[f"{module_name}_{t}_proj.lora_down.weight"]
  453. del state_dict[f"{module_name}_{t}_proj.lora_up.weight"]
  454. alpha_key = f"{module_name}_{t}_proj.alpha"
  455. if alpha_key in state_dict:
  456. del state_dict[alpha_key]
  457. else:
  458. # corresponding weight not exists: version mismatch
  459. pass
  460. # conversion 2nd step: convert weight's shape (and handle wrapped)
  461. state_dict = self.convert_state_dict_shape_to_compvis(state_dict)
  462. return state_dict
  463. def convert_state_dict_shape_to_compvis(self, state_dict):
  464. # shape conversion
  465. current_sd = self.state_dict() # to get target shape
  466. wrapped = False
  467. count = 0
  468. for key in list(state_dict.keys()):
  469. if key not in current_sd:
  470. continue # might be error or another version
  471. if "wrapped" in key:
  472. wrapped = True
  473. value: torch.Tensor = state_dict[key]
  474. if value.size() != current_sd[key].size():
  475. # print(f"convert weights shape: {key}, from: {value.size()}, {len(value.size())}")
  476. count += 1
  477. if len(value.size()) == 4:
  478. value = value.squeeze(3).squeeze(2)
  479. else:
  480. value = value.unsqueeze(2).unsqueeze(3)
  481. state_dict[key] = value
  482. if tuple(value.size()) != tuple(current_sd[key].size()):
  483. print(
  484. f"weight's shape is different: {key} expected {current_sd[key].size()} found {value.size()}. SD version may be different"
  485. )
  486. del state_dict[key]
  487. print(f"shapes for {count} weights are converted.")
  488. # convert wrapped
  489. if not wrapped:
  490. print("remove 'wrapped' from keys")
  491. for key in list(state_dict.keys()):
  492. if "_wrapped_" in key:
  493. new_key = key.replace("_wrapped_", "_")
  494. state_dict[new_key] = state_dict[key]
  495. del state_dict[key]
  496. return state_dict
  497. def set_mask(self, mask, height=None, width=None, hr_height=None, hr_width=None):
  498. if mask is None:
  499. # clear latest mask
  500. # print("clear mask")
  501. self.latest_mask_info = None
  502. for lora in self.unet_loras:
  503. lora.set_mask_dic(None)
  504. return
  505. # check mask image and h/w are same
  506. if (
  507. self.latest_mask_info is not None
  508. and torch.equal(mask, self.latest_mask_info[0])
  509. and (height, width, hr_height, hr_width) == self.latest_mask_info[1:]
  510. ):
  511. # print("mask not changed")
  512. return
  513. self.latest_mask_info = (mask, height, width, hr_height, hr_width)
  514. org_dtype = mask.dtype
  515. if mask.dtype == torch.bfloat16:
  516. mask = mask.to(torch.float)
  517. mask_dic = {}
  518. mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
  519. def resize_add(mh, mw):
  520. # print(mh, mw, mh * mw)
  521. m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
  522. m = m.to(org_dtype)
  523. mask_dic[mh * mw] = m
  524. for h, w in [(height, width), (hr_height, hr_width)]:
  525. h = h // 8
  526. w = w // 8
  527. for i in range(4):
  528. resize_add(h, w)
  529. if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
  530. resize_add(h + h % 2, w + w % 2)
  531. h = (h + 1) // 2
  532. w = (w + 1) // 2
  533. for lora in self.unet_loras:
  534. lora.set_mask_dic(mask_dic)
  535. return