make_scene_samples.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import glob
  2. import os
  3. import sys
  4. from itertools import product
  5. from pathlib import Path
  6. from typing import Literal, List, Optional, Tuple
  7. import numpy as np
  8. import torch
  9. from omegaconf import OmegaConf
  10. from pytorch_lightning import seed_everything
  11. from torch import Tensor
  12. from torchvision.utils import save_image
  13. from tqdm import tqdm
  14. from scripts.make_samples import get_parser, load_model_and_dset
  15. from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
  16. from taming.data.helper_types import BoundingBox, Annotation
  17. from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
  18. from taming.models.cond_transformer import Net2NetTransformer
  19. seed_everything(42424242)
  20. device: Literal['cuda', 'cpu'] = 'cuda'
  21. first_stage_factor = 16
  22. trained_on_res = 256
  23. def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
  24. assert 0 <= coord < coord_max
  25. coord_desired_center = (coord_window - 1) // 2
  26. return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
  27. def get_crop_coordinates(x: int, y: int) -> BoundingBox:
  28. WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
  29. x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
  30. y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
  31. w = first_stage_factor / WIDTH
  32. h = first_stage_factor / HEIGHT
  33. return x0, y0, w, h
  34. def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
  35. WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
  36. x0 = _helper(predict_x, WIDTH, first_stage_factor)
  37. y0 = _helper(predict_y, HEIGHT, first_stage_factor)
  38. no_images = z_indices.shape[0]
  39. cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
  40. cut_out_2 = z_indices[:, predict_y, x0:predict_x]
  41. return torch.cat((cut_out_1, cut_out_2), dim=1)
  42. @torch.no_grad()
  43. def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
  44. conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
  45. temperature: float, top_k: int) -> Tensor:
  46. x_max, y_max = desired_z_shape[1], desired_z_shape[0]
  47. annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
  48. recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
  49. if not recompute_conditional:
  50. crop_coordinates = get_crop_coordinates(0, 0)
  51. conditional_indices = conditional_builder.build(annotations, crop_coordinates)
  52. c_indices = conditional_indices.to(device).repeat(no_samples, 1)
  53. z_indices = torch.zeros((no_samples, 0), device=device).long()
  54. output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
  55. sample=True, top_k=top_k)
  56. else:
  57. output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
  58. for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
  59. crop_coordinates = get_crop_coordinates(predict_x, predict_y)
  60. z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
  61. conditional_indices = conditional_builder.build(annotations, crop_coordinates)
  62. c_indices = conditional_indices.to(device).repeat(no_samples, 1)
  63. new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
  64. output_indices[:, predict_y, predict_x] = new_index[:, -1]
  65. z_shape = (
  66. no_samples,
  67. model.first_stage_model.quantize.e_dim, # codebook embed_dim
  68. desired_z_shape[0], # z_height
  69. desired_z_shape[1] # z_width
  70. )
  71. x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
  72. x_sample = x_sample.to('cpu')
  73. plotter = conditional_builder.plot
  74. figure_size = (x_sample.shape[2], x_sample.shape[3])
  75. scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
  76. plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
  77. return torch.cat((x_sample, plot.unsqueeze(0)))
  78. def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
  79. if not resolution_str.count(',') == 1:
  80. raise ValueError("Give resolution as in 'height,width'")
  81. res_h, res_w = resolution_str.split(',')
  82. res_h = max(int(res_h), trained_on_res)
  83. res_w = max(int(res_w), trained_on_res)
  84. z_h = int(round(res_h/first_stage_factor))
  85. z_w = int(round(res_w/first_stage_factor))
  86. return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
  87. def add_arg_to_parser(parser):
  88. parser.add_argument(
  89. "-R",
  90. "--resolution",
  91. type=str,
  92. default='256,256',
  93. help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
  94. )
  95. parser.add_argument(
  96. "-C",
  97. "--conditional",
  98. type=str,
  99. default='objects_bbox',
  100. help=f"objects_bbox or objects_center_points",
  101. )
  102. parser.add_argument(
  103. "-N",
  104. "--n_samples_per_layout",
  105. type=int,
  106. default=4,
  107. help=f"how many samples to generate per layout",
  108. )
  109. return parser
  110. if __name__ == "__main__":
  111. sys.path.append(os.getcwd())
  112. parser = get_parser()
  113. parser = add_arg_to_parser(parser)
  114. opt, unknown = parser.parse_known_args()
  115. ckpt = None
  116. if opt.resume:
  117. if not os.path.exists(opt.resume):
  118. raise ValueError("Cannot find {}".format(opt.resume))
  119. if os.path.isfile(opt.resume):
  120. paths = opt.resume.split("/")
  121. try:
  122. idx = len(paths)-paths[::-1].index("logs")+1
  123. except ValueError:
  124. idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
  125. logdir = "/".join(paths[:idx])
  126. ckpt = opt.resume
  127. else:
  128. assert os.path.isdir(opt.resume), opt.resume
  129. logdir = opt.resume.rstrip("/")
  130. ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
  131. print(f"logdir:{logdir}")
  132. base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
  133. opt.base = base_configs+opt.base
  134. if opt.config:
  135. if type(opt.config) == str:
  136. opt.base = [opt.config]
  137. else:
  138. opt.base = [opt.base[-1]]
  139. configs = [OmegaConf.load(cfg) for cfg in opt.base]
  140. cli = OmegaConf.from_dotlist(unknown)
  141. if opt.ignore_base_data:
  142. for config in configs:
  143. if hasattr(config, "data"):
  144. del config["data"]
  145. config = OmegaConf.merge(*configs, cli)
  146. desired_z_shape, desired_resolution = get_resolution(opt.resolution)
  147. conditional = opt.conditional
  148. print(ckpt)
  149. gpu = True
  150. eval_mode = True
  151. show_config = False
  152. if show_config:
  153. print(OmegaConf.to_container(config))
  154. dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
  155. print(f"Global step: {global_step}")
  156. data_loader = dsets.val_dataloader()
  157. print(dsets.datasets["validation"].conditional_builders)
  158. conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
  159. outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
  160. outdir.mkdir(exist_ok=True, parents=True)
  161. print("Writing samples to ", outdir)
  162. p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
  163. for batch_no, batch in p_bar_1:
  164. save_img: Optional[Tensor] = None
  165. for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
  166. imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
  167. opt.n_samples_per_layout, opt.temperature, opt.top_k)
  168. save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)