data_util.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import cv2
  2. import math
  3. import numpy as np
  4. import torch
  5. from os import path as osp
  6. from PIL import Image, ImageDraw
  7. from torch.nn import functional as F
  8. from basicsr.data.transforms import mod_crop
  9. from basicsr.utils import img2tensor, scandir
  10. def read_img_seq(path, require_mod_crop=False, scale=1):
  11. """Read a sequence of images from a given folder path.
  12. Args:
  13. path (list[str] | str): List of image paths or image folder path.
  14. require_mod_crop (bool): Require mod crop for each image.
  15. Default: False.
  16. scale (int): Scale factor for mod_crop. Default: 1.
  17. Returns:
  18. Tensor: size (t, c, h, w), RGB, [0, 1].
  19. """
  20. if isinstance(path, list):
  21. img_paths = path
  22. else:
  23. img_paths = sorted(list(scandir(path, full_path=True)))
  24. imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
  25. if require_mod_crop:
  26. imgs = [mod_crop(img, scale) for img in imgs]
  27. imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
  28. imgs = torch.stack(imgs, dim=0)
  29. return imgs
  30. def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
  31. """Generate an index list for reading `num_frames` frames from a sequence
  32. of images.
  33. Args:
  34. crt_idx (int): Current center index.
  35. max_frame_num (int): Max number of the sequence of images (from 1).
  36. num_frames (int): Reading num_frames frames.
  37. padding (str): Padding mode, one of
  38. 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
  39. Examples: current_idx = 0, num_frames = 5
  40. The generated frame indices under different padding mode:
  41. replicate: [0, 0, 0, 1, 2]
  42. reflection: [2, 1, 0, 1, 2]
  43. reflection_circle: [4, 3, 0, 1, 2]
  44. circle: [3, 4, 0, 1, 2]
  45. Returns:
  46. list[int]: A list of indices.
  47. """
  48. assert num_frames % 2 == 1, 'num_frames should be an odd number.'
  49. assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
  50. max_frame_num = max_frame_num - 1 # start from 0
  51. num_pad = num_frames // 2
  52. indices = []
  53. for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
  54. if i < 0:
  55. if padding == 'replicate':
  56. pad_idx = 0
  57. elif padding == 'reflection':
  58. pad_idx = -i
  59. elif padding == 'reflection_circle':
  60. pad_idx = crt_idx + num_pad - i
  61. else:
  62. pad_idx = num_frames + i
  63. elif i > max_frame_num:
  64. if padding == 'replicate':
  65. pad_idx = max_frame_num
  66. elif padding == 'reflection':
  67. pad_idx = max_frame_num * 2 - i
  68. elif padding == 'reflection_circle':
  69. pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
  70. else:
  71. pad_idx = i - num_frames
  72. else:
  73. pad_idx = i
  74. indices.append(pad_idx)
  75. return indices
  76. def paired_paths_from_lmdb(folders, keys):
  77. """Generate paired paths from lmdb files.
  78. Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
  79. lq.lmdb
  80. ├── data.mdb
  81. ├── lock.mdb
  82. ├── meta_info.txt
  83. The data.mdb and lock.mdb are standard lmdb files and you can refer to
  84. https://lmdb.readthedocs.io/en/release/ for more details.
  85. The meta_info.txt is a specified txt file to record the meta information
  86. of our datasets. It will be automatically created when preparing
  87. datasets by our provided dataset tools.
  88. Each line in the txt file records
  89. 1)image name (with extension),
  90. 2)image shape,
  91. 3)compression level, separated by a white space.
  92. Example: `baboon.png (120,125,3) 1`
  93. We use the image name without extension as the lmdb key.
  94. Note that we use the same key for the corresponding lq and gt images.
  95. Args:
  96. folders (list[str]): A list of folder path. The order of list should
  97. be [input_folder, gt_folder].
  98. keys (list[str]): A list of keys identifying folders. The order should
  99. be in consistent with folders, e.g., ['lq', 'gt'].
  100. Note that this key is different from lmdb keys.
  101. Returns:
  102. list[str]: Returned path list.
  103. """
  104. assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
  105. f'But got {len(folders)}')
  106. assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
  107. input_folder, gt_folder = folders
  108. input_key, gt_key = keys
  109. if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
  110. raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
  111. f'formats. But received {input_key}: {input_folder}; '
  112. f'{gt_key}: {gt_folder}')
  113. # ensure that the two meta_info files are the same
  114. with open(osp.join(input_folder, 'meta_info.txt')) as fin:
  115. input_lmdb_keys = [line.split('.')[0] for line in fin]
  116. with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
  117. gt_lmdb_keys = [line.split('.')[0] for line in fin]
  118. if set(input_lmdb_keys) != set(gt_lmdb_keys):
  119. raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
  120. else:
  121. paths = []
  122. for lmdb_key in sorted(input_lmdb_keys):
  123. paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
  124. return paths
  125. def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
  126. """Generate paired paths from an meta information file.
  127. Each line in the meta information file contains the image names and
  128. image shape (usually for gt), separated by a white space.
  129. Example of an meta information file:
  130. ```
  131. 0001_s001.png (480,480,3)
  132. 0001_s002.png (480,480,3)
  133. ```
  134. Args:
  135. folders (list[str]): A list of folder path. The order of list should
  136. be [input_folder, gt_folder].
  137. keys (list[str]): A list of keys identifying folders. The order should
  138. be in consistent with folders, e.g., ['lq', 'gt'].
  139. meta_info_file (str): Path to the meta information file.
  140. filename_tmpl (str): Template for each filename. Note that the
  141. template excludes the file extension. Usually the filename_tmpl is
  142. for files in the input folder.
  143. Returns:
  144. list[str]: Returned path list.
  145. """
  146. assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
  147. f'But got {len(folders)}')
  148. assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
  149. input_folder, gt_folder = folders
  150. input_key, gt_key = keys
  151. with open(meta_info_file, 'r') as fin:
  152. gt_names = [line.split(' ')[0] for line in fin]
  153. paths = []
  154. for gt_name in gt_names:
  155. basename, ext = osp.splitext(osp.basename(gt_name))
  156. input_name = f'{filename_tmpl.format(basename)}{ext}'
  157. input_path = osp.join(input_folder, input_name)
  158. gt_path = osp.join(gt_folder, gt_name)
  159. paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
  160. return paths
  161. def paired_paths_from_folder(folders, keys, filename_tmpl):
  162. """Generate paired paths from folders.
  163. Args:
  164. folders (list[str]): A list of folder path. The order of list should
  165. be [input_folder, gt_folder].
  166. keys (list[str]): A list of keys identifying folders. The order should
  167. be in consistent with folders, e.g., ['lq', 'gt'].
  168. filename_tmpl (str): Template for each filename. Note that the
  169. template excludes the file extension. Usually the filename_tmpl is
  170. for files in the input folder.
  171. Returns:
  172. list[str]: Returned path list.
  173. """
  174. assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
  175. f'But got {len(folders)}')
  176. assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
  177. input_folder, gt_folder = folders
  178. input_key, gt_key = keys
  179. input_paths = list(scandir(input_folder))
  180. gt_paths = list(scandir(gt_folder))
  181. assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
  182. f'{len(input_paths)}, {len(gt_paths)}.')
  183. paths = []
  184. for gt_path in gt_paths:
  185. basename, ext = osp.splitext(osp.basename(gt_path))
  186. input_name = f'{filename_tmpl.format(basename)}{ext}'
  187. input_path = osp.join(input_folder, input_name)
  188. assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
  189. gt_path = osp.join(gt_folder, gt_path)
  190. paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
  191. return paths
  192. def paths_from_folder(folder):
  193. """Generate paths from folder.
  194. Args:
  195. folder (str): Folder path.
  196. Returns:
  197. list[str]: Returned path list.
  198. """
  199. paths = list(scandir(folder))
  200. paths = [osp.join(folder, path) for path in paths]
  201. return paths
  202. def paths_from_lmdb(folder):
  203. """Generate paths from lmdb.
  204. Args:
  205. folder (str): Folder path.
  206. Returns:
  207. list[str]: Returned path list.
  208. """
  209. if not folder.endswith('.lmdb'):
  210. raise ValueError(f'Folder {folder}folder should in lmdb format.')
  211. with open(osp.join(folder, 'meta_info.txt')) as fin:
  212. paths = [line.split('.')[0] for line in fin]
  213. return paths
  214. def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
  215. """Generate Gaussian kernel used in `duf_downsample`.
  216. Args:
  217. kernel_size (int): Kernel size. Default: 13.
  218. sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
  219. Returns:
  220. np.array: The Gaussian kernel.
  221. """
  222. from scipy.ndimage import filters as filters
  223. kernel = np.zeros((kernel_size, kernel_size))
  224. # set element at the middle to one, a dirac delta
  225. kernel[kernel_size // 2, kernel_size // 2] = 1
  226. # gaussian-smooth the dirac, resulting in a gaussian filter
  227. return filters.gaussian_filter(kernel, sigma)
  228. def duf_downsample(x, kernel_size=13, scale=4):
  229. """Downsamping with Gaussian kernel used in the DUF official code.
  230. Args:
  231. x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
  232. kernel_size (int): Kernel size. Default: 13.
  233. scale (int): Downsampling factor. Supported scale: (2, 3, 4).
  234. Default: 4.
  235. Returns:
  236. Tensor: DUF downsampled frames.
  237. """
  238. assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
  239. squeeze_flag = False
  240. if x.ndim == 4:
  241. squeeze_flag = True
  242. x = x.unsqueeze(0)
  243. b, t, c, h, w = x.size()
  244. x = x.view(-1, 1, h, w)
  245. pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
  246. x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
  247. gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
  248. gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
  249. x = F.conv2d(x, gaussian_filter, stride=scale)
  250. x = x[:, :, 2:-2, 2:-2]
  251. x = x.view(b, t, c, x.size(2), x.size(3))
  252. if squeeze_flag:
  253. x = x.squeeze(0)
  254. return x
  255. def brush_stroke_mask(img, color=(255,255,255)):
  256. min_num_vertex = 8
  257. max_num_vertex = 28
  258. mean_angle = 2*math.pi / 5
  259. angle_range = 2*math.pi / 12
  260. # training large mask ratio (training setting)
  261. min_width = 30
  262. max_width = 70
  263. # very large mask ratio (test setting and refine after 200k)
  264. # min_width = 80
  265. # max_width = 120
  266. def generate_mask(H, W, img=None):
  267. average_radius = math.sqrt(H*H+W*W) / 8
  268. mask = Image.new('RGB', (W, H), 0)
  269. if img is not None: mask = img # Image.fromarray(img)
  270. for _ in range(np.random.randint(1, 4)):
  271. num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
  272. angle_min = mean_angle - np.random.uniform(0, angle_range)
  273. angle_max = mean_angle + np.random.uniform(0, angle_range)
  274. angles = []
  275. vertex = []
  276. for i in range(num_vertex):
  277. if i % 2 == 0:
  278. angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
  279. else:
  280. angles.append(np.random.uniform(angle_min, angle_max))
  281. h, w = mask.size
  282. vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
  283. for i in range(num_vertex):
  284. r = np.clip(
  285. np.random.normal(loc=average_radius, scale=average_radius//2),
  286. 0, 2*average_radius)
  287. new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
  288. new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
  289. vertex.append((int(new_x), int(new_y)))
  290. draw = ImageDraw.Draw(mask)
  291. width = int(np.random.uniform(min_width, max_width))
  292. draw.line(vertex, fill=color, width=width)
  293. for v in vertex:
  294. draw.ellipse((v[0] - width//2,
  295. v[1] - width//2,
  296. v[0] + width//2,
  297. v[1] + width//2),
  298. fill=color)
  299. return mask
  300. width, height = img.size
  301. mask = generate_mask(height, width, img)
  302. return mask
  303. def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
  304. """Generate a random free form mask with configuration.
  305. Args:
  306. config: Config should have configuration including IMG_SHAPES,
  307. VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
  308. Returns:
  309. tuple: (top, left, height, width)
  310. Link:
  311. https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
  312. """
  313. height = shape[0]
  314. width = shape[1]
  315. mask = np.zeros((height, width), np.float32)
  316. times = np.random.randint(times-5, times)
  317. for i in range(times):
  318. start_x = np.random.randint(width)
  319. start_y = np.random.randint(height)
  320. for j in range(1 + np.random.randint(5)):
  321. angle = 0.01 + np.random.randint(max_angle)
  322. if i % 2 == 0:
  323. angle = 2 * 3.1415926 - angle
  324. length = 10 + np.random.randint(max_len-20, max_len)
  325. brush_w = 5 + np.random.randint(max_width-30, max_width)
  326. end_x = (start_x + length * np.sin(angle)).astype(np.int32)
  327. end_y = (start_y + length * np.cos(angle)).astype(np.int32)
  328. cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
  329. start_x, start_y = end_x, end_y
  330. return mask.astype(np.float32)