txt2img.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. import argparse, os
  2. import cv2
  3. import torch
  4. import numpy as np
  5. from omegaconf import OmegaConf
  6. from PIL import Image
  7. from tqdm import tqdm, trange
  8. from itertools import islice
  9. from einops import rearrange
  10. from torchvision.utils import make_grid
  11. from pytorch_lightning import seed_everything
  12. from torch import autocast
  13. from contextlib import nullcontext
  14. from imwatermark import WatermarkEncoder
  15. from ldm.util import instantiate_from_config
  16. from ldm.models.diffusion.ddim import DDIMSampler
  17. from ldm.models.diffusion.plms import PLMSSampler
  18. from ldm.models.diffusion.dpm_solver import DPMSolverSampler
  19. torch.set_grad_enabled(False)
  20. def chunk(it, size):
  21. it = iter(it)
  22. return iter(lambda: tuple(islice(it, size)), ())
  23. def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
  24. print(f"Loading model from {ckpt}")
  25. pl_sd = torch.load(ckpt, map_location="cpu")
  26. if "global_step" in pl_sd:
  27. print(f"Global Step: {pl_sd['global_step']}")
  28. sd = pl_sd["state_dict"]
  29. model = instantiate_from_config(config.model)
  30. m, u = model.load_state_dict(sd, strict=False)
  31. if len(m) > 0 and verbose:
  32. print("missing keys:")
  33. print(m)
  34. if len(u) > 0 and verbose:
  35. print("unexpected keys:")
  36. print(u)
  37. if device == torch.device("cuda"):
  38. model.cuda()
  39. elif device == torch.device("cpu"):
  40. model.cpu()
  41. model.cond_stage_model.device = "cpu"
  42. else:
  43. raise ValueError(f"Incorrect device name. Received: {device}")
  44. model.eval()
  45. return model
  46. def parse_args():
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument(
  49. "--prompt",
  50. type=str,
  51. nargs="?",
  52. default="a professional photograph of an astronaut riding a triceratops",
  53. help="the prompt to render"
  54. )
  55. parser.add_argument(
  56. "--outdir",
  57. type=str,
  58. nargs="?",
  59. help="dir to write results to",
  60. default="outputs/txt2img-samples"
  61. )
  62. parser.add_argument(
  63. "--steps",
  64. type=int,
  65. default=50,
  66. help="number of ddim sampling steps",
  67. )
  68. parser.add_argument(
  69. "--plms",
  70. action='store_true',
  71. help="use plms sampling",
  72. )
  73. parser.add_argument(
  74. "--dpm",
  75. action='store_true',
  76. help="use DPM (2) sampler",
  77. )
  78. parser.add_argument(
  79. "--fixed_code",
  80. action='store_true',
  81. help="if enabled, uses the same starting code across all samples ",
  82. )
  83. parser.add_argument(
  84. "--ddim_eta",
  85. type=float,
  86. default=0.0,
  87. help="ddim eta (eta=0.0 corresponds to deterministic sampling",
  88. )
  89. parser.add_argument(
  90. "--n_iter",
  91. type=int,
  92. default=3,
  93. help="sample this often",
  94. )
  95. parser.add_argument(
  96. "--H",
  97. type=int,
  98. default=512,
  99. help="image height, in pixel space",
  100. )
  101. parser.add_argument(
  102. "--W",
  103. type=int,
  104. default=512,
  105. help="image width, in pixel space",
  106. )
  107. parser.add_argument(
  108. "--C",
  109. type=int,
  110. default=4,
  111. help="latent channels",
  112. )
  113. parser.add_argument(
  114. "--f",
  115. type=int,
  116. default=8,
  117. help="downsampling factor, most often 8 or 16",
  118. )
  119. parser.add_argument(
  120. "--n_samples",
  121. type=int,
  122. default=3,
  123. help="how many samples to produce for each given prompt. A.k.a batch size",
  124. )
  125. parser.add_argument(
  126. "--n_rows",
  127. type=int,
  128. default=0,
  129. help="rows in the grid (default: n_samples)",
  130. )
  131. parser.add_argument(
  132. "--scale",
  133. type=float,
  134. default=9.0,
  135. help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
  136. )
  137. parser.add_argument(
  138. "--from-file",
  139. type=str,
  140. help="if specified, load prompts from this file, separated by newlines",
  141. )
  142. parser.add_argument(
  143. "--config",
  144. type=str,
  145. default="configs/stable-diffusion/v2-inference.yaml",
  146. help="path to config which constructs model",
  147. )
  148. parser.add_argument(
  149. "--ckpt",
  150. type=str,
  151. help="path to checkpoint of model",
  152. )
  153. parser.add_argument(
  154. "--seed",
  155. type=int,
  156. default=42,
  157. help="the seed (for reproducible sampling)",
  158. )
  159. parser.add_argument(
  160. "--precision",
  161. type=str,
  162. help="evaluate at this precision",
  163. choices=["full", "autocast"],
  164. default="autocast"
  165. )
  166. parser.add_argument(
  167. "--repeat",
  168. type=int,
  169. default=1,
  170. help="repeat each prompt in file this often",
  171. )
  172. parser.add_argument(
  173. "--device",
  174. type=str,
  175. help="Device on which Stable Diffusion will be run",
  176. choices=["cpu", "cuda"],
  177. default="cpu"
  178. )
  179. parser.add_argument(
  180. "--torchscript",
  181. action='store_true',
  182. help="Use TorchScript",
  183. )
  184. parser.add_argument(
  185. "--ipex",
  186. action='store_true',
  187. help="Use Intel® Extension for PyTorch*",
  188. )
  189. parser.add_argument(
  190. "--bf16",
  191. action='store_true',
  192. help="Use bfloat16",
  193. )
  194. opt = parser.parse_args()
  195. return opt
  196. def put_watermark(img, wm_encoder=None):
  197. if wm_encoder is not None:
  198. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  199. img = wm_encoder.encode(img, 'dwtDct')
  200. img = Image.fromarray(img[:, :, ::-1])
  201. return img
  202. def main(opt):
  203. seed_everything(opt.seed)
  204. config = OmegaConf.load(f"{opt.config}")
  205. device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
  206. model = load_model_from_config(config, f"{opt.ckpt}", device)
  207. if opt.plms:
  208. sampler = PLMSSampler(model, device=device)
  209. elif opt.dpm:
  210. sampler = DPMSolverSampler(model, device=device)
  211. else:
  212. sampler = DDIMSampler(model, device=device)
  213. os.makedirs(opt.outdir, exist_ok=True)
  214. outpath = opt.outdir
  215. print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
  216. wm = "SDV2"
  217. wm_encoder = WatermarkEncoder()
  218. wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
  219. batch_size = opt.n_samples
  220. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  221. if not opt.from_file:
  222. prompt = opt.prompt
  223. assert prompt is not None
  224. data = [batch_size * [prompt]]
  225. else:
  226. print(f"reading prompts from {opt.from_file}")
  227. with open(opt.from_file, "r") as f:
  228. data = f.read().splitlines()
  229. data = [p for p in data for i in range(opt.repeat)]
  230. data = list(chunk(data, batch_size))
  231. sample_path = os.path.join(outpath, "samples")
  232. os.makedirs(sample_path, exist_ok=True)
  233. sample_count = 0
  234. base_count = len(os.listdir(sample_path))
  235. grid_count = len(os.listdir(outpath)) - 1
  236. start_code = None
  237. if opt.fixed_code:
  238. start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
  239. if opt.torchscript or opt.ipex:
  240. transformer = model.cond_stage_model.model
  241. unet = model.model.diffusion_model
  242. decoder = model.first_stage_model.decoder
  243. additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
  244. shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
  245. if opt.bf16 and not opt.torchscript and not opt.ipex:
  246. raise ValueError('Bfloat16 is supported only for torchscript+ipex')
  247. if opt.bf16 and unet.dtype != torch.bfloat16:
  248. raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
  249. "you'd like to use bfloat16 with CPU.")
  250. if unet.dtype == torch.float16 and device == torch.device("cpu"):
  251. raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
  252. if opt.ipex:
  253. import intel_extension_for_pytorch as ipex
  254. bf16_dtype = torch.bfloat16 if opt.bf16 else None
  255. transformer = transformer.to(memory_format=torch.channels_last)
  256. transformer = ipex.optimize(transformer, level="O1", inplace=True)
  257. unet = unet.to(memory_format=torch.channels_last)
  258. unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
  259. decoder = decoder.to(memory_format=torch.channels_last)
  260. decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
  261. if opt.torchscript:
  262. with torch.no_grad(), additional_context:
  263. # get UNET scripted
  264. if unet.use_checkpoint:
  265. raise ValueError("Gradient checkpoint won't work with tracing. " +
  266. "Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
  267. img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
  268. t_in = torch.ones(2, dtype=torch.int64)
  269. context = torch.ones(2, 77, 1024, dtype=torch.float32)
  270. scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
  271. scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
  272. print(type(scripted_unet))
  273. model.model.scripted_diffusion_model = scripted_unet
  274. # get Decoder for first stage model scripted
  275. samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
  276. scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
  277. scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
  278. print(type(scripted_decoder))
  279. model.first_stage_model.decoder = scripted_decoder
  280. prompts = data[0]
  281. print("Running a forward pass to initialize optimizations")
  282. uc = None
  283. if opt.scale != 1.0:
  284. uc = model.get_learned_conditioning(batch_size * [""])
  285. if isinstance(prompts, tuple):
  286. prompts = list(prompts)
  287. with torch.no_grad(), additional_context:
  288. for _ in range(3):
  289. c = model.get_learned_conditioning(prompts)
  290. samples_ddim, _ = sampler.sample(S=5,
  291. conditioning=c,
  292. batch_size=batch_size,
  293. shape=shape,
  294. verbose=False,
  295. unconditional_guidance_scale=opt.scale,
  296. unconditional_conditioning=uc,
  297. eta=opt.ddim_eta,
  298. x_T=start_code)
  299. print("Running a forward pass for decoder")
  300. for _ in range(3):
  301. x_samples_ddim = model.decode_first_stage(samples_ddim)
  302. precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
  303. with torch.no_grad(), \
  304. precision_scope(opt.device), \
  305. model.ema_scope():
  306. all_samples = list()
  307. for n in trange(opt.n_iter, desc="Sampling"):
  308. for prompts in tqdm(data, desc="data"):
  309. uc = None
  310. if opt.scale != 1.0:
  311. uc = model.get_learned_conditioning(batch_size * [""])
  312. if isinstance(prompts, tuple):
  313. prompts = list(prompts)
  314. c = model.get_learned_conditioning(prompts)
  315. shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
  316. samples, _ = sampler.sample(S=opt.steps,
  317. conditioning=c,
  318. batch_size=opt.n_samples,
  319. shape=shape,
  320. verbose=False,
  321. unconditional_guidance_scale=opt.scale,
  322. unconditional_conditioning=uc,
  323. eta=opt.ddim_eta,
  324. x_T=start_code)
  325. x_samples = model.decode_first_stage(samples)
  326. x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
  327. for x_sample in x_samples:
  328. x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
  329. img = Image.fromarray(x_sample.astype(np.uint8))
  330. img = put_watermark(img, wm_encoder)
  331. img.save(os.path.join(sample_path, f"{base_count:05}.png"))
  332. base_count += 1
  333. sample_count += 1
  334. all_samples.append(x_samples)
  335. # additionally, save as grid
  336. grid = torch.stack(all_samples, 0)
  337. grid = rearrange(grid, 'n b c h w -> (n b) c h w')
  338. grid = make_grid(grid, nrow=n_rows)
  339. # to image
  340. grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
  341. grid = Image.fromarray(grid.astype(np.uint8))
  342. grid = put_watermark(grid, wm_encoder)
  343. grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
  344. grid_count += 1
  345. print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
  346. f" \nEnjoy.")
  347. if __name__ == "__main__":
  348. opt = parse_args()
  349. main(opt)