__init__.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import cv2
  2. import numpy as np
  3. import torch
  4. import os
  5. from modules import devices, shared
  6. from annotator.annotator_path import models_path
  7. from torchvision.transforms import transforms
  8. # AdelaiDepth/LeReS imports
  9. from .leres.depthmap import estimateleres, estimateboost
  10. from .leres.multi_depth_model_woauxi import RelDepthModel
  11. from .leres.net_tools import strip_prefix_if_present
  12. # pix2pix/merge net imports
  13. from .pix2pix.options.test_options import TestOptions
  14. from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
  15. base_model_path = os.path.join(models_path, "leres")
  16. old_modeldir = os.path.dirname(os.path.realpath(__file__))
  17. remote_model_path_leres = "https://huggingface.co/lllyasviel/Annotators/resolve/main/res101.pth"
  18. remote_model_path_pix2pix = "https://huggingface.co/lllyasviel/Annotators/resolve/main/latest_net_G.pth"
  19. model = None
  20. pix2pixmodel = None
  21. def unload_leres_model():
  22. global model, pix2pixmodel
  23. if model is not None:
  24. model = model.cpu()
  25. if pix2pixmodel is not None:
  26. pix2pixmodel = pix2pixmodel.unload_network('G')
  27. def apply_leres(input_image, thr_a, thr_b, boost=False):
  28. global model, pix2pixmodel
  29. if model is None:
  30. model_path = os.path.join(base_model_path, "res101.pth")
  31. old_model_path = os.path.join(old_modeldir, "res101.pth")
  32. if os.path.exists(old_model_path):
  33. model_path = old_model_path
  34. elif not os.path.exists(model_path):
  35. from basicsr.utils.download_util import load_file_from_url
  36. load_file_from_url(remote_model_path_leres, model_dir=base_model_path)
  37. if torch.cuda.is_available():
  38. checkpoint = torch.load(model_path)
  39. else:
  40. checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
  41. model = RelDepthModel(backbone='resnext101')
  42. model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
  43. del checkpoint
  44. if boost and pix2pixmodel is None:
  45. pix2pixmodel_path = os.path.join(base_model_path, "latest_net_G.pth")
  46. if not os.path.exists(pix2pixmodel_path):
  47. from basicsr.utils.download_util import load_file_from_url
  48. load_file_from_url(remote_model_path_pix2pix, model_dir=base_model_path)
  49. opt = TestOptions().parse()
  50. if not torch.cuda.is_available():
  51. opt.gpu_ids = [] # cpu mode
  52. pix2pixmodel = Pix2Pix4DepthModel(opt)
  53. pix2pixmodel.save_dir = base_model_path
  54. pix2pixmodel.load_networks('latest')
  55. pix2pixmodel.eval()
  56. if devices.get_device_for("controlnet").type != 'mps':
  57. model = model.to(devices.get_device_for("controlnet"))
  58. assert input_image.ndim == 3
  59. height, width, dim = input_image.shape
  60. with torch.no_grad():
  61. if boost:
  62. depth = estimateboost(input_image, model, 0, pix2pixmodel, max(width, height))
  63. else:
  64. depth = estimateleres(input_image, model, width, height)
  65. numbytes=2
  66. depth_min = depth.min()
  67. depth_max = depth.max()
  68. max_val = (2**(8*numbytes))-1
  69. # check output before normalizing and mapping to 16 bit
  70. if depth_max - depth_min > np.finfo("float").eps:
  71. out = max_val * (depth - depth_min) / (depth_max - depth_min)
  72. else:
  73. out = np.zeros(depth.shape)
  74. # single channel, 16 bit image
  75. depth_image = out.astype("uint16")
  76. # convert to uint8
  77. depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
  78. # remove near
  79. if thr_a != 0:
  80. thr_a = ((thr_a/100)*255)
  81. depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
  82. # invert image
  83. depth_image = cv2.bitwise_not(depth_image)
  84. # remove bg
  85. if thr_b != 0:
  86. thr_b = ((thr_b/100)*255)
  87. depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
  88. return depth_image