paired_image_dataset.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from torch.utils import data as data
  2. from torchvision.transforms.functional import normalize
  3. from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
  4. from basicsr.data.transforms import augment, paired_random_crop
  5. from basicsr.utils import FileClient, imfrombytes, img2tensor
  6. from basicsr.utils.registry import DATASET_REGISTRY
  7. @DATASET_REGISTRY.register()
  8. class PairedImageDataset(data.Dataset):
  9. """Paired image dataset for image restoration.
  10. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
  11. GT image pairs.
  12. There are three modes:
  13. 1. 'lmdb': Use lmdb files.
  14. If opt['io_backend'] == lmdb.
  15. 2. 'meta_info_file': Use meta information file to generate paths.
  16. If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
  17. 3. 'folder': Scan folders to generate paths.
  18. The rest.
  19. Args:
  20. opt (dict): Config for train datasets. It contains the following keys:
  21. dataroot_gt (str): Data root path for gt.
  22. dataroot_lq (str): Data root path for lq.
  23. meta_info_file (str): Path for meta information file.
  24. io_backend (dict): IO backend type and other kwarg.
  25. filename_tmpl (str): Template for each filename. Note that the
  26. template excludes the file extension. Default: '{}'.
  27. gt_size (int): Cropped patched size for gt patches.
  28. use_flip (bool): Use horizontal flips.
  29. use_rot (bool): Use rotation (use vertical flip and transposing h
  30. and w for implementation).
  31. scale (bool): Scale, which will be added automatically.
  32. phase (str): 'train' or 'val'.
  33. """
  34. def __init__(self, opt):
  35. super(PairedImageDataset, self).__init__()
  36. self.opt = opt
  37. # file client (io backend)
  38. self.file_client = None
  39. self.io_backend_opt = opt['io_backend']
  40. self.mean = opt['mean'] if 'mean' in opt else None
  41. self.std = opt['std'] if 'std' in opt else None
  42. self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
  43. if 'filename_tmpl' in opt:
  44. self.filename_tmpl = opt['filename_tmpl']
  45. else:
  46. self.filename_tmpl = '{}'
  47. if self.io_backend_opt['type'] == 'lmdb':
  48. self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
  49. self.io_backend_opt['client_keys'] = ['lq', 'gt']
  50. self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
  51. elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
  52. self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
  53. self.opt['meta_info_file'], self.filename_tmpl)
  54. else:
  55. self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
  56. def __getitem__(self, index):
  57. if self.file_client is None:
  58. self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
  59. scale = self.opt['scale']
  60. # Load gt and lq images. Dimension order: HWC; channel order: BGR;
  61. # image range: [0, 1], float32.
  62. gt_path = self.paths[index]['gt_path']
  63. img_bytes = self.file_client.get(gt_path, 'gt')
  64. img_gt = imfrombytes(img_bytes, float32=True)
  65. lq_path = self.paths[index]['lq_path']
  66. img_bytes = self.file_client.get(lq_path, 'lq')
  67. img_lq = imfrombytes(img_bytes, float32=True)
  68. # augmentation for training
  69. if self.opt['phase'] == 'train':
  70. gt_size = self.opt['gt_size']
  71. # random crop
  72. img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
  73. # flip, rotation
  74. img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
  75. # TODO: color space transform
  76. # BGR to RGB, HWC to CHW, numpy to tensor
  77. img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
  78. # normalize
  79. if self.mean is not None or self.std is not None:
  80. normalize(img_lq, self.mean, self.std, inplace=True)
  81. normalize(img_gt, self.mean, self.std, inplace=True)
  82. return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
  83. def __len__(self):
  84. return len(self.paths)