cldm.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import torch
  2. import torch.nn as nn
  3. from omegaconf import OmegaConf
  4. from modules import devices, shared
  5. cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)
  6. from ldm.util import exists
  7. from ldm.modules.attention import SpatialTransformer
  8. from ldm.modules.diffusionmodules.util import conv_nd, linear, zero_module, timestep_embedding
  9. from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
  10. class TorchHijackForUnet:
  11. """
  12. This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
  13. this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
  14. """
  15. def __getattr__(self, item):
  16. if item == 'cat':
  17. return self.cat
  18. if hasattr(torch, item):
  19. return getattr(torch, item)
  20. raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
  21. def cat(self, tensors, *args, **kwargs):
  22. if len(tensors) == 2:
  23. a, b = tensors
  24. if a.shape[-2:] != b.shape[-2:]:
  25. a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
  26. tensors = (a, b)
  27. return torch.cat(tensors, *args, **kwargs)
  28. th = TorchHijackForUnet()
  29. def align(hint, size):
  30. b, c, h1, w1 = hint.shape
  31. h, w = size
  32. if h != h1 or w != w1:
  33. hint = th.nn.functional.interpolate(hint, size=size, mode="nearest")
  34. return hint
  35. def get_node_name(name, parent_name):
  36. if len(name) <= len(parent_name):
  37. return False, ''
  38. p = name[:len(parent_name)]
  39. if p != parent_name:
  40. return False, ''
  41. return True, name[len(parent_name):]
  42. class PlugableControlModel(nn.Module):
  43. def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> None:
  44. super().__init__()
  45. self.config = OmegaConf.load(config_path)
  46. self.control_model = ControlNet(**self.config.model.params.control_stage_config.params)
  47. if any([k.startswith("control_model.") for k, v in state_dict.items()]):
  48. if 'difference' in state_dict and base_model is not None:
  49. print('We will stop supporting diff models soon because of its lack of robustness.')
  50. print('Please begin to use official models as soon as possible.')
  51. unet_state_dict = base_model.state_dict()
  52. unet_state_dict_keys = unet_state_dict.keys()
  53. final_state_dict = {}
  54. counter = 0
  55. for key in state_dict.keys():
  56. if not key.startswith("control_model."):
  57. continue
  58. p = state_dict[key]
  59. is_control, node_name = get_node_name(key, 'control_')
  60. key_name = node_name.replace("model.", "") if is_control else key
  61. if key_name in unet_state_dict_keys:
  62. p_new = p + unet_state_dict[key_name].clone().cpu()
  63. counter += 1
  64. else:
  65. p_new = p
  66. final_state_dict[key] = p_new
  67. print(f'Diff model cloned: {counter} values')
  68. state_dict = final_state_dict
  69. state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items() if k.startswith("control_model.")}
  70. self.control_model.load_state_dict(state_dict)
  71. if not lowvram:
  72. self.control_model.to(devices.get_device_for("controlnet"))
  73. def reset(self):
  74. pass
  75. def forward(self, *args, **kwargs):
  76. return self.control_model(*args, **kwargs)
  77. class ControlNet(nn.Module):
  78. def __init__(
  79. self,
  80. image_size,
  81. in_channels,
  82. model_channels,
  83. hint_channels,
  84. num_res_blocks,
  85. attention_resolutions,
  86. dropout=0,
  87. channel_mult=(1, 2, 4, 8),
  88. conv_resample=True,
  89. dims=2,
  90. use_checkpoint=False,
  91. use_fp16=False,
  92. num_heads=-1,
  93. num_head_channels=-1,
  94. num_heads_upsample=-1,
  95. use_scale_shift_norm=False,
  96. resblock_updown=False,
  97. use_new_attention_order=False,
  98. use_spatial_transformer=False, # custom transformer support
  99. transformer_depth=1, # custom transformer support
  100. context_dim=None, # custom transformer support
  101. # custom support for prediction of discrete ids into codebook of first stage vq model
  102. n_embed=None,
  103. legacy=True,
  104. disable_self_attentions=None,
  105. num_attention_blocks=None,
  106. disable_middle_self_attn=False,
  107. use_linear_in_transformer=False,
  108. ):
  109. use_fp16 = getattr(devices, 'dtype_unet', devices.dtype) == th.float16 and not getattr(shared.cmd_opts, "no_half_controlnet", False)
  110. super().__init__()
  111. if use_spatial_transformer:
  112. assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
  113. if context_dim is not None:
  114. assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
  115. from omegaconf.listconfig import ListConfig
  116. if type(context_dim) == ListConfig:
  117. context_dim = list(context_dim)
  118. if num_heads_upsample == -1:
  119. num_heads_upsample = num_heads
  120. if num_heads == -1:
  121. assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
  122. if num_head_channels == -1:
  123. assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
  124. self.dims = dims
  125. self.image_size = image_size
  126. self.in_channels = in_channels
  127. self.model_channels = model_channels
  128. if isinstance(num_res_blocks, int):
  129. self.num_res_blocks = len(channel_mult) * [num_res_blocks]
  130. else:
  131. if len(num_res_blocks) != len(channel_mult):
  132. raise ValueError("provide num_res_blocks either as an int (globally constant) or "
  133. "as a list/tuple (per-level) with the same length as channel_mult")
  134. self.num_res_blocks = num_res_blocks
  135. if disable_self_attentions is not None:
  136. # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
  137. assert len(disable_self_attentions) == len(channel_mult)
  138. if num_attention_blocks is not None:
  139. assert len(num_attention_blocks) == len(self.num_res_blocks)
  140. assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(
  141. len(num_attention_blocks))))
  142. print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
  143. f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
  144. f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
  145. f"attention will still not be set.")
  146. self.attention_resolutions = attention_resolutions
  147. self.dropout = dropout
  148. self.channel_mult = channel_mult
  149. self.conv_resample = conv_resample
  150. self.use_checkpoint = use_checkpoint
  151. self.dtype = th.float16 if use_fp16 else th.float32
  152. self.num_heads = num_heads
  153. self.num_head_channels = num_head_channels
  154. self.num_heads_upsample = num_heads_upsample
  155. self.predict_codebook_ids = n_embed is not None
  156. time_embed_dim = model_channels * 4
  157. self.time_embed = nn.Sequential(
  158. linear(model_channels, time_embed_dim),
  159. nn.SiLU(),
  160. linear(time_embed_dim, time_embed_dim),
  161. )
  162. self.input_blocks = nn.ModuleList(
  163. [
  164. TimestepEmbedSequential(
  165. conv_nd(dims, in_channels, model_channels, 3, padding=1)
  166. )
  167. ]
  168. )
  169. self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
  170. self.input_hint_block = TimestepEmbedSequential(
  171. conv_nd(dims, hint_channels, 16, 3, padding=1),
  172. nn.SiLU(),
  173. conv_nd(dims, 16, 16, 3, padding=1),
  174. nn.SiLU(),
  175. conv_nd(dims, 16, 32, 3, padding=1, stride=2),
  176. nn.SiLU(),
  177. conv_nd(dims, 32, 32, 3, padding=1),
  178. nn.SiLU(),
  179. conv_nd(dims, 32, 96, 3, padding=1, stride=2),
  180. nn.SiLU(),
  181. conv_nd(dims, 96, 96, 3, padding=1),
  182. nn.SiLU(),
  183. conv_nd(dims, 96, 256, 3, padding=1, stride=2),
  184. nn.SiLU(),
  185. zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
  186. )
  187. self._feature_size = model_channels
  188. input_block_chans = [model_channels]
  189. ch = model_channels
  190. ds = 1
  191. for level, mult in enumerate(channel_mult):
  192. for nr in range(self.num_res_blocks[level]):
  193. layers = [
  194. ResBlock(
  195. ch,
  196. time_embed_dim,
  197. dropout,
  198. out_channels=mult * model_channels,
  199. dims=dims,
  200. use_checkpoint=use_checkpoint,
  201. use_scale_shift_norm=use_scale_shift_norm,
  202. )
  203. ]
  204. ch = mult * model_channels
  205. if ds in attention_resolutions:
  206. if num_head_channels == -1:
  207. dim_head = ch // num_heads
  208. else:
  209. num_heads = ch // num_head_channels
  210. dim_head = num_head_channels
  211. if legacy:
  212. #num_heads = 1
  213. dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  214. if exists(disable_self_attentions):
  215. disabled_sa = disable_self_attentions[level]
  216. else:
  217. disabled_sa = False
  218. if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
  219. layers.append(
  220. AttentionBlock(
  221. ch,
  222. use_checkpoint=use_checkpoint,
  223. num_heads=num_heads,
  224. num_head_channels=dim_head,
  225. use_new_attention_order=use_new_attention_order,
  226. ) if not use_spatial_transformer else SpatialTransformer(
  227. ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
  228. disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
  229. use_checkpoint=use_checkpoint
  230. )
  231. )
  232. self.input_blocks.append(TimestepEmbedSequential(*layers))
  233. self.zero_convs.append(self.make_zero_conv(ch))
  234. self._feature_size += ch
  235. input_block_chans.append(ch)
  236. if level != len(channel_mult) - 1:
  237. out_ch = ch
  238. self.input_blocks.append(
  239. TimestepEmbedSequential(
  240. ResBlock(
  241. ch,
  242. time_embed_dim,
  243. dropout,
  244. out_channels=out_ch,
  245. dims=dims,
  246. use_checkpoint=use_checkpoint,
  247. use_scale_shift_norm=use_scale_shift_norm,
  248. down=True,
  249. )
  250. if resblock_updown
  251. else Downsample(
  252. ch, conv_resample, dims=dims, out_channels=out_ch
  253. )
  254. )
  255. )
  256. ch = out_ch
  257. input_block_chans.append(ch)
  258. self.zero_convs.append(self.make_zero_conv(ch))
  259. ds *= 2
  260. self._feature_size += ch
  261. if num_head_channels == -1:
  262. dim_head = ch // num_heads
  263. else:
  264. num_heads = ch // num_head_channels
  265. dim_head = num_head_channels
  266. if legacy:
  267. #num_heads = 1
  268. dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  269. self.middle_block = TimestepEmbedSequential(
  270. ResBlock(
  271. ch,
  272. time_embed_dim,
  273. dropout,
  274. dims=dims,
  275. use_checkpoint=use_checkpoint,
  276. use_scale_shift_norm=use_scale_shift_norm,
  277. ),
  278. AttentionBlock(
  279. ch,
  280. use_checkpoint=use_checkpoint,
  281. num_heads=num_heads,
  282. num_head_channels=dim_head,
  283. use_new_attention_order=use_new_attention_order,
  284. # always uses a self-attn
  285. ) if not use_spatial_transformer else SpatialTransformer(
  286. ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
  287. disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
  288. use_checkpoint=use_checkpoint
  289. ),
  290. ResBlock(
  291. ch,
  292. time_embed_dim,
  293. dropout,
  294. dims=dims,
  295. use_checkpoint=use_checkpoint,
  296. use_scale_shift_norm=use_scale_shift_norm,
  297. ),
  298. )
  299. self.middle_block_out = self.make_zero_conv(ch)
  300. self._feature_size += ch
  301. def make_zero_conv(self, channels):
  302. return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
  303. def align(self, hint, h, w):
  304. b, c, h1, w1 = hint.shape
  305. if h != h1 or w != w1:
  306. return align(hint, (h, w))
  307. return hint
  308. def forward(self, x, hint, timesteps, context, **kwargs):
  309. t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
  310. emb = self.time_embed(t_emb)
  311. guided_hint = self.input_hint_block(cond_cast_unet(hint), emb, context)
  312. outs = []
  313. h1, w1 = x.shape[-2:]
  314. guided_hint = self.align(guided_hint, h1, w1)
  315. h = x.type(self.dtype)
  316. for module, zero_conv in zip(self.input_blocks, self.zero_convs):
  317. if guided_hint is not None:
  318. h = module(h, emb, context)
  319. h += guided_hint
  320. guided_hint = None
  321. else:
  322. h = module(h, emb, context)
  323. outs.append(zero_conv(h, emb, context))
  324. h = self.middle_block(h, emb, context)
  325. outs.append(self.middle_block_out(h, emb, context))
  326. return outs