cond_transformer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import os, math
  2. import torch
  3. import torch.nn.functional as F
  4. import pytorch_lightning as pl
  5. from main import instantiate_from_config
  6. from taming.modules.util import SOSProvider
  7. def disabled_train(self, mode=True):
  8. """Overwrite model.train with this function to make sure train/eval mode
  9. does not change anymore."""
  10. return self
  11. class Net2NetTransformer(pl.LightningModule):
  12. def __init__(self,
  13. transformer_config,
  14. first_stage_config,
  15. cond_stage_config,
  16. permuter_config=None,
  17. ckpt_path=None,
  18. ignore_keys=[],
  19. first_stage_key="image",
  20. cond_stage_key="depth",
  21. downsample_cond_size=-1,
  22. pkeep=1.0,
  23. sos_token=0,
  24. unconditional=False,
  25. ):
  26. super().__init__()
  27. self.be_unconditional = unconditional
  28. self.sos_token = sos_token
  29. self.first_stage_key = first_stage_key
  30. self.cond_stage_key = cond_stage_key
  31. self.init_first_stage_from_ckpt(first_stage_config)
  32. self.init_cond_stage_from_ckpt(cond_stage_config)
  33. if permuter_config is None:
  34. permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
  35. self.permuter = instantiate_from_config(config=permuter_config)
  36. self.transformer = instantiate_from_config(config=transformer_config)
  37. if ckpt_path is not None:
  38. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
  39. self.downsample_cond_size = downsample_cond_size
  40. self.pkeep = pkeep
  41. def init_from_ckpt(self, path, ignore_keys=list()):
  42. sd = torch.load(path, map_location="cpu")["state_dict"]
  43. for k in sd.keys():
  44. for ik in ignore_keys:
  45. if k.startswith(ik):
  46. self.print("Deleting key {} from state_dict.".format(k))
  47. del sd[k]
  48. self.load_state_dict(sd, strict=False)
  49. print(f"Restored from {path}")
  50. def init_first_stage_from_ckpt(self, config):
  51. model = instantiate_from_config(config)
  52. model = model.eval()
  53. model.train = disabled_train
  54. self.first_stage_model = model
  55. def init_cond_stage_from_ckpt(self, config):
  56. if config == "__is_first_stage__":
  57. print("Using first stage also as cond stage.")
  58. self.cond_stage_model = self.first_stage_model
  59. elif config == "__is_unconditional__" or self.be_unconditional:
  60. print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
  61. f"Prepending {self.sos_token} as a sos token.")
  62. self.be_unconditional = True
  63. self.cond_stage_key = self.first_stage_key
  64. self.cond_stage_model = SOSProvider(self.sos_token)
  65. else:
  66. model = instantiate_from_config(config)
  67. model = model.eval()
  68. model.train = disabled_train
  69. self.cond_stage_model = model
  70. def forward(self, x, c):
  71. # one step to produce the logits
  72. _, z_indices = self.encode_to_z(x)
  73. _, c_indices = self.encode_to_c(c)
  74. if self.training and self.pkeep < 1.0:
  75. mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
  76. device=z_indices.device))
  77. mask = mask.round().to(dtype=torch.int64)
  78. r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
  79. a_indices = mask*z_indices+(1-mask)*r_indices
  80. else:
  81. a_indices = z_indices
  82. cz_indices = torch.cat((c_indices, a_indices), dim=1)
  83. # target includes all sequence elements (no need to handle first one
  84. # differently because we are conditioning)
  85. target = z_indices
  86. # make the prediction
  87. logits, _ = self.transformer(cz_indices[:, :-1])
  88. # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
  89. logits = logits[:, c_indices.shape[1]-1:]
  90. return logits, target
  91. def top_k_logits(self, logits, k):
  92. v, ix = torch.topk(logits, k)
  93. out = logits.clone()
  94. out[out < v[..., [-1]]] = -float('Inf')
  95. return out
  96. @torch.no_grad()
  97. def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
  98. callback=lambda k: None):
  99. x = torch.cat((c,x),dim=1)
  100. block_size = self.transformer.get_block_size()
  101. assert not self.transformer.training
  102. if self.pkeep <= 0.0:
  103. # one pass suffices since input is pure noise anyway
  104. assert len(x.shape)==2
  105. noise_shape = (x.shape[0], steps-1)
  106. #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
  107. noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
  108. x = torch.cat((x,noise),dim=1)
  109. logits, _ = self.transformer(x)
  110. # take all logits for now and scale by temp
  111. logits = logits / temperature
  112. # optionally crop probabilities to only the top k options
  113. if top_k is not None:
  114. logits = self.top_k_logits(logits, top_k)
  115. # apply softmax to convert to probabilities
  116. probs = F.softmax(logits, dim=-1)
  117. # sample from the distribution or take the most likely
  118. if sample:
  119. shape = probs.shape
  120. probs = probs.reshape(shape[0]*shape[1],shape[2])
  121. ix = torch.multinomial(probs, num_samples=1)
  122. probs = probs.reshape(shape[0],shape[1],shape[2])
  123. ix = ix.reshape(shape[0],shape[1])
  124. else:
  125. _, ix = torch.topk(probs, k=1, dim=-1)
  126. # cut off conditioning
  127. x = ix[:, c.shape[1]-1:]
  128. else:
  129. for k in range(steps):
  130. callback(k)
  131. assert x.size(1) <= block_size # make sure model can see conditioning
  132. x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
  133. logits, _ = self.transformer(x_cond)
  134. # pluck the logits at the final step and scale by temperature
  135. logits = logits[:, -1, :] / temperature
  136. # optionally crop probabilities to only the top k options
  137. if top_k is not None:
  138. logits = self.top_k_logits(logits, top_k)
  139. # apply softmax to convert to probabilities
  140. probs = F.softmax(logits, dim=-1)
  141. # sample from the distribution or take the most likely
  142. if sample:
  143. ix = torch.multinomial(probs, num_samples=1)
  144. else:
  145. _, ix = torch.topk(probs, k=1, dim=-1)
  146. # append to the sequence and continue
  147. x = torch.cat((x, ix), dim=1)
  148. # cut off conditioning
  149. x = x[:, c.shape[1]:]
  150. return x
  151. @torch.no_grad()
  152. def encode_to_z(self, x):
  153. quant_z, _, info = self.first_stage_model.encode(x)
  154. indices = info[2].view(quant_z.shape[0], -1)
  155. indices = self.permuter(indices)
  156. return quant_z, indices
  157. @torch.no_grad()
  158. def encode_to_c(self, c):
  159. if self.downsample_cond_size > -1:
  160. c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
  161. quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
  162. if len(indices.shape) > 2:
  163. indices = indices.view(c.shape[0], -1)
  164. return quant_c, indices
  165. @torch.no_grad()
  166. def decode_to_img(self, index, zshape):
  167. index = self.permuter(index, reverse=True)
  168. bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
  169. quant_z = self.first_stage_model.quantize.get_codebook_entry(
  170. index.reshape(-1), shape=bhwc)
  171. x = self.first_stage_model.decode(quant_z)
  172. return x
  173. @torch.no_grad()
  174. def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
  175. log = dict()
  176. N = 4
  177. if lr_interface:
  178. x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
  179. else:
  180. x, c = self.get_xc(batch, N)
  181. x = x.to(device=self.device)
  182. c = c.to(device=self.device)
  183. quant_z, z_indices = self.encode_to_z(x)
  184. quant_c, c_indices = self.encode_to_c(c)
  185. # create a "half"" sample
  186. z_start_indices = z_indices[:,:z_indices.shape[1]//2]
  187. index_sample = self.sample(z_start_indices, c_indices,
  188. steps=z_indices.shape[1]-z_start_indices.shape[1],
  189. temperature=temperature if temperature is not None else 1.0,
  190. sample=True,
  191. top_k=top_k if top_k is not None else 100,
  192. callback=callback if callback is not None else lambda k: None)
  193. x_sample = self.decode_to_img(index_sample, quant_z.shape)
  194. # sample
  195. z_start_indices = z_indices[:, :0]
  196. index_sample = self.sample(z_start_indices, c_indices,
  197. steps=z_indices.shape[1],
  198. temperature=temperature if temperature is not None else 1.0,
  199. sample=True,
  200. top_k=top_k if top_k is not None else 100,
  201. callback=callback if callback is not None else lambda k: None)
  202. x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
  203. # det sample
  204. z_start_indices = z_indices[:, :0]
  205. index_sample = self.sample(z_start_indices, c_indices,
  206. steps=z_indices.shape[1],
  207. sample=False,
  208. callback=callback if callback is not None else lambda k: None)
  209. x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
  210. # reconstruction
  211. x_rec = self.decode_to_img(z_indices, quant_z.shape)
  212. log["inputs"] = x
  213. log["reconstructions"] = x_rec
  214. if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
  215. figure_size = (x_rec.shape[2], x_rec.shape[3])
  216. dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
  217. label_for_category_no = dataset.get_textual_label_for_category_no
  218. plotter = dataset.conditional_builders[self.cond_stage_key].plot
  219. log["conditioning"] = torch.zeros_like(log["reconstructions"])
  220. for i in range(quant_c.shape[0]):
  221. log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
  222. log["conditioning_rec"] = log["conditioning"]
  223. elif self.cond_stage_key != "image":
  224. cond_rec = self.cond_stage_model.decode(quant_c)
  225. if self.cond_stage_key == "segmentation":
  226. # get image from segmentation mask
  227. num_classes = cond_rec.shape[1]
  228. c = torch.argmax(c, dim=1, keepdim=True)
  229. c = F.one_hot(c, num_classes=num_classes)
  230. c = c.squeeze(1).permute(0, 3, 1, 2).float()
  231. c = self.cond_stage_model.to_rgb(c)
  232. cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
  233. cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
  234. cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
  235. cond_rec = self.cond_stage_model.to_rgb(cond_rec)
  236. log["conditioning_rec"] = cond_rec
  237. log["conditioning"] = c
  238. log["samples_half"] = x_sample
  239. log["samples_nopix"] = x_sample_nopix
  240. log["samples_det"] = x_sample_det
  241. return log
  242. def get_input(self, key, batch):
  243. x = batch[key]
  244. if len(x.shape) == 3:
  245. x = x[..., None]
  246. if len(x.shape) == 4:
  247. x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
  248. if x.dtype == torch.double:
  249. x = x.float()
  250. return x
  251. def get_xc(self, batch, N=None):
  252. x = self.get_input(self.first_stage_key, batch)
  253. c = self.get_input(self.cond_stage_key, batch)
  254. if N is not None:
  255. x = x[:N]
  256. c = c[:N]
  257. return x, c
  258. def shared_step(self, batch, batch_idx):
  259. x, c = self.get_xc(batch)
  260. logits, target = self(x, c)
  261. loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
  262. return loss
  263. def training_step(self, batch, batch_idx):
  264. loss = self.shared_step(batch, batch_idx)
  265. self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
  266. return loss
  267. def validation_step(self, batch, batch_idx):
  268. loss = self.shared_step(batch, batch_idx)
  269. self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
  270. return loss
  271. def configure_optimizers(self):
  272. """
  273. Following minGPT:
  274. This long function is unfortunately doing something very simple and is being very defensive:
  275. We are separating out all parameters of the model into two buckets: those that will experience
  276. weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
  277. We are then returning the PyTorch optimizer object.
  278. """
  279. # separate out all parameters to those that will and won't experience regularizing weight decay
  280. decay = set()
  281. no_decay = set()
  282. whitelist_weight_modules = (torch.nn.Linear, )
  283. blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
  284. for mn, m in self.transformer.named_modules():
  285. for pn, p in m.named_parameters():
  286. fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
  287. if pn.endswith('bias'):
  288. # all biases will not be decayed
  289. no_decay.add(fpn)
  290. elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
  291. # weights of whitelist modules will be weight decayed
  292. decay.add(fpn)
  293. elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
  294. # weights of blacklist modules will NOT be weight decayed
  295. no_decay.add(fpn)
  296. # special case the position embedding parameter in the root GPT module as not decayed
  297. no_decay.add('pos_emb')
  298. # validate that we considered every parameter
  299. param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
  300. inter_params = decay & no_decay
  301. union_params = decay | no_decay
  302. assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
  303. assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
  304. % (str(param_dict.keys() - union_params), )
  305. # create the pytorch optimizer object
  306. optim_groups = [
  307. {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
  308. {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
  309. ]
  310. optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
  311. return optimizer