ffhq_blind_joint_dataset.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import cv2
  2. import math
  3. import random
  4. import numpy as np
  5. import os.path as osp
  6. from scipy.io import loadmat
  7. import torch
  8. import torch.utils.data as data
  9. from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
  10. adjust_hue, adjust_saturation, normalize)
  11. from basicsr.data import gaussian_kernels as gaussian_kernels
  12. from basicsr.data.transforms import augment
  13. from basicsr.data.data_util import paths_from_folder
  14. from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
  15. from basicsr.utils.registry import DATASET_REGISTRY
  16. @DATASET_REGISTRY.register()
  17. class FFHQBlindJointDataset(data.Dataset):
  18. def __init__(self, opt):
  19. super(FFHQBlindJointDataset, self).__init__()
  20. logger = get_root_logger()
  21. self.opt = opt
  22. # file client (io backend)
  23. self.file_client = None
  24. self.io_backend_opt = opt['io_backend']
  25. self.gt_folder = opt['dataroot_gt']
  26. self.gt_size = opt.get('gt_size', 512)
  27. self.in_size = opt.get('in_size', 512)
  28. assert self.gt_size >= self.in_size, 'Wrong setting.'
  29. self.mean = opt.get('mean', [0.5, 0.5, 0.5])
  30. self.std = opt.get('std', [0.5, 0.5, 0.5])
  31. self.component_path = opt.get('component_path', None)
  32. self.latent_gt_path = opt.get('latent_gt_path', None)
  33. if self.component_path is not None:
  34. self.crop_components = True
  35. self.components_dict = torch.load(self.component_path)
  36. self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
  37. self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
  38. self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
  39. else:
  40. self.crop_components = False
  41. if self.latent_gt_path is not None:
  42. self.load_latent_gt = True
  43. self.latent_gt_dict = torch.load(self.latent_gt_path)
  44. else:
  45. self.load_latent_gt = False
  46. if self.io_backend_opt['type'] == 'lmdb':
  47. self.io_backend_opt['db_paths'] = self.gt_folder
  48. if not self.gt_folder.endswith('.lmdb'):
  49. raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
  50. with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
  51. self.paths = [line.split('.')[0] for line in fin]
  52. else:
  53. self.paths = paths_from_folder(self.gt_folder)
  54. # perform corrupt
  55. self.use_corrupt = opt.get('use_corrupt', True)
  56. self.use_motion_kernel = False
  57. # self.use_motion_kernel = opt.get('use_motion_kernel', True)
  58. if self.use_motion_kernel:
  59. self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
  60. motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
  61. self.motion_kernels = torch.load(motion_kernel_path)
  62. if self.use_corrupt:
  63. # degradation configurations
  64. self.blur_kernel_size = self.opt['blur_kernel_size']
  65. self.kernel_list = self.opt['kernel_list']
  66. self.kernel_prob = self.opt['kernel_prob']
  67. # Small degradation
  68. self.blur_sigma = self.opt['blur_sigma']
  69. self.downsample_range = self.opt['downsample_range']
  70. self.noise_range = self.opt['noise_range']
  71. self.jpeg_range = self.opt['jpeg_range']
  72. # Large degradation
  73. self.blur_sigma_large = self.opt['blur_sigma_large']
  74. self.downsample_range_large = self.opt['downsample_range_large']
  75. self.noise_range_large = self.opt['noise_range_large']
  76. self.jpeg_range_large = self.opt['jpeg_range_large']
  77. # print
  78. logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
  79. logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
  80. logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
  81. logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
  82. # color jitter
  83. self.color_jitter_prob = opt.get('color_jitter_prob', None)
  84. self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
  85. self.color_jitter_shift = opt.get('color_jitter_shift', 20)
  86. if self.color_jitter_prob is not None:
  87. logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
  88. # to gray
  89. self.gray_prob = opt.get('gray_prob', 0.0)
  90. if self.gray_prob is not None:
  91. logger.info(f'Use random gray. Prob: {self.gray_prob}')
  92. self.color_jitter_shift /= 255.
  93. @staticmethod
  94. def color_jitter(img, shift):
  95. """jitter color: randomly jitter the RGB values, in numpy formats"""
  96. jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
  97. img = img + jitter_val
  98. img = np.clip(img, 0, 1)
  99. return img
  100. @staticmethod
  101. def color_jitter_pt(img, brightness, contrast, saturation, hue):
  102. """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
  103. fn_idx = torch.randperm(4)
  104. for fn_id in fn_idx:
  105. if fn_id == 0 and brightness is not None:
  106. brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
  107. img = adjust_brightness(img, brightness_factor)
  108. if fn_id == 1 and contrast is not None:
  109. contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
  110. img = adjust_contrast(img, contrast_factor)
  111. if fn_id == 2 and saturation is not None:
  112. saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
  113. img = adjust_saturation(img, saturation_factor)
  114. if fn_id == 3 and hue is not None:
  115. hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
  116. img = adjust_hue(img, hue_factor)
  117. return img
  118. def get_component_locations(self, name, status):
  119. components_bbox = self.components_dict[name]
  120. if status[0]: # hflip
  121. # exchange right and left eye
  122. tmp = components_bbox['left_eye']
  123. components_bbox['left_eye'] = components_bbox['right_eye']
  124. components_bbox['right_eye'] = tmp
  125. # modify the width coordinate
  126. components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
  127. components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
  128. components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
  129. components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
  130. locations_gt = {}
  131. locations_in = {}
  132. for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
  133. mean = components_bbox[part][0:2]
  134. half_len = components_bbox[part][2]
  135. if 'eye' in part:
  136. half_len *= self.eye_enlarge_ratio
  137. elif part == 'nose':
  138. half_len *= self.nose_enlarge_ratio
  139. elif part == 'mouth':
  140. half_len *= self.mouth_enlarge_ratio
  141. loc = np.hstack((mean - half_len + 1, mean + half_len))
  142. loc = torch.from_numpy(loc).float()
  143. locations_gt[part] = loc
  144. loc_in = loc/(self.gt_size//self.in_size)
  145. locations_in[part] = loc_in
  146. return locations_gt, locations_in
  147. def __getitem__(self, index):
  148. if self.file_client is None:
  149. self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
  150. # load gt image
  151. gt_path = self.paths[index]
  152. name = osp.basename(gt_path)[:-4]
  153. img_bytes = self.file_client.get(gt_path)
  154. img_gt = imfrombytes(img_bytes, float32=True)
  155. # random horizontal flip
  156. img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
  157. if self.load_latent_gt:
  158. if status[0]:
  159. latent_gt = self.latent_gt_dict['hflip'][name]
  160. else:
  161. latent_gt = self.latent_gt_dict['orig'][name]
  162. if self.crop_components:
  163. locations_gt, locations_in = self.get_component_locations(name, status)
  164. # generate in image
  165. img_in = img_gt
  166. if self.use_corrupt:
  167. # motion blur
  168. if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
  169. m_i = random.randint(0,31)
  170. k = self.motion_kernels[f'{m_i:02d}']
  171. img_in = cv2.filter2D(img_in,-1,k)
  172. # gaussian blur
  173. kernel = gaussian_kernels.random_mixed_kernels(
  174. self.kernel_list,
  175. self.kernel_prob,
  176. self.blur_kernel_size,
  177. self.blur_sigma,
  178. self.blur_sigma,
  179. [-math.pi, math.pi],
  180. noise_range=None)
  181. img_in = cv2.filter2D(img_in, -1, kernel)
  182. # downsample
  183. scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
  184. img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
  185. # noise
  186. if self.noise_range is not None:
  187. noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
  188. noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
  189. img_in = img_in + noise
  190. img_in = np.clip(img_in, 0, 1)
  191. # jpeg
  192. if self.jpeg_range is not None:
  193. jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
  194. encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
  195. _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
  196. img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
  197. # resize to in_size
  198. img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
  199. # generate in_large with large degradation
  200. img_in_large = img_gt
  201. if self.use_corrupt:
  202. # motion blur
  203. if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
  204. m_i = random.randint(0,31)
  205. k = self.motion_kernels[f'{m_i:02d}']
  206. img_in_large = cv2.filter2D(img_in_large,-1,k)
  207. # gaussian blur
  208. kernel = gaussian_kernels.random_mixed_kernels(
  209. self.kernel_list,
  210. self.kernel_prob,
  211. self.blur_kernel_size,
  212. self.blur_sigma_large,
  213. self.blur_sigma_large,
  214. [-math.pi, math.pi],
  215. noise_range=None)
  216. img_in_large = cv2.filter2D(img_in_large, -1, kernel)
  217. # downsample
  218. scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
  219. img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
  220. # noise
  221. if self.noise_range_large is not None:
  222. noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
  223. noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
  224. img_in_large = img_in_large + noise
  225. img_in_large = np.clip(img_in_large, 0, 1)
  226. # jpeg
  227. if self.jpeg_range_large is not None:
  228. jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
  229. encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
  230. _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
  231. img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
  232. # resize to in_size
  233. img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
  234. # random color jitter (only for lq)
  235. if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
  236. img_in = self.color_jitter(img_in, self.color_jitter_shift)
  237. img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
  238. # random to gray (only for lq)
  239. if self.gray_prob and np.random.uniform() < self.gray_prob:
  240. img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
  241. img_in = np.tile(img_in[:, :, None], [1, 1, 3])
  242. img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
  243. img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
  244. # BGR to RGB, HWC to CHW, numpy to tensor
  245. img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
  246. # random color jitter (pytorch version) (only for lq)
  247. if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
  248. brightness = self.opt.get('brightness', (0.5, 1.5))
  249. contrast = self.opt.get('contrast', (0.5, 1.5))
  250. saturation = self.opt.get('saturation', (0, 1.5))
  251. hue = self.opt.get('hue', (-0.1, 0.1))
  252. img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
  253. img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
  254. # round and clip
  255. img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
  256. img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
  257. # Set vgg range_norm=True if use the normalization here
  258. # normalize
  259. normalize(img_in, self.mean, self.std, inplace=True)
  260. normalize(img_in_large, self.mean, self.std, inplace=True)
  261. normalize(img_gt, self.mean, self.std, inplace=True)
  262. return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
  263. if self.crop_components:
  264. return_dict['locations_in'] = locations_in
  265. return_dict['locations_gt'] = locations_gt
  266. if self.load_latent_gt:
  267. return_dict['latent_gt'] = latent_gt
  268. return return_dict
  269. def __len__(self):
  270. return len(self.paths)