faceshq.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import os
  2. import numpy as np
  3. import albumentations
  4. from torch.utils.data import Dataset
  5. from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
  6. class FacesBase(Dataset):
  7. def __init__(self, *args, **kwargs):
  8. super().__init__()
  9. self.data = None
  10. self.keys = None
  11. def __len__(self):
  12. return len(self.data)
  13. def __getitem__(self, i):
  14. example = self.data[i]
  15. ex = {}
  16. if self.keys is not None:
  17. for k in self.keys:
  18. ex[k] = example[k]
  19. else:
  20. ex = example
  21. return ex
  22. class CelebAHQTrain(FacesBase):
  23. def __init__(self, size, keys=None):
  24. super().__init__()
  25. root = "data/celebahq"
  26. with open("data/celebahqtrain.txt", "r") as f:
  27. relpaths = f.read().splitlines()
  28. paths = [os.path.join(root, relpath) for relpath in relpaths]
  29. self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
  30. self.keys = keys
  31. class CelebAHQValidation(FacesBase):
  32. def __init__(self, size, keys=None):
  33. super().__init__()
  34. root = "data/celebahq"
  35. with open("data/celebahqvalidation.txt", "r") as f:
  36. relpaths = f.read().splitlines()
  37. paths = [os.path.join(root, relpath) for relpath in relpaths]
  38. self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
  39. self.keys = keys
  40. class FFHQTrain(FacesBase):
  41. def __init__(self, size, keys=None):
  42. super().__init__()
  43. root = "data/ffhq"
  44. with open("data/ffhqtrain.txt", "r") as f:
  45. relpaths = f.read().splitlines()
  46. paths = [os.path.join(root, relpath) for relpath in relpaths]
  47. self.data = ImagePaths(paths=paths, size=size, random_crop=False)
  48. self.keys = keys
  49. class FFHQValidation(FacesBase):
  50. def __init__(self, size, keys=None):
  51. super().__init__()
  52. root = "data/ffhq"
  53. with open("data/ffhqvalidation.txt", "r") as f:
  54. relpaths = f.read().splitlines()
  55. paths = [os.path.join(root, relpath) for relpath in relpaths]
  56. self.data = ImagePaths(paths=paths, size=size, random_crop=False)
  57. self.keys = keys
  58. class FacesHQTrain(Dataset):
  59. # CelebAHQ [0] + FFHQ [1]
  60. def __init__(self, size, keys=None, crop_size=None, coord=False):
  61. d1 = CelebAHQTrain(size=size, keys=keys)
  62. d2 = FFHQTrain(size=size, keys=keys)
  63. self.data = ConcatDatasetWithIndex([d1, d2])
  64. self.coord = coord
  65. if crop_size is not None:
  66. self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
  67. if self.coord:
  68. self.cropper = albumentations.Compose([self.cropper],
  69. additional_targets={"coord": "image"})
  70. def __len__(self):
  71. return len(self.data)
  72. def __getitem__(self, i):
  73. ex, y = self.data[i]
  74. if hasattr(self, "cropper"):
  75. if not self.coord:
  76. out = self.cropper(image=ex["image"])
  77. ex["image"] = out["image"]
  78. else:
  79. h,w,_ = ex["image"].shape
  80. coord = np.arange(h*w).reshape(h,w,1)/(h*w)
  81. out = self.cropper(image=ex["image"], coord=coord)
  82. ex["image"] = out["image"]
  83. ex["coord"] = out["coord"]
  84. ex["class"] = y
  85. return ex
  86. class FacesHQValidation(Dataset):
  87. # CelebAHQ [0] + FFHQ [1]
  88. def __init__(self, size, keys=None, crop_size=None, coord=False):
  89. d1 = CelebAHQValidation(size=size, keys=keys)
  90. d2 = FFHQValidation(size=size, keys=keys)
  91. self.data = ConcatDatasetWithIndex([d1, d2])
  92. self.coord = coord
  93. if crop_size is not None:
  94. self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
  95. if self.coord:
  96. self.cropper = albumentations.Compose([self.cropper],
  97. additional_targets={"coord": "image"})
  98. def __len__(self):
  99. return len(self.data)
  100. def __getitem__(self, i):
  101. ex, y = self.data[i]
  102. if hasattr(self, "cropper"):
  103. if not self.coord:
  104. out = self.cropper(image=ex["image"])
  105. ex["image"] = out["image"]
  106. else:
  107. h,w,_ = ex["image"].shape
  108. coord = np.arange(h*w).reshape(h,w,1)/(h*w)
  109. out = self.cropper(image=ex["image"], coord=coord)
  110. ex["image"] = out["image"]
  111. ex["coord"] = out["coord"]
  112. ex["class"] = y
  113. return ex