__init__.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. import cv2
  3. import torch
  4. import numpy as np
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from modules import devices
  8. from annotator.annotator_path import models_path
  9. norm_layer = nn.InstanceNorm2d
  10. class ResidualBlock(nn.Module):
  11. def __init__(self, in_features):
  12. super(ResidualBlock, self).__init__()
  13. conv_block = [ nn.ReflectionPad2d(1),
  14. nn.Conv2d(in_features, in_features, 3),
  15. norm_layer(in_features),
  16. nn.ReLU(inplace=True),
  17. nn.ReflectionPad2d(1),
  18. nn.Conv2d(in_features, in_features, 3),
  19. norm_layer(in_features)
  20. ]
  21. self.conv_block = nn.Sequential(*conv_block)
  22. def forward(self, x):
  23. return x + self.conv_block(x)
  24. class Generator(nn.Module):
  25. def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
  26. super(Generator, self).__init__()
  27. # Initial convolution block
  28. model0 = [ nn.ReflectionPad2d(3),
  29. nn.Conv2d(input_nc, 64, 7),
  30. norm_layer(64),
  31. nn.ReLU(inplace=True) ]
  32. self.model0 = nn.Sequential(*model0)
  33. # Downsampling
  34. model1 = []
  35. in_features = 64
  36. out_features = in_features*2
  37. for _ in range(2):
  38. model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
  39. norm_layer(out_features),
  40. nn.ReLU(inplace=True) ]
  41. in_features = out_features
  42. out_features = in_features*2
  43. self.model1 = nn.Sequential(*model1)
  44. model2 = []
  45. # Residual blocks
  46. for _ in range(n_residual_blocks):
  47. model2 += [ResidualBlock(in_features)]
  48. self.model2 = nn.Sequential(*model2)
  49. # Upsampling
  50. model3 = []
  51. out_features = in_features//2
  52. for _ in range(2):
  53. model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
  54. norm_layer(out_features),
  55. nn.ReLU(inplace=True) ]
  56. in_features = out_features
  57. out_features = in_features//2
  58. self.model3 = nn.Sequential(*model3)
  59. # Output layer
  60. model4 = [ nn.ReflectionPad2d(3),
  61. nn.Conv2d(64, output_nc, 7)]
  62. if sigmoid:
  63. model4 += [nn.Sigmoid()]
  64. self.model4 = nn.Sequential(*model4)
  65. def forward(self, x, cond=None):
  66. out = self.model0(x)
  67. out = self.model1(out)
  68. out = self.model2(out)
  69. out = self.model3(out)
  70. out = self.model4(out)
  71. return out
  72. class LineartDetector:
  73. model_dir = os.path.join(models_path, "lineart")
  74. model_default = 'sk_model.pth'
  75. model_coarse = 'sk_model2.pth'
  76. def __init__(self, model_name):
  77. self.model = None
  78. self.model_name = model_name
  79. self.device = devices.get_device_for("controlnet")
  80. def load_model(self, name):
  81. remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
  82. model_path = os.path.join(self.model_dir, name)
  83. if not os.path.exists(model_path):
  84. from basicsr.utils.download_util import load_file_from_url
  85. load_file_from_url(remote_model_path, model_dir=self.model_dir)
  86. model = Generator(3, 1, 3)
  87. model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
  88. model.eval()
  89. self.model = model.to(self.device)
  90. def unload_model(self):
  91. if self.model is not None:
  92. self.model.cpu()
  93. def __call__(self, input_image):
  94. if self.model is None:
  95. self.load_model(self.model_name)
  96. self.model.to(self.device)
  97. assert input_image.ndim == 3
  98. image = input_image
  99. with torch.no_grad():
  100. image = torch.from_numpy(image).float().to(self.device)
  101. image = image / 255.0
  102. image = rearrange(image, 'h w c -> 1 c h w')
  103. line = self.model(image)[0][0]
  104. line = line.cpu().numpy()
  105. line = (line * 255.0).clip(0, 255).astype(np.uint8)
  106. return line