hypernetwork_nai.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # NAI compatible
  2. import torch
  3. class HypernetworkModule(torch.nn.Module):
  4. def __init__(self, dim, multiplier=1.0):
  5. super().__init__()
  6. linear1 = torch.nn.Linear(dim, dim * 2)
  7. linear2 = torch.nn.Linear(dim * 2, dim)
  8. linear1.weight.data.normal_(mean=0.0, std=0.01)
  9. linear1.bias.data.zero_()
  10. linear2.weight.data.normal_(mean=0.0, std=0.01)
  11. linear2.bias.data.zero_()
  12. linears = [linear1, linear2]
  13. self.linear = torch.nn.Sequential(*linears)
  14. self.multiplier = multiplier
  15. def forward(self, x):
  16. return x + self.linear(x) * self.multiplier
  17. class Hypernetwork(torch.nn.Module):
  18. enable_sizes = [320, 640, 768, 1280]
  19. # return self.modules[Hypernetwork.enable_sizes.index(size)]
  20. def __init__(self, multiplier=1.0) -> None:
  21. super().__init__()
  22. self.modules = []
  23. for size in Hypernetwork.enable_sizes:
  24. self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
  25. self.register_module(f"{size}_0", self.modules[-1][0])
  26. self.register_module(f"{size}_1", self.modules[-1][1])
  27. def apply_to_stable_diffusion(self, text_encoder, vae, unet):
  28. blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
  29. for block in blocks:
  30. for subblk in block:
  31. if 'SpatialTransformer' in str(type(subblk)):
  32. for tf_block in subblk.transformer_blocks:
  33. for attn in [tf_block.attn1, tf_block.attn2]:
  34. size = attn.context_dim
  35. if size in Hypernetwork.enable_sizes:
  36. attn.hypernetwork = self
  37. else:
  38. attn.hypernetwork = None
  39. def apply_to_diffusers(self, text_encoder, vae, unet):
  40. blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
  41. for block in blocks:
  42. if hasattr(block, 'attentions'):
  43. for subblk in block.attentions:
  44. if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
  45. for tf_block in subblk.transformer_blocks:
  46. for attn in [tf_block.attn1, tf_block.attn2]:
  47. size = attn.to_k.in_features
  48. if size in Hypernetwork.enable_sizes:
  49. attn.hypernetwork = self
  50. else:
  51. attn.hypernetwork = None
  52. return True # TODO error checking
  53. def forward(self, x, context):
  54. size = context.shape[-1]
  55. assert size in Hypernetwork.enable_sizes
  56. module = self.modules[Hypernetwork.enable_sizes.index(size)]
  57. return module[0].forward(context), module[1].forward(context)
  58. def load_from_state_dict(self, state_dict):
  59. # old ver to new ver
  60. changes = {
  61. 'linear1.bias': 'linear.0.bias',
  62. 'linear1.weight': 'linear.0.weight',
  63. 'linear2.bias': 'linear.1.bias',
  64. 'linear2.weight': 'linear.1.weight',
  65. }
  66. for key_from, key_to in changes.items():
  67. if key_from in state_dict:
  68. state_dict[key_to] = state_dict[key_from]
  69. del state_dict[key_from]
  70. for size, sd in state_dict.items():
  71. if type(size) == int:
  72. self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
  73. self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
  74. return True
  75. def get_state_dict(self):
  76. state_dict = {}
  77. for i, size in enumerate(Hypernetwork.enable_sizes):
  78. sd0 = self.modules[i][0].state_dict()
  79. sd1 = self.modules[i][1].state_dict()
  80. state_dict[size] = [sd0, sd1]
  81. return state_dict