9.9 KB

  1. import argparse, os, sys, glob, math, time
  2. import torch
  3. import numpy as np
  4. from omegaconf import OmegaConf
  5. from PIL import Image
  6. from main import instantiate_from_config, DataModuleFromConfig
  7. from import DataLoader
  8. from import default_collate
  9. from tqdm import trange
  10. def save_image(x, path):
  11. c,h,w = x.shape
  12. assert c==3
  13. x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
  14. Image.fromarray(x).save(path)
  15. @torch.no_grad()
  16. def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
  17. if len(dsets.datasets) > 1:
  18. split = sorted(dsets.datasets.keys())[0]
  19. dset = dsets.datasets[split]
  20. else:
  21. dset = next(iter(dsets.datasets.values()))
  22. print("Dataset: ", dset.__class__.__name__)
  23. for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
  24. indices = list(range(start_idx, start_idx+batch_size))
  25. example = default_collate([dset[i] for i in indices])
  26. x = model.get_input("image", example).to(model.device)
  27. for i in range(x.shape[0]):
  28. save_image(x[i], os.path.join(outdir, "originals",
  29. "{:06}.png".format(indices[i])))
  30. cond_key = model.cond_stage_key
  31. c = model.get_input(cond_key, example).to(model.device)
  32. scale_factor = 1.0
  33. quant_z, z_indices = model.encode_to_z(x)
  34. quant_c, c_indices = model.encode_to_c(c)
  35. cshape = quant_z.shape
  36. xrec = model.first_stage_model.decode(quant_z)
  37. for i in range(xrec.shape[0]):
  38. save_image(xrec[i], os.path.join(outdir, "reconstructions",
  39. "{:06}.png".format(indices[i])))
  40. if cond_key == "segmentation":
  41. # get image from segmentation mask
  42. num_classes = c.shape[1]
  43. c = torch.argmax(c, dim=1, keepdim=True)
  44. c = torch.nn.functional.one_hot(c, num_classes=num_classes)
  45. c = c.squeeze(1).permute(0, 3, 1, 2).float()
  46. c = model.cond_stage_model.to_rgb(c)
  47. idx = z_indices
  48. half_sample = False
  49. if half_sample:
  50. start = idx.shape[1]//2
  51. else:
  52. start = 0
  53. idx[:,start:] = 0
  54. idx = idx.reshape(cshape[0],cshape[2],cshape[3])
  55. start_i = start//cshape[3]
  56. start_j = start %cshape[3]
  57. cidx = c_indices
  58. cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
  59. sample = True
  60. for i in range(start_i,cshape[2]-0):
  61. if i <= 8:
  62. local_i = i
  63. elif cshape[2]-i < 8:
  64. local_i = 16-(cshape[2]-i)
  65. else:
  66. local_i = 8
  67. for j in range(start_j,cshape[3]-0):
  68. if j <= 8:
  69. local_j = j
  70. elif cshape[3]-j < 8:
  71. local_j = 16-(cshape[3]-j)
  72. else:
  73. local_j = 8
  74. i_start = i-local_i
  75. i_end = i_start+16
  76. j_start = j-local_j
  77. j_end = j_start+16
  78. patch = idx[:,i_start:i_end,j_start:j_end]
  79. patch = patch.reshape(patch.shape[0],-1)
  80. cpatch = cidx[:, i_start:i_end, j_start:j_end]
  81. cpatch = cpatch.reshape(cpatch.shape[0], -1)
  82. patch =, patch), dim=1)
  83. logits,_ = model.transformer(patch[:,:-1])
  84. logits = logits[:, -256:, :]
  85. logits = logits.reshape(cshape[0],16,16,-1)
  86. logits = logits[:,local_i,local_j,:]
  87. logits = logits/temperature
  88. if top_k is not None:
  89. logits = model.top_k_logits(logits, top_k)
  90. # apply softmax to convert to probabilities
  91. probs = torch.nn.functional.softmax(logits, dim=-1)
  92. # sample from the distribution or take the most likely
  93. if sample:
  94. ix = torch.multinomial(probs, num_samples=1)
  95. else:
  96. _, ix = torch.topk(probs, k=1, dim=-1)
  97. idx[:,i,j] = ix
  98. xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
  99. for i in range(xsample.shape[0]):
  100. save_image(xsample[i], os.path.join(outdir, "samples",
  101. "{:06}.png".format(indices[i])))
  102. def get_parser():
  103. parser = argparse.ArgumentParser()
  104. parser.add_argument(
  105. "-r",
  106. "--resume",
  107. type=str,
  108. nargs="?",
  109. help="load from logdir or checkpoint in logdir",
  110. )
  111. parser.add_argument(
  112. "-b",
  113. "--base",
  114. nargs="*",
  115. metavar="base_config.yaml",
  116. help="paths to base configs. Loaded from left-to-right. "
  117. "Parameters can be overwritten or added with command-line options of the form `--key value`.",
  118. default=list(),
  119. )
  120. parser.add_argument(
  121. "-c",
  122. "--config",
  123. nargs="?",
  124. metavar="single_config.yaml",
  125. help="path to single config. If specified, base configs will be ignored "
  126. "(except for the last one if left unspecified).",
  127. const=True,
  128. default="",
  129. )
  130. parser.add_argument(
  131. "--ignore_base_data",
  132. action="store_true",
  133. help="Ignore data specification from base configs. Useful if you want "
  134. "to specify a custom datasets on the command line.",
  135. )
  136. parser.add_argument(
  137. "--outdir",
  138. required=True,
  139. type=str,
  140. help="Where to write outputs to.",
  141. )
  142. parser.add_argument(
  143. "--top_k",
  144. type=int,
  145. default=100,
  146. help="Sample from among top-k predictions.",
  147. )
  148. parser.add_argument(
  149. "--temperature",
  150. type=float,
  151. default=1.0,
  152. help="Sampling temperature.",
  153. )
  154. return parser
  155. def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  156. if "ckpt_path" in config.params:
  157. print("Deleting the restore-ckpt path from the config...")
  158. config.params.ckpt_path = None
  159. if "downsample_cond_size" in config.params:
  160. print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
  161. config.params.downsample_cond_size = -1
  162. config.params["downsample_cond_factor"] = 0.5
  163. try:
  164. if "ckpt_path" in config.params.first_stage_config.params:
  165. config.params.first_stage_config.params.ckpt_path = None
  166. print("Deleting the first-stage restore-ckpt path from the config...")
  167. if "ckpt_path" in config.params.cond_stage_config.params:
  168. config.params.cond_stage_config.params.ckpt_path = None
  169. print("Deleting the cond-stage restore-ckpt path from the config...")
  170. except:
  171. pass
  172. model = instantiate_from_config(config)
  173. if sd is not None:
  174. missing, unexpected = model.load_state_dict(sd, strict=False)
  175. print(f"Missing Keys in State Dict: {missing}")
  176. print(f"Unexpected Keys in State Dict: {unexpected}")
  177. if gpu:
  178. model.cuda()
  179. if eval_mode:
  180. model.eval()
  181. return {"model": model}
  182. def get_data(config):
  183. # get data
  184. data = instantiate_from_config(
  185. data.prepare_data()
  186. data.setup()
  187. return data
  188. def load_model_and_dset(config, ckpt, gpu, eval_mode):
  189. # get data
  190. dsets = get_data(config) # calls data.config ...
  191. # now load the specified checkpoint
  192. if ckpt:
  193. pl_sd = torch.load(ckpt, map_location="cpu")
  194. global_step = pl_sd["global_step"]
  195. else:
  196. pl_sd = {"state_dict": None}
  197. global_step = None
  198. model = load_model_from_config(config.model,
  199. pl_sd["state_dict"],
  200. gpu=gpu,
  201. eval_mode=eval_mode)["model"]
  202. return dsets, model, global_step
  203. if __name__ == "__main__":
  204. sys.path.append(os.getcwd())
  205. parser = get_parser()
  206. opt, unknown = parser.parse_known_args()
  207. ckpt = None
  208. if opt.resume:
  209. if not os.path.exists(opt.resume):
  210. raise ValueError("Cannot find {}".format(opt.resume))
  211. if os.path.isfile(opt.resume):
  212. paths = opt.resume.split("/")
  213. try:
  214. idx = len(paths)-paths[::-1].index("logs")+1
  215. except ValueError:
  216. idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
  217. logdir = "/".join(paths[:idx])
  218. ckpt = opt.resume
  219. else:
  220. assert os.path.isdir(opt.resume), opt.resume
  221. logdir = opt.resume.rstrip("/")
  222. ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
  223. print(f"logdir:{logdir}")
  224. base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
  225. opt.base = base_configs+opt.base
  226. if opt.config:
  227. if type(opt.config) == str:
  228. opt.base = [opt.config]
  229. else:
  230. opt.base = [opt.base[-1]]
  231. configs = [OmegaConf.load(cfg) for cfg in opt.base]
  232. cli = OmegaConf.from_dotlist(unknown)
  233. if opt.ignore_base_data:
  234. for config in configs:
  235. if hasattr(config, "data"): del config["data"]
  236. config = OmegaConf.merge(*configs, cli)
  237. print(ckpt)
  238. gpu = True
  239. eval_mode = True
  240. show_config = False
  241. if show_config:
  242. print(OmegaConf.to_container(config))
  243. dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
  244. print(f"Global step: {global_step}")
  245. outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
  246. opt.top_k,
  247. opt.temperature))
  248. os.makedirs(outdir, exist_ok=True)
  249. print("Writing samples to ", outdir)
  250. for k in ["originals", "reconstructions", "samples"]:
  251. os.makedirs(os.path.join(outdir, k), exist_ok=True)
  252. run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)