prefetch_dataloader.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import queue as Queue
  2. import threading
  3. import torch
  4. from torch.utils.data import DataLoader
  5. class PrefetchGenerator(threading.Thread):
  6. """A general prefetch generator.
  7. Ref:
  8. https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
  9. Args:
  10. generator: Python generator.
  11. num_prefetch_queue (int): Number of prefetch queue.
  12. """
  13. def __init__(self, generator, num_prefetch_queue):
  14. threading.Thread.__init__(self)
  15. self.queue = Queue.Queue(num_prefetch_queue)
  16. self.generator = generator
  17. self.daemon = True
  18. self.start()
  19. def run(self):
  20. for item in self.generator:
  21. self.queue.put(item)
  22. self.queue.put(None)
  23. def __next__(self):
  24. next_item = self.queue.get()
  25. if next_item is None:
  26. raise StopIteration
  27. return next_item
  28. def __iter__(self):
  29. return self
  30. class PrefetchDataLoader(DataLoader):
  31. """Prefetch version of dataloader.
  32. Ref:
  33. https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
  34. TODO:
  35. Need to test on single gpu and ddp (multi-gpu). There is a known issue in
  36. ddp.
  37. Args:
  38. num_prefetch_queue (int): Number of prefetch queue.
  39. kwargs (dict): Other arguments for dataloader.
  40. """
  41. def __init__(self, num_prefetch_queue, **kwargs):
  42. self.num_prefetch_queue = num_prefetch_queue
  43. super(PrefetchDataLoader, self).__init__(**kwargs)
  44. def __iter__(self):
  45. return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
  46. class CPUPrefetcher():
  47. """CPU prefetcher.
  48. Args:
  49. loader: Dataloader.
  50. """
  51. def __init__(self, loader):
  52. self.ori_loader = loader
  53. self.loader = iter(loader)
  54. def next(self):
  55. try:
  56. return next(self.loader)
  57. except StopIteration:
  58. return None
  59. def reset(self):
  60. self.loader = iter(self.ori_loader)
  61. class CUDAPrefetcher():
  62. """CUDA prefetcher.
  63. Ref:
  64. https://github.com/NVIDIA/apex/issues/304#
  65. It may consums more GPU memory.
  66. Args:
  67. loader: Dataloader.
  68. opt (dict): Options.
  69. """
  70. def __init__(self, loader, opt):
  71. self.ori_loader = loader
  72. self.loader = iter(loader)
  73. self.opt = opt
  74. self.stream = torch.cuda.Stream()
  75. self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
  76. self.preload()
  77. def preload(self):
  78. try:
  79. self.batch = next(self.loader) # self.batch is a dict
  80. except StopIteration:
  81. self.batch = None
  82. return None
  83. # put tensors to gpu
  84. with torch.cuda.stream(self.stream):
  85. for k, v in self.batch.items():
  86. if torch.is_tensor(v):
  87. self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
  88. def next(self):
  89. torch.cuda.current_stream().wait_stream(self.stream)
  90. batch = self.batch
  91. self.preload()
  92. return batch
  93. def reset(self):
  94. self.loader = iter(self.ori_loader)
  95. self.preload()