make_samples.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import argparse, os, sys, glob, math, time
  2. import torch
  3. import numpy as np
  4. from omegaconf import OmegaConf
  5. from PIL import Image
  6. from main import instantiate_from_config, DataModuleFromConfig
  7. from torch.utils.data import DataLoader
  8. from torch.utils.data.dataloader import default_collate
  9. from tqdm import trange
  10. def save_image(x, path):
  11. c,h,w = x.shape
  12. assert c==3
  13. x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
  14. Image.fromarray(x).save(path)
  15. @torch.no_grad()
  16. def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
  17. if len(dsets.datasets) > 1:
  18. split = sorted(dsets.datasets.keys())[0]
  19. dset = dsets.datasets[split]
  20. else:
  21. dset = next(iter(dsets.datasets.values()))
  22. print("Dataset: ", dset.__class__.__name__)
  23. for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
  24. indices = list(range(start_idx, start_idx+batch_size))
  25. example = default_collate([dset[i] for i in indices])
  26. x = model.get_input("image", example).to(model.device)
  27. for i in range(x.shape[0]):
  28. save_image(x[i], os.path.join(outdir, "originals",
  29. "{:06}.png".format(indices[i])))
  30. cond_key = model.cond_stage_key
  31. c = model.get_input(cond_key, example).to(model.device)
  32. scale_factor = 1.0
  33. quant_z, z_indices = model.encode_to_z(x)
  34. quant_c, c_indices = model.encode_to_c(c)
  35. cshape = quant_z.shape
  36. xrec = model.first_stage_model.decode(quant_z)
  37. for i in range(xrec.shape[0]):
  38. save_image(xrec[i], os.path.join(outdir, "reconstructions",
  39. "{:06}.png".format(indices[i])))
  40. if cond_key == "segmentation":
  41. # get image from segmentation mask
  42. num_classes = c.shape[1]
  43. c = torch.argmax(c, dim=1, keepdim=True)
  44. c = torch.nn.functional.one_hot(c, num_classes=num_classes)
  45. c = c.squeeze(1).permute(0, 3, 1, 2).float()
  46. c = model.cond_stage_model.to_rgb(c)
  47. idx = z_indices
  48. half_sample = False
  49. if half_sample:
  50. start = idx.shape[1]//2
  51. else:
  52. start = 0
  53. idx[:,start:] = 0
  54. idx = idx.reshape(cshape[0],cshape[2],cshape[3])
  55. start_i = start//cshape[3]
  56. start_j = start %cshape[3]
  57. cidx = c_indices
  58. cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
  59. sample = True
  60. for i in range(start_i,cshape[2]-0):
  61. if i <= 8:
  62. local_i = i
  63. elif cshape[2]-i < 8:
  64. local_i = 16-(cshape[2]-i)
  65. else:
  66. local_i = 8
  67. for j in range(start_j,cshape[3]-0):
  68. if j <= 8:
  69. local_j = j
  70. elif cshape[3]-j < 8:
  71. local_j = 16-(cshape[3]-j)
  72. else:
  73. local_j = 8
  74. i_start = i-local_i
  75. i_end = i_start+16
  76. j_start = j-local_j
  77. j_end = j_start+16
  78. patch = idx[:,i_start:i_end,j_start:j_end]
  79. patch = patch.reshape(patch.shape[0],-1)
  80. cpatch = cidx[:, i_start:i_end, j_start:j_end]
  81. cpatch = cpatch.reshape(cpatch.shape[0], -1)
  82. patch = torch.cat((cpatch, patch), dim=1)
  83. logits,_ = model.transformer(patch[:,:-1])
  84. logits = logits[:, -256:, :]
  85. logits = logits.reshape(cshape[0],16,16,-1)
  86. logits = logits[:,local_i,local_j,:]
  87. logits = logits/temperature
  88. if top_k is not None:
  89. logits = model.top_k_logits(logits, top_k)
  90. # apply softmax to convert to probabilities
  91. probs = torch.nn.functional.softmax(logits, dim=-1)
  92. # sample from the distribution or take the most likely
  93. if sample:
  94. ix = torch.multinomial(probs, num_samples=1)
  95. else:
  96. _, ix = torch.topk(probs, k=1, dim=-1)
  97. idx[:,i,j] = ix
  98. xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
  99. for i in range(xsample.shape[0]):
  100. save_image(xsample[i], os.path.join(outdir, "samples",
  101. "{:06}.png".format(indices[i])))
  102. def get_parser():
  103. parser = argparse.ArgumentParser()
  104. parser.add_argument(
  105. "-r",
  106. "--resume",
  107. type=str,
  108. nargs="?",
  109. help="load from logdir or checkpoint in logdir",
  110. )
  111. parser.add_argument(
  112. "-b",
  113. "--base",
  114. nargs="*",
  115. metavar="base_config.yaml",
  116. help="paths to base configs. Loaded from left-to-right. "
  117. "Parameters can be overwritten or added with command-line options of the form `--key value`.",
  118. default=list(),
  119. )
  120. parser.add_argument(
  121. "-c",
  122. "--config",
  123. nargs="?",
  124. metavar="single_config.yaml",
  125. help="path to single config. If specified, base configs will be ignored "
  126. "(except for the last one if left unspecified).",
  127. const=True,
  128. default="",
  129. )
  130. parser.add_argument(
  131. "--ignore_base_data",
  132. action="store_true",
  133. help="Ignore data specification from base configs. Useful if you want "
  134. "to specify a custom datasets on the command line.",
  135. )
  136. parser.add_argument(
  137. "--outdir",
  138. required=True,
  139. type=str,
  140. help="Where to write outputs to.",
  141. )
  142. parser.add_argument(
  143. "--top_k",
  144. type=int,
  145. default=100,
  146. help="Sample from among top-k predictions.",
  147. )
  148. parser.add_argument(
  149. "--temperature",
  150. type=float,
  151. default=1.0,
  152. help="Sampling temperature.",
  153. )
  154. return parser
  155. def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  156. if "ckpt_path" in config.params:
  157. print("Deleting the restore-ckpt path from the config...")
  158. config.params.ckpt_path = None
  159. if "downsample_cond_size" in config.params:
  160. print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
  161. config.params.downsample_cond_size = -1
  162. config.params["downsample_cond_factor"] = 0.5
  163. try:
  164. if "ckpt_path" in config.params.first_stage_config.params:
  165. config.params.first_stage_config.params.ckpt_path = None
  166. print("Deleting the first-stage restore-ckpt path from the config...")
  167. if "ckpt_path" in config.params.cond_stage_config.params:
  168. config.params.cond_stage_config.params.ckpt_path = None
  169. print("Deleting the cond-stage restore-ckpt path from the config...")
  170. except:
  171. pass
  172. model = instantiate_from_config(config)
  173. if sd is not None:
  174. missing, unexpected = model.load_state_dict(sd, strict=False)
  175. print(f"Missing Keys in State Dict: {missing}")
  176. print(f"Unexpected Keys in State Dict: {unexpected}")
  177. if gpu:
  178. model.cuda()
  179. if eval_mode:
  180. model.eval()
  181. return {"model": model}
  182. def get_data(config):
  183. # get data
  184. data = instantiate_from_config(config.data)
  185. data.prepare_data()
  186. data.setup()
  187. return data
  188. def load_model_and_dset(config, ckpt, gpu, eval_mode):
  189. # get data
  190. dsets = get_data(config) # calls data.config ...
  191. # now load the specified checkpoint
  192. if ckpt:
  193. pl_sd = torch.load(ckpt, map_location="cpu")
  194. global_step = pl_sd["global_step"]
  195. else:
  196. pl_sd = {"state_dict": None}
  197. global_step = None
  198. model = load_model_from_config(config.model,
  199. pl_sd["state_dict"],
  200. gpu=gpu,
  201. eval_mode=eval_mode)["model"]
  202. return dsets, model, global_step
  203. if __name__ == "__main__":
  204. sys.path.append(os.getcwd())
  205. parser = get_parser()
  206. opt, unknown = parser.parse_known_args()
  207. ckpt = None
  208. if opt.resume:
  209. if not os.path.exists(opt.resume):
  210. raise ValueError("Cannot find {}".format(opt.resume))
  211. if os.path.isfile(opt.resume):
  212. paths = opt.resume.split("/")
  213. try:
  214. idx = len(paths)-paths[::-1].index("logs")+1
  215. except ValueError:
  216. idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
  217. logdir = "/".join(paths[:idx])
  218. ckpt = opt.resume
  219. else:
  220. assert os.path.isdir(opt.resume), opt.resume
  221. logdir = opt.resume.rstrip("/")
  222. ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
  223. print(f"logdir:{logdir}")
  224. base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
  225. opt.base = base_configs+opt.base
  226. if opt.config:
  227. if type(opt.config) == str:
  228. opt.base = [opt.config]
  229. else:
  230. opt.base = [opt.base[-1]]
  231. configs = [OmegaConf.load(cfg) for cfg in opt.base]
  232. cli = OmegaConf.from_dotlist(unknown)
  233. if opt.ignore_base_data:
  234. for config in configs:
  235. if hasattr(config, "data"): del config["data"]
  236. config = OmegaConf.merge(*configs, cli)
  237. print(ckpt)
  238. gpu = True
  239. eval_mode = True
  240. show_config = False
  241. if show_config:
  242. print(OmegaConf.to_container(config))
  243. dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
  244. print(f"Global step: {global_step}")
  245. outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
  246. opt.top_k,
  247. opt.temperature))
  248. os.makedirs(outdir, exist_ok=True)
  249. print("Writing samples to ", outdir)
  250. for k in ["originals", "reconstructions", "samples"]:
  251. os.makedirs(os.path.join(outdir, k), exist_ok=True)
  252. run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)