123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import queue as Queue
- import threading
- import torch
- from torch.utils.data import DataLoader
- class PrefetchGenerator(threading.Thread):
- """A general prefetch generator.
- Ref:
- https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
- Args:
- generator: Python generator.
- num_prefetch_queue (int): Number of prefetch queue.
- """
- def __init__(self, generator, num_prefetch_queue):
- threading.Thread.__init__(self)
- self.queue = Queue.Queue(num_prefetch_queue)
- self.generator = generator
- self.daemon = True
- self.start()
- def run(self):
- for item in self.generator:
- self.queue.put(item)
- self.queue.put(None)
- def __next__(self):
- next_item = self.queue.get()
- if next_item is None:
- raise StopIteration
- return next_item
- def __iter__(self):
- return self
- class PrefetchDataLoader(DataLoader):
- """Prefetch version of dataloader.
- Ref:
- https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
- TODO:
- Need to test on single gpu and ddp (multi-gpu). There is a known issue in
- ddp.
- Args:
- num_prefetch_queue (int): Number of prefetch queue.
- kwargs (dict): Other arguments for dataloader.
- """
- def __init__(self, num_prefetch_queue, **kwargs):
- self.num_prefetch_queue = num_prefetch_queue
- super(PrefetchDataLoader, self).__init__(**kwargs)
- def __iter__(self):
- return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
- class CPUPrefetcher():
- """CPU prefetcher.
- Args:
- loader: Dataloader.
- """
- def __init__(self, loader):
- self.ori_loader = loader
- self.loader = iter(loader)
- def next(self):
- try:
- return next(self.loader)
- except StopIteration:
- return None
- def reset(self):
- self.loader = iter(self.ori_loader)
- class CUDAPrefetcher():
- """CUDA prefetcher.
- Ref:
- https://github.com/NVIDIA/apex/issues/304#
- It may consums more GPU memory.
- Args:
- loader: Dataloader.
- opt (dict): Options.
- """
- def __init__(self, loader, opt):
- self.ori_loader = loader
- self.loader = iter(loader)
- self.opt = opt
- self.stream = torch.cuda.Stream()
- self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
- self.preload()
- def preload(self):
- try:
- self.batch = next(self.loader) # self.batch is a dict
- except StopIteration:
- self.batch = None
- return None
- # put tensors to gpu
- with torch.cuda.stream(self.stream):
- for k, v in self.batch.items():
- if torch.is_tensor(v):
- self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
- def next(self):
- torch.cuda.current_stream().wait_stream(self.stream)
- batch = self.batch
- self.preload()
- return batch
- def reset(self):
- self.loader = iter(self.ori_loader)
- self.preload()
|