123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- # NAI compatible
- import torch
- class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, multiplier=1.0):
- super().__init__()
- linear1 = torch.nn.Linear(dim, dim * 2)
- linear2 = torch.nn.Linear(dim * 2, dim)
- linear1.weight.data.normal_(mean=0.0, std=0.01)
- linear1.bias.data.zero_()
- linear2.weight.data.normal_(mean=0.0, std=0.01)
- linear2.bias.data.zero_()
- linears = [linear1, linear2]
- self.linear = torch.nn.Sequential(*linears)
- self.multiplier = multiplier
- def forward(self, x):
- return x + self.linear(x) * self.multiplier
- class Hypernetwork(torch.nn.Module):
- enable_sizes = [320, 640, 768, 1280]
- # return self.modules[Hypernetwork.enable_sizes.index(size)]
- def __init__(self, multiplier=1.0) -> None:
- super().__init__()
- self.modules = []
- for size in Hypernetwork.enable_sizes:
- self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
- self.register_module(f"{size}_0", self.modules[-1][0])
- self.register_module(f"{size}_1", self.modules[-1][1])
- def apply_to_stable_diffusion(self, text_encoder, vae, unet):
- blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
- for block in blocks:
- for subblk in block:
- if 'SpatialTransformer' in str(type(subblk)):
- for tf_block in subblk.transformer_blocks:
- for attn in [tf_block.attn1, tf_block.attn2]:
- size = attn.context_dim
- if size in Hypernetwork.enable_sizes:
- attn.hypernetwork = self
- else:
- attn.hypernetwork = None
- def apply_to_diffusers(self, text_encoder, vae, unet):
- blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
- for block in blocks:
- if hasattr(block, 'attentions'):
- for subblk in block.attentions:
- if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
- for tf_block in subblk.transformer_blocks:
- for attn in [tf_block.attn1, tf_block.attn2]:
- size = attn.to_k.in_features
- if size in Hypernetwork.enable_sizes:
- attn.hypernetwork = self
- else:
- attn.hypernetwork = None
- return True # TODO error checking
- def forward(self, x, context):
- size = context.shape[-1]
- assert size in Hypernetwork.enable_sizes
- module = self.modules[Hypernetwork.enable_sizes.index(size)]
- return module[0].forward(context), module[1].forward(context)
- def load_from_state_dict(self, state_dict):
- # old ver to new ver
- changes = {
- 'linear1.bias': 'linear.0.bias',
- 'linear1.weight': 'linear.0.weight',
- 'linear2.bias': 'linear.1.bias',
- 'linear2.weight': 'linear.1.weight',
- }
- for key_from, key_to in changes.items():
- if key_from in state_dict:
- state_dict[key_to] = state_dict[key_from]
- del state_dict[key_from]
- for size, sd in state_dict.items():
- if type(size) == int:
- self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
- self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
- return True
- def get_state_dict(self):
- state_dict = {}
- for i, size in enumerate(Hypernetwork.enable_sizes):
- sd0 = self.modules[i][0].state_dict()
- sd1 = self.modules[i][1].state_dict()
- state_dict[size] = [sd0, sd1]
- return state_dict
|