data_sampler.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import math
  2. import torch
  3. from torch.utils.data.sampler import Sampler
  4. class EnlargedSampler(Sampler):
  5. """Sampler that restricts data loading to a subset of the dataset.
  6. Modified from torch.utils.data.distributed.DistributedSampler
  7. Support enlarging the dataset for iteration-based training, for saving
  8. time when restart the dataloader after each epoch
  9. Args:
  10. dataset (torch.utils.data.Dataset): Dataset used for sampling.
  11. num_replicas (int | None): Number of processes participating in
  12. the training. It is usually the world_size.
  13. rank (int | None): Rank of the current process within num_replicas.
  14. ratio (int): Enlarging ratio. Default: 1.
  15. """
  16. def __init__(self, dataset, num_replicas, rank, ratio=1):
  17. self.dataset = dataset
  18. self.num_replicas = num_replicas
  19. self.rank = rank
  20. self.epoch = 0
  21. self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
  22. self.total_size = self.num_samples * self.num_replicas
  23. def __iter__(self):
  24. # deterministically shuffle based on epoch
  25. g = torch.Generator()
  26. g.manual_seed(self.epoch)
  27. indices = torch.randperm(self.total_size, generator=g).tolist()
  28. dataset_size = len(self.dataset)
  29. indices = [v % dataset_size for v in indices]
  30. # subsample
  31. indices = indices[self.rank:self.total_size:self.num_replicas]
  32. assert len(indices) == self.num_samples
  33. return iter(indices)
  34. def __len__(self):
  35. return self.num_samples
  36. def set_epoch(self, epoch):
  37. self.epoch = epoch