transforms.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import cv2
  2. import random
  3. def mod_crop(img, scale):
  4. """Mod crop images, used during testing.
  5. Args:
  6. img (ndarray): Input image.
  7. scale (int): Scale factor.
  8. Returns:
  9. ndarray: Result image.
  10. """
  11. img = img.copy()
  12. if img.ndim in (2, 3):
  13. h, w = img.shape[0], img.shape[1]
  14. h_remainder, w_remainder = h % scale, w % scale
  15. img = img[:h - h_remainder, :w - w_remainder, ...]
  16. else:
  17. raise ValueError(f'Wrong img ndim: {img.ndim}.')
  18. return img
  19. def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
  20. """Paired random crop.
  21. It crops lists of lq and gt images with corresponding locations.
  22. Args:
  23. img_gts (list[ndarray] | ndarray): GT images. Note that all images
  24. should have the same shape. If the input is an ndarray, it will
  25. be transformed to a list containing itself.
  26. img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
  27. should have the same shape. If the input is an ndarray, it will
  28. be transformed to a list containing itself.
  29. gt_patch_size (int): GT patch size.
  30. scale (int): Scale factor.
  31. gt_path (str): Path to ground-truth.
  32. Returns:
  33. list[ndarray] | ndarray: GT images and LQ images. If returned results
  34. only have one element, just return ndarray.
  35. """
  36. if not isinstance(img_gts, list):
  37. img_gts = [img_gts]
  38. if not isinstance(img_lqs, list):
  39. img_lqs = [img_lqs]
  40. h_lq, w_lq, _ = img_lqs[0].shape
  41. h_gt, w_gt, _ = img_gts[0].shape
  42. lq_patch_size = gt_patch_size // scale
  43. if h_gt != h_lq * scale or w_gt != w_lq * scale:
  44. raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
  45. f'multiplication of LQ ({h_lq}, {w_lq}).')
  46. if h_lq < lq_patch_size or w_lq < lq_patch_size:
  47. raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
  48. f'({lq_patch_size}, {lq_patch_size}). '
  49. f'Please remove {gt_path}.')
  50. # randomly choose top and left coordinates for lq patch
  51. top = random.randint(0, h_lq - lq_patch_size)
  52. left = random.randint(0, w_lq - lq_patch_size)
  53. # crop lq patch
  54. img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
  55. # crop corresponding gt patch
  56. top_gt, left_gt = int(top * scale), int(left * scale)
  57. img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
  58. if len(img_gts) == 1:
  59. img_gts = img_gts[0]
  60. if len(img_lqs) == 1:
  61. img_lqs = img_lqs[0]
  62. return img_gts, img_lqs
  63. def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
  64. """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
  65. We use vertical flip and transpose for rotation implementation.
  66. All the images in the list use the same augmentation.
  67. Args:
  68. imgs (list[ndarray] | ndarray): Images to be augmented. If the input
  69. is an ndarray, it will be transformed to a list.
  70. hflip (bool): Horizontal flip. Default: True.
  71. rotation (bool): Ratotation. Default: True.
  72. flows (list[ndarray]: Flows to be augmented. If the input is an
  73. ndarray, it will be transformed to a list.
  74. Dimension is (h, w, 2). Default: None.
  75. return_status (bool): Return the status of flip and rotation.
  76. Default: False.
  77. Returns:
  78. list[ndarray] | ndarray: Augmented images and flows. If returned
  79. results only have one element, just return ndarray.
  80. """
  81. hflip = hflip and random.random() < 0.5
  82. vflip = rotation and random.random() < 0.5
  83. rot90 = rotation and random.random() < 0.5
  84. def _augment(img):
  85. if hflip: # horizontal
  86. cv2.flip(img, 1, img)
  87. if vflip: # vertical
  88. cv2.flip(img, 0, img)
  89. if rot90:
  90. img = img.transpose(1, 0, 2)
  91. return img
  92. def _augment_flow(flow):
  93. if hflip: # horizontal
  94. cv2.flip(flow, 1, flow)
  95. flow[:, :, 0] *= -1
  96. if vflip: # vertical
  97. cv2.flip(flow, 0, flow)
  98. flow[:, :, 1] *= -1
  99. if rot90:
  100. flow = flow.transpose(1, 0, 2)
  101. flow = flow[:, :, [1, 0]]
  102. return flow
  103. if not isinstance(imgs, list):
  104. imgs = [imgs]
  105. imgs = [_augment(img) for img in imgs]
  106. if len(imgs) == 1:
  107. imgs = imgs[0]
  108. if flows is not None:
  109. if not isinstance(flows, list):
  110. flows = [flows]
  111. flows = [_augment_flow(flow) for flow in flows]
  112. if len(flows) == 1:
  113. flows = flows[0]
  114. return imgs, flows
  115. else:
  116. if return_status:
  117. return imgs, (hflip, vflip, rot90)
  118. else:
  119. return imgs
  120. def img_rotate(img, angle, center=None, scale=1.0):
  121. """Rotate image.
  122. Args:
  123. img (ndarray): Image to be rotated.
  124. angle (float): Rotation angle in degrees. Positive values mean
  125. counter-clockwise rotation.
  126. center (tuple[int]): Rotation center. If the center is None,
  127. initialize it as the center of the image. Default: None.
  128. scale (float): Isotropic scale factor. Default: 1.0.
  129. """
  130. (h, w) = img.shape[:2]
  131. if center is None:
  132. center = (w // 2, h // 2)
  133. matrix = cv2.getRotationMatrix2D(center, angle, scale)
  134. rotated_img = cv2.warpAffine(img, matrix, (w, h))
  135. return rotated_img