imagenet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. import os, tarfile, glob, shutil
  2. import yaml
  3. import numpy as np
  4. from tqdm import tqdm
  5. from PIL import Image
  6. import albumentations
  7. from omegaconf import OmegaConf
  8. from torch.utils.data import Dataset
  9. from taming.data.base import ImagePaths
  10. from taming.util import download, retrieve
  11. import taming.data.utils as bdu
  12. def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
  13. synsets = []
  14. with open(path_to_yaml) as f:
  15. di2s = yaml.load(f)
  16. for idx in indices:
  17. synsets.append(str(di2s[idx]))
  18. print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
  19. return synsets
  20. def str_to_indices(string):
  21. """Expects a string in the format '32-123, 256, 280-321'"""
  22. assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
  23. subs = string.split(",")
  24. indices = []
  25. for sub in subs:
  26. subsubs = sub.split("-")
  27. assert len(subsubs) > 0
  28. if len(subsubs) == 1:
  29. indices.append(int(subsubs[0]))
  30. else:
  31. rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
  32. indices.extend(rang)
  33. return sorted(indices)
  34. class ImageNetBase(Dataset):
  35. def __init__(self, config=None):
  36. self.config = config or OmegaConf.create()
  37. if not type(self.config)==dict:
  38. self.config = OmegaConf.to_container(self.config)
  39. self._prepare()
  40. self._prepare_synset_to_human()
  41. self._prepare_idx_to_synset()
  42. self._load()
  43. def __len__(self):
  44. return len(self.data)
  45. def __getitem__(self, i):
  46. return self.data[i]
  47. def _prepare(self):
  48. raise NotImplementedError()
  49. def _filter_relpaths(self, relpaths):
  50. ignore = set([
  51. "n06596364_9591.JPEG",
  52. ])
  53. relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
  54. if "sub_indices" in self.config:
  55. indices = str_to_indices(self.config["sub_indices"])
  56. synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
  57. files = []
  58. for rpath in relpaths:
  59. syn = rpath.split("/")[0]
  60. if syn in synsets:
  61. files.append(rpath)
  62. return files
  63. else:
  64. return relpaths
  65. def _prepare_synset_to_human(self):
  66. SIZE = 2655750
  67. URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
  68. self.human_dict = os.path.join(self.root, "synset_human.txt")
  69. if (not os.path.exists(self.human_dict) or
  70. not os.path.getsize(self.human_dict)==SIZE):
  71. download(URL, self.human_dict)
  72. def _prepare_idx_to_synset(self):
  73. URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
  74. self.idx2syn = os.path.join(self.root, "index_synset.yaml")
  75. if (not os.path.exists(self.idx2syn)):
  76. download(URL, self.idx2syn)
  77. def _load(self):
  78. with open(self.txt_filelist, "r") as f:
  79. self.relpaths = f.read().splitlines()
  80. l1 = len(self.relpaths)
  81. self.relpaths = self._filter_relpaths(self.relpaths)
  82. print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
  83. self.synsets = [p.split("/")[0] for p in self.relpaths]
  84. self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
  85. unique_synsets = np.unique(self.synsets)
  86. class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
  87. self.class_labels = [class_dict[s] for s in self.synsets]
  88. with open(self.human_dict, "r") as f:
  89. human_dict = f.read().splitlines()
  90. human_dict = dict(line.split(maxsplit=1) for line in human_dict)
  91. self.human_labels = [human_dict[s] for s in self.synsets]
  92. labels = {
  93. "relpath": np.array(self.relpaths),
  94. "synsets": np.array(self.synsets),
  95. "class_label": np.array(self.class_labels),
  96. "human_label": np.array(self.human_labels),
  97. }
  98. self.data = ImagePaths(self.abspaths,
  99. labels=labels,
  100. size=retrieve(self.config, "size", default=0),
  101. random_crop=self.random_crop)
  102. class ImageNetTrain(ImageNetBase):
  103. NAME = "ILSVRC2012_train"
  104. URL = "http://www.image-net.org/challenges/LSVRC/2012/"
  105. AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
  106. FILES = [
  107. "ILSVRC2012_img_train.tar",
  108. ]
  109. SIZES = [
  110. 147897477120,
  111. ]
  112. def _prepare(self):
  113. self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
  114. default=True)
  115. cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
  116. self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
  117. self.datadir = os.path.join(self.root, "data")
  118. self.txt_filelist = os.path.join(self.root, "filelist.txt")
  119. self.expected_length = 1281167
  120. if not bdu.is_prepared(self.root):
  121. # prep
  122. print("Preparing dataset {} in {}".format(self.NAME, self.root))
  123. datadir = self.datadir
  124. if not os.path.exists(datadir):
  125. path = os.path.join(self.root, self.FILES[0])
  126. if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
  127. import academictorrents as at
  128. atpath = at.get(self.AT_HASH, datastore=self.root)
  129. assert atpath == path
  130. print("Extracting {} to {}".format(path, datadir))
  131. os.makedirs(datadir, exist_ok=True)
  132. with tarfile.open(path, "r:") as tar:
  133. tar.extractall(path=datadir)
  134. print("Extracting sub-tars.")
  135. subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
  136. for subpath in tqdm(subpaths):
  137. subdir = subpath[:-len(".tar")]
  138. os.makedirs(subdir, exist_ok=True)
  139. with tarfile.open(subpath, "r:") as tar:
  140. tar.extractall(path=subdir)
  141. filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
  142. filelist = [os.path.relpath(p, start=datadir) for p in filelist]
  143. filelist = sorted(filelist)
  144. filelist = "\n".join(filelist)+"\n"
  145. with open(self.txt_filelist, "w") as f:
  146. f.write(filelist)
  147. bdu.mark_prepared(self.root)
  148. class ImageNetValidation(ImageNetBase):
  149. NAME = "ILSVRC2012_validation"
  150. URL = "http://www.image-net.org/challenges/LSVRC/2012/"
  151. AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
  152. VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
  153. FILES = [
  154. "ILSVRC2012_img_val.tar",
  155. "validation_synset.txt",
  156. ]
  157. SIZES = [
  158. 6744924160,
  159. 1950000,
  160. ]
  161. def _prepare(self):
  162. self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
  163. default=False)
  164. cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
  165. self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
  166. self.datadir = os.path.join(self.root, "data")
  167. self.txt_filelist = os.path.join(self.root, "filelist.txt")
  168. self.expected_length = 50000
  169. if not bdu.is_prepared(self.root):
  170. # prep
  171. print("Preparing dataset {} in {}".format(self.NAME, self.root))
  172. datadir = self.datadir
  173. if not os.path.exists(datadir):
  174. path = os.path.join(self.root, self.FILES[0])
  175. if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
  176. import academictorrents as at
  177. atpath = at.get(self.AT_HASH, datastore=self.root)
  178. assert atpath == path
  179. print("Extracting {} to {}".format(path, datadir))
  180. os.makedirs(datadir, exist_ok=True)
  181. with tarfile.open(path, "r:") as tar:
  182. tar.extractall(path=datadir)
  183. vspath = os.path.join(self.root, self.FILES[1])
  184. if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
  185. download(self.VS_URL, vspath)
  186. with open(vspath, "r") as f:
  187. synset_dict = f.read().splitlines()
  188. synset_dict = dict(line.split() for line in synset_dict)
  189. print("Reorganizing into synset folders")
  190. synsets = np.unique(list(synset_dict.values()))
  191. for s in synsets:
  192. os.makedirs(os.path.join(datadir, s), exist_ok=True)
  193. for k, v in synset_dict.items():
  194. src = os.path.join(datadir, k)
  195. dst = os.path.join(datadir, v)
  196. shutil.move(src, dst)
  197. filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
  198. filelist = [os.path.relpath(p, start=datadir) for p in filelist]
  199. filelist = sorted(filelist)
  200. filelist = "\n".join(filelist)+"\n"
  201. with open(self.txt_filelist, "w") as f:
  202. f.write(filelist)
  203. bdu.mark_prepared(self.root)
  204. def get_preprocessor(size=None, random_crop=False, additional_targets=None,
  205. crop_size=None):
  206. if size is not None and size > 0:
  207. transforms = list()
  208. rescaler = albumentations.SmallestMaxSize(max_size = size)
  209. transforms.append(rescaler)
  210. if not random_crop:
  211. cropper = albumentations.CenterCrop(height=size,width=size)
  212. transforms.append(cropper)
  213. else:
  214. cropper = albumentations.RandomCrop(height=size,width=size)
  215. transforms.append(cropper)
  216. flipper = albumentations.HorizontalFlip()
  217. transforms.append(flipper)
  218. preprocessor = albumentations.Compose(transforms,
  219. additional_targets=additional_targets)
  220. elif crop_size is not None and crop_size > 0:
  221. if not random_crop:
  222. cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
  223. else:
  224. cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
  225. transforms = [cropper]
  226. preprocessor = albumentations.Compose(transforms,
  227. additional_targets=additional_targets)
  228. else:
  229. preprocessor = lambda **kwargs: kwargs
  230. return preprocessor
  231. def rgba_to_depth(x):
  232. assert x.dtype == np.uint8
  233. assert len(x.shape) == 3 and x.shape[2] == 4
  234. y = x.copy()
  235. y.dtype = np.float32
  236. y = y.reshape(x.shape[:2])
  237. return np.ascontiguousarray(y)
  238. class BaseWithDepth(Dataset):
  239. DEFAULT_DEPTH_ROOT="data/imagenet_depth"
  240. def __init__(self, config=None, size=None, random_crop=False,
  241. crop_size=None, root=None):
  242. self.config = config
  243. self.base_dset = self.get_base_dset()
  244. self.preprocessor = get_preprocessor(
  245. size=size,
  246. crop_size=crop_size,
  247. random_crop=random_crop,
  248. additional_targets={"depth": "image"})
  249. self.crop_size = crop_size
  250. if self.crop_size is not None:
  251. self.rescaler = albumentations.Compose(
  252. [albumentations.SmallestMaxSize(max_size = self.crop_size)],
  253. additional_targets={"depth": "image"})
  254. if root is not None:
  255. self.DEFAULT_DEPTH_ROOT = root
  256. def __len__(self):
  257. return len(self.base_dset)
  258. def preprocess_depth(self, path):
  259. rgba = np.array(Image.open(path))
  260. depth = rgba_to_depth(rgba)
  261. depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
  262. depth = 2.0*depth-1.0
  263. return depth
  264. def __getitem__(self, i):
  265. e = self.base_dset[i]
  266. e["depth"] = self.preprocess_depth(self.get_depth_path(e))
  267. # up if necessary
  268. h,w,c = e["image"].shape
  269. if self.crop_size and min(h,w) < self.crop_size:
  270. # have to upscale to be able to crop - this just uses bilinear
  271. out = self.rescaler(image=e["image"], depth=e["depth"])
  272. e["image"] = out["image"]
  273. e["depth"] = out["depth"]
  274. transformed = self.preprocessor(image=e["image"], depth=e["depth"])
  275. e["image"] = transformed["image"]
  276. e["depth"] = transformed["depth"]
  277. return e
  278. class ImageNetTrainWithDepth(BaseWithDepth):
  279. # default to random_crop=True
  280. def __init__(self, random_crop=True, sub_indices=None, **kwargs):
  281. self.sub_indices = sub_indices
  282. super().__init__(random_crop=random_crop, **kwargs)
  283. def get_base_dset(self):
  284. if self.sub_indices is None:
  285. return ImageNetTrain()
  286. else:
  287. return ImageNetTrain({"sub_indices": self.sub_indices})
  288. def get_depth_path(self, e):
  289. fid = os.path.splitext(e["relpath"])[0]+".png"
  290. fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
  291. return fid
  292. class ImageNetValidationWithDepth(BaseWithDepth):
  293. def __init__(self, sub_indices=None, **kwargs):
  294. self.sub_indices = sub_indices
  295. super().__init__(**kwargs)
  296. def get_base_dset(self):
  297. if self.sub_indices is None:
  298. return ImageNetValidation()
  299. else:
  300. return ImageNetValidation({"sub_indices": self.sub_indices})
  301. def get_depth_path(self, e):
  302. fid = os.path.splitext(e["relpath"])[0]+".png"
  303. fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
  304. return fid
  305. class RINTrainWithDepth(ImageNetTrainWithDepth):
  306. def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
  307. sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
  308. super().__init__(config=config, size=size, random_crop=random_crop,
  309. sub_indices=sub_indices, crop_size=crop_size)
  310. class RINValidationWithDepth(ImageNetValidationWithDepth):
  311. def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
  312. sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
  313. super().__init__(config=config, size=size, random_crop=random_crop,
  314. sub_indices=sub_indices, crop_size=crop_size)
  315. class DRINExamples(Dataset):
  316. def __init__(self):
  317. self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
  318. with open("data/drin_examples.txt", "r") as f:
  319. relpaths = f.read().splitlines()
  320. self.image_paths = [os.path.join("data/drin_images",
  321. relpath) for relpath in relpaths]
  322. self.depth_paths = [os.path.join("data/drin_depth",
  323. relpath.replace(".JPEG", ".png")) for relpath in relpaths]
  324. def __len__(self):
  325. return len(self.image_paths)
  326. def preprocess_image(self, image_path):
  327. image = Image.open(image_path)
  328. if not image.mode == "RGB":
  329. image = image.convert("RGB")
  330. image = np.array(image).astype(np.uint8)
  331. image = self.preprocessor(image=image)["image"]
  332. image = (image/127.5 - 1.0).astype(np.float32)
  333. return image
  334. def preprocess_depth(self, path):
  335. rgba = np.array(Image.open(path))
  336. depth = rgba_to_depth(rgba)
  337. depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
  338. depth = 2.0*depth-1.0
  339. return depth
  340. def __getitem__(self, i):
  341. e = dict()
  342. e["image"] = self.preprocess_image(self.image_paths[i])
  343. e["depth"] = self.preprocess_depth(self.depth_paths[i])
  344. transformed = self.preprocessor(image=e["image"], depth=e["depth"])
  345. e["image"] = transformed["image"]
  346. e["depth"] = transformed["depth"]
  347. return e
  348. def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
  349. if factor is None or factor==1:
  350. return x
  351. dtype = x.dtype
  352. assert dtype in [np.float32, np.float64]
  353. assert x.min() >= -1
  354. assert x.max() <= 1
  355. keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
  356. "bicubic": Image.BICUBIC}[keepmode]
  357. lr = (x+1.0)*127.5
  358. lr = lr.clip(0,255).astype(np.uint8)
  359. lr = Image.fromarray(lr)
  360. h, w, _ = x.shape
  361. nh = h//factor
  362. nw = w//factor
  363. assert nh > 0 and nw > 0, (nh, nw)
  364. lr = lr.resize((nw,nh), Image.BICUBIC)
  365. if keepshapes:
  366. lr = lr.resize((w,h), keepmode)
  367. lr = np.array(lr)/127.5-1.0
  368. lr = lr.astype(dtype)
  369. return lr
  370. class ImageNetScale(Dataset):
  371. def __init__(self, size=None, crop_size=None, random_crop=False,
  372. up_factor=None, hr_factor=None, keep_mode="bicubic"):
  373. self.base = self.get_base()
  374. self.size = size
  375. self.crop_size = crop_size if crop_size is not None else self.size
  376. self.random_crop = random_crop
  377. self.up_factor = up_factor
  378. self.hr_factor = hr_factor
  379. self.keep_mode = keep_mode
  380. transforms = list()
  381. if self.size is not None and self.size > 0:
  382. rescaler = albumentations.SmallestMaxSize(max_size = self.size)
  383. self.rescaler = rescaler
  384. transforms.append(rescaler)
  385. if self.crop_size is not None and self.crop_size > 0:
  386. if len(transforms) == 0:
  387. self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
  388. if not self.random_crop:
  389. cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
  390. else:
  391. cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
  392. transforms.append(cropper)
  393. if len(transforms) > 0:
  394. if self.up_factor is not None:
  395. additional_targets = {"lr": "image"}
  396. else:
  397. additional_targets = None
  398. self.preprocessor = albumentations.Compose(transforms,
  399. additional_targets=additional_targets)
  400. else:
  401. self.preprocessor = lambda **kwargs: kwargs
  402. def __len__(self):
  403. return len(self.base)
  404. def __getitem__(self, i):
  405. example = self.base[i]
  406. image = example["image"]
  407. # adjust resolution
  408. image = imscale(image, self.hr_factor, keepshapes=False)
  409. h,w,c = image.shape
  410. if self.crop_size and min(h,w) < self.crop_size:
  411. # have to upscale to be able to crop - this just uses bilinear
  412. image = self.rescaler(image=image)["image"]
  413. if self.up_factor is None:
  414. image = self.preprocessor(image=image)["image"]
  415. example["image"] = image
  416. else:
  417. lr = imscale(image, self.up_factor, keepshapes=True,
  418. keepmode=self.keep_mode)
  419. out = self.preprocessor(image=image, lr=lr)
  420. example["image"] = out["image"]
  421. example["lr"] = out["lr"]
  422. return example
  423. class ImageNetScaleTrain(ImageNetScale):
  424. def __init__(self, random_crop=True, **kwargs):
  425. super().__init__(random_crop=random_crop, **kwargs)
  426. def get_base(self):
  427. return ImageNetTrain()
  428. class ImageNetScaleValidation(ImageNetScale):
  429. def get_base(self):
  430. return ImageNetValidation()
  431. from skimage.feature import canny
  432. from skimage.color import rgb2gray
  433. class ImageNetEdges(ImageNetScale):
  434. def __init__(self, up_factor=1, **kwargs):
  435. super().__init__(up_factor=1, **kwargs)
  436. def __getitem__(self, i):
  437. example = self.base[i]
  438. image = example["image"]
  439. h,w,c = image.shape
  440. if self.crop_size and min(h,w) < self.crop_size:
  441. # have to upscale to be able to crop - this just uses bilinear
  442. image = self.rescaler(image=image)["image"]
  443. lr = canny(rgb2gray(image), sigma=2)
  444. lr = lr.astype(np.float32)
  445. lr = lr[:,:,None][:,:,[0,0,0]]
  446. out = self.preprocessor(image=image, lr=lr)
  447. example["image"] = out["image"]
  448. example["lr"] = out["lr"]
  449. return example
  450. class ImageNetEdgesTrain(ImageNetEdges):
  451. def __init__(self, random_crop=True, **kwargs):
  452. super().__init__(random_crop=random_crop, **kwargs)
  453. def get_base(self):
  454. return ImageNetTrain()
  455. class ImageNetEdgesValidation(ImageNetEdges):
  456. def get_base(self):
  457. return ImageNetValidation()