extract_depth.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import torch
  3. import numpy as np
  4. from tqdm import trange
  5. from PIL import Image
  6. def get_state(gpu):
  7. import torch
  8. midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
  9. if gpu:
  10. midas.cuda()
  11. midas.eval()
  12. midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
  13. transform = midas_transforms.default_transform
  14. state = {"model": midas,
  15. "transform": transform}
  16. return state
  17. def depth_to_rgba(x):
  18. assert x.dtype == np.float32
  19. assert len(x.shape) == 2
  20. y = x.copy()
  21. y.dtype = np.uint8
  22. y = y.reshape(x.shape+(4,))
  23. return np.ascontiguousarray(y)
  24. def rgba_to_depth(x):
  25. assert x.dtype == np.uint8
  26. assert len(x.shape) == 3 and x.shape[2] == 4
  27. y = x.copy()
  28. y.dtype = np.float32
  29. y = y.reshape(x.shape[:2])
  30. return np.ascontiguousarray(y)
  31. def run(x, state):
  32. model = state["model"]
  33. transform = state["transform"]
  34. hw = x.shape[:2]
  35. with torch.no_grad():
  36. prediction = model(transform((x + 1.0) * 127.5).cuda())
  37. prediction = torch.nn.functional.interpolate(
  38. prediction.unsqueeze(1),
  39. size=hw,
  40. mode="bicubic",
  41. align_corners=False,
  42. ).squeeze()
  43. output = prediction.cpu().numpy()
  44. return output
  45. def get_filename(relpath, level=-2):
  46. # save class folder structure and filename:
  47. fn = relpath.split(os.sep)[level:]
  48. folder = fn[-2]
  49. file = fn[-1].split('.')[0]
  50. return folder, file
  51. def save_depth(dataset, path, debug=False):
  52. os.makedirs(path)
  53. N = len(dset)
  54. if debug:
  55. N = 10
  56. state = get_state(gpu=True)
  57. for idx in trange(N, desc="Data"):
  58. ex = dataset[idx]
  59. image, relpath = ex["image"], ex["relpath"]
  60. folder, filename = get_filename(relpath)
  61. # prepare
  62. folderabspath = os.path.join(path, folder)
  63. os.makedirs(folderabspath, exist_ok=True)
  64. savepath = os.path.join(folderabspath, filename)
  65. # run model
  66. xout = run(image, state)
  67. I = depth_to_rgba(xout)
  68. Image.fromarray(I).save("{}.png".format(savepath))
  69. if __name__ == "__main__":
  70. from taming.data.imagenet import ImageNetTrain, ImageNetValidation
  71. out = "data/imagenet_depth"
  72. if not os.path.exists(out):
  73. print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
  74. "(be prepared that the output size will be larger than ImageNet itself).")
  75. exit(1)
  76. # go
  77. dset = ImageNetValidation()
  78. abspath = os.path.join(out, "val")
  79. if os.path.exists(abspath):
  80. print("{} exists - not doing anything.".format(abspath))
  81. else:
  82. print("preparing {}".format(abspath))
  83. save_depth(dset, abspath)
  84. print("done with validation split")
  85. dset = ImageNetTrain()
  86. abspath = os.path.join(out, "train")
  87. if os.path.exists(abspath):
  88. print("{} exists - not doing anything.".format(abspath))
  89. else:
  90. print("preparing {}".format(abspath))
  91. save_depth(dset, abspath)
  92. print("done with train split")
  93. print("done done.")