__init__.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import importlib
  2. import numpy as np
  3. import random
  4. import torch
  5. import torch.utils.data
  6. from copy import deepcopy
  7. from functools import partial
  8. from os import path as osp
  9. from basicsr.data.prefetch_dataloader import PrefetchDataLoader
  10. from basicsr.utils import get_root_logger, scandir
  11. from basicsr.utils.dist_util import get_dist_info
  12. from basicsr.utils.registry import DATASET_REGISTRY
  13. __all__ = ['build_dataset', 'build_dataloader']
  14. # automatically scan and import dataset modules for registry
  15. # scan all the files under the data folder with '_dataset' in file names
  16. data_folder = osp.dirname(osp.abspath(__file__))
  17. dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
  18. # import all the dataset modules
  19. _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
  20. def build_dataset(dataset_opt):
  21. """Build dataset from options.
  22. Args:
  23. dataset_opt (dict): Configuration for dataset. It must constain:
  24. name (str): Dataset name.
  25. type (str): Dataset type.
  26. """
  27. dataset_opt = deepcopy(dataset_opt)
  28. dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
  29. logger = get_root_logger()
  30. logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
  31. return dataset
  32. def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
  33. """Build dataloader.
  34. Args:
  35. dataset (torch.utils.data.Dataset): Dataset.
  36. dataset_opt (dict): Dataset options. It contains the following keys:
  37. phase (str): 'train' or 'val'.
  38. num_worker_per_gpu (int): Number of workers for each GPU.
  39. batch_size_per_gpu (int): Training batch size for each GPU.
  40. num_gpu (int): Number of GPUs. Used only in the train phase.
  41. Default: 1.
  42. dist (bool): Whether in distributed training. Used only in the train
  43. phase. Default: False.
  44. sampler (torch.utils.data.sampler): Data sampler. Default: None.
  45. seed (int | None): Seed. Default: None
  46. """
  47. phase = dataset_opt['phase']
  48. rank, _ = get_dist_info()
  49. if phase == 'train':
  50. if dist: # distributed training
  51. batch_size = dataset_opt['batch_size_per_gpu']
  52. num_workers = dataset_opt['num_worker_per_gpu']
  53. else: # non-distributed training
  54. multiplier = 1 if num_gpu == 0 else num_gpu
  55. batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
  56. num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
  57. dataloader_args = dict(
  58. dataset=dataset,
  59. batch_size=batch_size,
  60. shuffle=False,
  61. num_workers=num_workers,
  62. sampler=sampler,
  63. drop_last=True)
  64. if sampler is None:
  65. dataloader_args['shuffle'] = True
  66. dataloader_args['worker_init_fn'] = partial(
  67. worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
  68. elif phase in ['val', 'test']: # validation
  69. dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
  70. else:
  71. raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
  72. dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
  73. prefetch_mode = dataset_opt.get('prefetch_mode')
  74. if prefetch_mode == 'cpu': # CPUPrefetcher
  75. num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
  76. logger = get_root_logger()
  77. logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
  78. return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
  79. else:
  80. # prefetch_mode=None: Normal dataloader
  81. # prefetch_mode='cuda': dataloader for CUDAPrefetcher
  82. return torch.utils.data.DataLoader(**dataloader_args)
  83. def worker_init_fn(worker_id, num_workers, rank, seed):
  84. # Set the worker seed to num_workers * rank + worker_id + seed
  85. worker_seed = num_workers * rank + worker_id + seed
  86. np.random.seed(worker_seed)
  87. random.seed(worker_seed)