utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """Utils for monoDepth."""
  2. import sys
  3. import re
  4. import numpy as np
  5. import cv2
  6. import torch
  7. def read_pfm(path):
  8. """Read pfm file.
  9. Args:
  10. path (str): path to file
  11. Returns:
  12. tuple: (data, scale)
  13. """
  14. with open(path, "rb") as file:
  15. color = None
  16. width = None
  17. height = None
  18. scale = None
  19. endian = None
  20. header = file.readline().rstrip()
  21. if header.decode("ascii") == "PF":
  22. color = True
  23. elif header.decode("ascii") == "Pf":
  24. color = False
  25. else:
  26. raise Exception("Not a PFM file: " + path)
  27. dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
  28. if dim_match:
  29. width, height = list(map(int, dim_match.groups()))
  30. else:
  31. raise Exception("Malformed PFM header.")
  32. scale = float(file.readline().decode("ascii").rstrip())
  33. if scale < 0:
  34. # little-endian
  35. endian = "<"
  36. scale = -scale
  37. else:
  38. # big-endian
  39. endian = ">"
  40. data = np.fromfile(file, endian + "f")
  41. shape = (height, width, 3) if color else (height, width)
  42. data = np.reshape(data, shape)
  43. data = np.flipud(data)
  44. return data, scale
  45. def write_pfm(path, image, scale=1):
  46. """Write pfm file.
  47. Args:
  48. path (str): pathto file
  49. image (array): data
  50. scale (int, optional): Scale. Defaults to 1.
  51. """
  52. with open(path, "wb") as file:
  53. color = None
  54. if image.dtype.name != "float32":
  55. raise Exception("Image dtype must be float32.")
  56. image = np.flipud(image)
  57. if len(image.shape) == 3 and image.shape[2] == 3: # color image
  58. color = True
  59. elif (
  60. len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
  61. ): # greyscale
  62. color = False
  63. else:
  64. raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
  65. file.write("PF\n" if color else "Pf\n".encode())
  66. file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
  67. endian = image.dtype.byteorder
  68. if endian == "<" or endian == "=" and sys.byteorder == "little":
  69. scale = -scale
  70. file.write("%f\n".encode() % scale)
  71. image.tofile(file)
  72. def read_image(path):
  73. """Read image and output RGB image (0-1).
  74. Args:
  75. path (str): path to file
  76. Returns:
  77. array: RGB image (0-1)
  78. """
  79. img = cv2.imread(path)
  80. if img.ndim == 2:
  81. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  82. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
  83. return img
  84. def resize_image(img):
  85. """Resize image and make it fit for network.
  86. Args:
  87. img (array): image
  88. Returns:
  89. tensor: data ready for network
  90. """
  91. height_orig = img.shape[0]
  92. width_orig = img.shape[1]
  93. if width_orig > height_orig:
  94. scale = width_orig / 384
  95. else:
  96. scale = height_orig / 384
  97. height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
  98. width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
  99. img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
  100. img_resized = (
  101. torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
  102. )
  103. img_resized = img_resized.unsqueeze(0)
  104. return img_resized
  105. def resize_depth(depth, width, height):
  106. """Resize depth map and bring to CPU (numpy).
  107. Args:
  108. depth (tensor): depth
  109. width (int): image width
  110. height (int): image height
  111. Returns:
  112. array: processed depth
  113. """
  114. depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
  115. depth_resized = cv2.resize(
  116. depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
  117. )
  118. return depth_resized
  119. def write_depth(path, depth, bits=1):
  120. """Write depth map to pfm and png file.
  121. Args:
  122. path (str): filepath without extension
  123. depth (array): depth
  124. """
  125. write_pfm(path + ".pfm", depth.astype(np.float32))
  126. depth_min = depth.min()
  127. depth_max = depth.max()
  128. max_val = (2**(8*bits))-1
  129. if depth_max - depth_min > np.finfo("float").eps:
  130. out = max_val * (depth - depth_min) / (depth_max - depth_min)
  131. else:
  132. out = np.zeros(depth.shape, dtype=depth.type)
  133. if bits == 1:
  134. cv2.imwrite(path + ".png", out.astype("uint8"))
  135. elif bits == 2:
  136. cv2.imwrite(path + ".png", out.astype("uint16"))
  137. return