train_db.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # DreamBooth training
  2. # XXX dropped option: fine_tune
  3. import gc
  4. import time
  5. import argparse
  6. import itertools
  7. import math
  8. import os
  9. import toml
  10. from multiprocessing import Value
  11. from tqdm import tqdm
  12. import torch
  13. from accelerate.utils import set_seed
  14. import diffusers
  15. from diffusers import DDPMScheduler
  16. import library.train_util as train_util
  17. import library.config_util as config_util
  18. from library.config_util import (
  19. ConfigSanitizer,
  20. BlueprintGenerator,
  21. )
  22. import library.custom_train_functions as custom_train_functions
  23. from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
  24. def train(args):
  25. train_util.verify_training_args(args)
  26. train_util.prepare_dataset_args(args, False)
  27. cache_latents = args.cache_latents
  28. if args.seed is not None:
  29. set_seed(args.seed) # 乱数系列を初期化する
  30. tokenizer = train_util.load_tokenizer(args)
  31. blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
  32. if args.dataset_config is not None:
  33. print(f"Load dataset config from {args.dataset_config}")
  34. user_config = config_util.load_user_config(args.dataset_config)
  35. ignored = ["train_data_dir", "reg_data_dir"]
  36. if any(getattr(args, attr) is not None for attr in ignored):
  37. print(
  38. "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
  39. ", ".join(ignored)
  40. )
  41. )
  42. else:
  43. user_config = {
  44. "datasets": [
  45. {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
  46. ]
  47. }
  48. blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
  49. train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
  50. current_epoch = Value("i", 0)
  51. current_step = Value("i", 0)
  52. ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
  53. collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
  54. if args.no_token_padding:
  55. train_dataset_group.disable_token_padding()
  56. if args.debug_dataset:
  57. train_util.debug_dataset(train_dataset_group)
  58. return
  59. if cache_latents:
  60. assert (
  61. train_dataset_group.is_latent_cacheable()
  62. ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
  63. # acceleratorを準備する
  64. print("prepare accelerator")
  65. if args.gradient_accumulation_steps > 1:
  66. print(
  67. f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
  68. )
  69. print(
  70. f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
  71. )
  72. accelerator, unwrap_model = train_util.prepare_accelerator(args)
  73. # mixed precisionに対応した型を用意しておき適宜castする
  74. weight_dtype, save_dtype = train_util.prepare_dtype(args)
  75. # モデルを読み込む
  76. text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
  77. # verify load/save model formats
  78. if load_stable_diffusion_format:
  79. src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
  80. src_diffusers_model_path = None
  81. else:
  82. src_stable_diffusion_ckpt = None
  83. src_diffusers_model_path = args.pretrained_model_name_or_path
  84. if args.save_model_as is None:
  85. save_stable_diffusion_format = load_stable_diffusion_format
  86. use_safetensors = args.use_safetensors
  87. else:
  88. save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
  89. use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
  90. # モデルに xformers とか memory efficient attention を組み込む
  91. train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
  92. # 学習を準備する
  93. if cache_latents:
  94. vae.to(accelerator.device, dtype=weight_dtype)
  95. vae.requires_grad_(False)
  96. vae.eval()
  97. with torch.no_grad():
  98. train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
  99. vae.to("cpu")
  100. if torch.cuda.is_available():
  101. torch.cuda.empty_cache()
  102. gc.collect()
  103. accelerator.wait_for_everyone()
  104. # 学習を準備する:モデルを適切な状態にする
  105. train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
  106. unet.requires_grad_(True) # 念のため追加
  107. text_encoder.requires_grad_(train_text_encoder)
  108. if not train_text_encoder:
  109. print("Text Encoder is not trained.")
  110. if args.gradient_checkpointing:
  111. unet.enable_gradient_checkpointing()
  112. text_encoder.gradient_checkpointing_enable()
  113. if not cache_latents:
  114. vae.requires_grad_(False)
  115. vae.eval()
  116. vae.to(accelerator.device, dtype=weight_dtype)
  117. # 学習に必要なクラスを準備する
  118. print("prepare optimizer, data loader etc.")
  119. if train_text_encoder:
  120. trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
  121. else:
  122. trainable_params = unet.parameters()
  123. _, _, optimizer = train_util.get_optimizer(args, trainable_params)
  124. # dataloaderを準備する
  125. # DataLoaderのプロセス数:0はメインプロセスになる
  126. n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
  127. train_dataloader = torch.utils.data.DataLoader(
  128. train_dataset_group,
  129. batch_size=1,
  130. shuffle=True,
  131. collate_fn=collater,
  132. num_workers=n_workers,
  133. persistent_workers=args.persistent_data_loader_workers,
  134. )
  135. # 学習ステップ数を計算する
  136. if args.max_train_epochs is not None:
  137. args.max_train_steps = args.max_train_epochs * math.ceil(
  138. len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
  139. )
  140. print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
  141. # データセット側にも学習ステップを送信
  142. train_dataset_group.set_max_train_steps(args.max_train_steps)
  143. if args.stop_text_encoder_training is None:
  144. args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
  145. # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
  146. lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
  147. # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
  148. if args.full_fp16:
  149. assert (
  150. args.mixed_precision == "fp16"
  151. ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
  152. print("enable full fp16 training.")
  153. unet.to(weight_dtype)
  154. text_encoder.to(weight_dtype)
  155. # acceleratorがなんかよろしくやってくれるらしい
  156. if train_text_encoder:
  157. unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  158. unet, text_encoder, optimizer, train_dataloader, lr_scheduler
  159. )
  160. else:
  161. unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
  162. # transform DDP after prepare
  163. text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
  164. if not train_text_encoder:
  165. text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
  166. # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
  167. if args.full_fp16:
  168. train_util.patch_accelerator_for_fp16_training(accelerator)
  169. # resumeする
  170. train_util.resume_from_local_or_hf_if_specified(accelerator, args)
  171. # epoch数を計算する
  172. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
  173. num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
  174. if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
  175. args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
  176. # 学習する
  177. total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
  178. print("running training / 学習開始")
  179. print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
  180. print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
  181. print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
  182. print(f" num epochs / epoch数: {num_train_epochs}")
  183. print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
  184. print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
  185. print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
  186. print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
  187. progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
  188. global_step = 0
  189. noise_scheduler = DDPMScheduler(
  190. beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
  191. )
  192. if accelerator.is_main_process:
  193. accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
  194. loss_list = []
  195. loss_total = 0.0
  196. for epoch in range(num_train_epochs):
  197. print(f"epoch {epoch+1}/{num_train_epochs}")
  198. current_epoch.value = epoch + 1
  199. # 指定したステップ数までText Encoderを学習する:epoch最初の状態
  200. unet.train()
  201. # train==True is required to enable gradient_checkpointing
  202. if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
  203. text_encoder.train()
  204. for step, batch in enumerate(train_dataloader):
  205. current_step.value = global_step
  206. # 指定したステップ数でText Encoderの学習を止める
  207. if global_step == args.stop_text_encoder_training:
  208. print(f"stop text encoder training at step {global_step}")
  209. if not args.gradient_checkpointing:
  210. text_encoder.train(False)
  211. text_encoder.requires_grad_(False)
  212. with accelerator.accumulate(unet):
  213. with torch.no_grad():
  214. # latentに変換
  215. if cache_latents:
  216. latents = batch["latents"].to(accelerator.device)
  217. else:
  218. latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
  219. latents = latents * 0.18215
  220. b_size = latents.shape[0]
  221. # Sample noise that we'll add to the latents
  222. noise = torch.randn_like(latents, device=latents.device)
  223. if args.noise_offset:
  224. # https://www.crosslabs.org//blog/diffusion-with-offset-noise
  225. noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
  226. elif args.multires_noise_iterations:
  227. noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
  228. # Get the text embedding for conditioning
  229. with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
  230. if args.weighted_captions:
  231. encoder_hidden_states = get_weighted_text_embeddings(
  232. tokenizer,
  233. text_encoder,
  234. batch["captions"],
  235. accelerator.device,
  236. args.max_token_length // 75 if args.max_token_length else 1,
  237. clip_skip=args.clip_skip,
  238. )
  239. else:
  240. input_ids = batch["input_ids"].to(accelerator.device)
  241. encoder_hidden_states = train_util.get_hidden_states(
  242. args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
  243. )
  244. # Sample a random timestep for each image
  245. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
  246. timesteps = timesteps.long()
  247. # Add noise to the latents according to the noise magnitude at each timestep
  248. # (this is the forward diffusion process)
  249. noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  250. # Predict the noise residual
  251. with accelerator.autocast():
  252. noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  253. if args.v_parameterization:
  254. # v-parameterization training
  255. target = noise_scheduler.get_velocity(latents, noise, timesteps)
  256. else:
  257. target = noise
  258. loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
  259. loss = loss.mean([1, 2, 3])
  260. loss_weights = batch["loss_weights"] # 各sampleごとのweight
  261. loss = loss * loss_weights
  262. if args.min_snr_gamma:
  263. loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
  264. loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
  265. accelerator.backward(loss)
  266. if accelerator.sync_gradients and args.max_grad_norm != 0.0:
  267. if train_text_encoder:
  268. params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
  269. else:
  270. params_to_clip = unet.parameters()
  271. accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
  272. optimizer.step()
  273. lr_scheduler.step()
  274. optimizer.zero_grad(set_to_none=True)
  275. # Checks if the accelerator has performed an optimization step behind the scenes
  276. if accelerator.sync_gradients:
  277. progress_bar.update(1)
  278. global_step += 1
  279. train_util.sample_images(
  280. accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
  281. )
  282. # 指定ステップごとにモデルを保存
  283. if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
  284. accelerator.wait_for_everyone()
  285. if accelerator.is_main_process:
  286. src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
  287. train_util.save_sd_model_on_epoch_end_or_stepwise(
  288. args,
  289. False,
  290. accelerator,
  291. src_path,
  292. save_stable_diffusion_format,
  293. use_safetensors,
  294. save_dtype,
  295. epoch,
  296. num_train_epochs,
  297. global_step,
  298. unwrap_model(text_encoder),
  299. unwrap_model(unet),
  300. vae,
  301. )
  302. current_loss = loss.detach().item()
  303. if args.logging_dir is not None:
  304. logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
  305. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
  306. logs["lr/d*lr"] = (
  307. lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
  308. )
  309. accelerator.log(logs, step=global_step)
  310. if epoch == 0:
  311. loss_list.append(current_loss)
  312. else:
  313. loss_total -= loss_list[step]
  314. loss_list[step] = current_loss
  315. loss_total += current_loss
  316. avr_loss = loss_total / len(loss_list)
  317. logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
  318. progress_bar.set_postfix(**logs)
  319. if global_step >= args.max_train_steps:
  320. break
  321. if args.logging_dir is not None:
  322. logs = {"loss/epoch": loss_total / len(loss_list)}
  323. accelerator.log(logs, step=epoch + 1)
  324. accelerator.wait_for_everyone()
  325. if args.save_every_n_epochs is not None:
  326. if accelerator.is_main_process:
  327. # checking for saving is in util
  328. src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
  329. train_util.save_sd_model_on_epoch_end_or_stepwise(
  330. args,
  331. True,
  332. accelerator,
  333. src_path,
  334. save_stable_diffusion_format,
  335. use_safetensors,
  336. save_dtype,
  337. epoch,
  338. num_train_epochs,
  339. global_step,
  340. unwrap_model(text_encoder),
  341. unwrap_model(unet),
  342. vae,
  343. )
  344. train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
  345. is_main_process = accelerator.is_main_process
  346. if is_main_process:
  347. unet = unwrap_model(unet)
  348. text_encoder = unwrap_model(text_encoder)
  349. accelerator.end_training()
  350. if args.save_state and is_main_process:
  351. train_util.save_state_on_train_end(args, accelerator)
  352. del accelerator # この後メモリを使うのでこれは消す
  353. if is_main_process:
  354. src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
  355. train_util.save_sd_model_on_train_end(
  356. args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
  357. )
  358. print("model saved.")
  359. def setup_parser() -> argparse.ArgumentParser:
  360. parser = argparse.ArgumentParser()
  361. train_util.add_sd_models_arguments(parser)
  362. train_util.add_dataset_arguments(parser, True, False, True)
  363. train_util.add_training_arguments(parser, True)
  364. train_util.add_sd_saving_arguments(parser)
  365. train_util.add_optimizer_arguments(parser)
  366. config_util.add_config_arguments(parser)
  367. custom_train_functions.add_custom_train_arguments(parser)
  368. parser.add_argument(
  369. "--no_token_padding",
  370. action="store_true",
  371. help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
  372. )
  373. parser.add_argument(
  374. "--stop_text_encoder_training",
  375. type=int,
  376. default=None,
  377. help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
  378. )
  379. return parser
  380. if __name__ == "__main__":
  381. parser = setup_parser()
  382. args = parser.parse_args()
  383. args = train_util.read_config_from_file(args, parser)
  384. train(args)