fine_tune.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # training with captions
  2. # XXX dropped option: hypernetwork training
  3. import argparse
  4. import gc
  5. import math
  6. import os
  7. import toml
  8. from multiprocessing import Value
  9. from tqdm import tqdm
  10. import torch
  11. from accelerate.utils import set_seed
  12. import diffusers
  13. from diffusers import DDPMScheduler
  14. import library.train_util as train_util
  15. import library.config_util as config_util
  16. from library.config_util import (
  17. ConfigSanitizer,
  18. BlueprintGenerator,
  19. )
  20. import library.custom_train_functions as custom_train_functions
  21. from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
  22. def train(args):
  23. train_util.verify_training_args(args)
  24. train_util.prepare_dataset_args(args, True)
  25. cache_latents = args.cache_latents
  26. if args.seed is not None:
  27. set_seed(args.seed) # 乱数系列を初期化する
  28. tokenizer = train_util.load_tokenizer(args)
  29. blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
  30. if args.dataset_config is not None:
  31. print(f"Load dataset config from {args.dataset_config}")
  32. user_config = config_util.load_user_config(args.dataset_config)
  33. ignored = ["train_data_dir", "in_json"]
  34. if any(getattr(args, attr) is not None for attr in ignored):
  35. print(
  36. "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
  37. ", ".join(ignored)
  38. )
  39. )
  40. else:
  41. user_config = {
  42. "datasets": [
  43. {
  44. "subsets": [
  45. {
  46. "image_dir": args.train_data_dir,
  47. "metadata_file": args.in_json,
  48. }
  49. ]
  50. }
  51. ]
  52. }
  53. blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
  54. train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
  55. current_epoch = Value("i", 0)
  56. current_step = Value("i", 0)
  57. ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
  58. collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
  59. if args.debug_dataset:
  60. train_util.debug_dataset(train_dataset_group)
  61. return
  62. if len(train_dataset_group) == 0:
  63. print(
  64. "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
  65. )
  66. return
  67. if cache_latents:
  68. assert (
  69. train_dataset_group.is_latent_cacheable()
  70. ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
  71. # acceleratorを準備する
  72. print("prepare accelerator")
  73. accelerator, unwrap_model = train_util.prepare_accelerator(args)
  74. # mixed precisionに対応した型を用意しておき適宜castする
  75. weight_dtype, save_dtype = train_util.prepare_dtype(args)
  76. # モデルを読み込む
  77. text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
  78. # verify load/save model formats
  79. if load_stable_diffusion_format:
  80. src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
  81. src_diffusers_model_path = None
  82. else:
  83. src_stable_diffusion_ckpt = None
  84. src_diffusers_model_path = args.pretrained_model_name_or_path
  85. if args.save_model_as is None:
  86. save_stable_diffusion_format = load_stable_diffusion_format
  87. use_safetensors = args.use_safetensors
  88. else:
  89. save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
  90. use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
  91. # Diffusers版のxformers使用フラグを設定する関数
  92. def set_diffusers_xformers_flag(model, valid):
  93. # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
  94. # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
  95. # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
  96. # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
  97. # Recursively walk through all the children.
  98. # Any children which exposes the set_use_memory_efficient_attention_xformers method
  99. # gets the message
  100. def fn_recursive_set_mem_eff(module: torch.nn.Module):
  101. if hasattr(module, "set_use_memory_efficient_attention_xformers"):
  102. module.set_use_memory_efficient_attention_xformers(valid)
  103. for child in module.children():
  104. fn_recursive_set_mem_eff(child)
  105. fn_recursive_set_mem_eff(model)
  106. # モデルに xformers とか memory efficient attention を組み込む
  107. if args.diffusers_xformers:
  108. print("Use xformers by Diffusers")
  109. set_diffusers_xformers_flag(unet, True)
  110. else:
  111. # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
  112. print("Disable Diffusers' xformers")
  113. set_diffusers_xformers_flag(unet, False)
  114. train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
  115. # 学習を準備する
  116. if cache_latents:
  117. vae.to(accelerator.device, dtype=weight_dtype)
  118. vae.requires_grad_(False)
  119. vae.eval()
  120. with torch.no_grad():
  121. train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
  122. vae.to("cpu")
  123. if torch.cuda.is_available():
  124. torch.cuda.empty_cache()
  125. gc.collect()
  126. accelerator.wait_for_everyone()
  127. # 学習を準備する:モデルを適切な状態にする
  128. training_models = []
  129. if args.gradient_checkpointing:
  130. unet.enable_gradient_checkpointing()
  131. training_models.append(unet)
  132. if args.train_text_encoder:
  133. print("enable text encoder training")
  134. if args.gradient_checkpointing:
  135. text_encoder.gradient_checkpointing_enable()
  136. training_models.append(text_encoder)
  137. else:
  138. text_encoder.to(accelerator.device, dtype=weight_dtype)
  139. text_encoder.requires_grad_(False) # text encoderは学習しない
  140. if args.gradient_checkpointing:
  141. text_encoder.gradient_checkpointing_enable()
  142. text_encoder.train() # required for gradient_checkpointing
  143. else:
  144. text_encoder.eval()
  145. if not cache_latents:
  146. vae.requires_grad_(False)
  147. vae.eval()
  148. vae.to(accelerator.device, dtype=weight_dtype)
  149. for m in training_models:
  150. m.requires_grad_(True)
  151. params = []
  152. for m in training_models:
  153. params.extend(m.parameters())
  154. params_to_optimize = params
  155. # 学習に必要なクラスを準備する
  156. print("prepare optimizer, data loader etc.")
  157. _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
  158. # dataloaderを準備する
  159. # DataLoaderのプロセス数:0はメインプロセスになる
  160. n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
  161. train_dataloader = torch.utils.data.DataLoader(
  162. train_dataset_group,
  163. batch_size=1,
  164. shuffle=True,
  165. collate_fn=collater,
  166. num_workers=n_workers,
  167. persistent_workers=args.persistent_data_loader_workers,
  168. )
  169. # 学習ステップ数を計算する
  170. if args.max_train_epochs is not None:
  171. args.max_train_steps = args.max_train_epochs * math.ceil(
  172. len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
  173. )
  174. print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
  175. # データセット側にも学習ステップを送信
  176. train_dataset_group.set_max_train_steps(args.max_train_steps)
  177. # lr schedulerを用意する
  178. lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
  179. # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
  180. if args.full_fp16:
  181. assert (
  182. args.mixed_precision == "fp16"
  183. ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
  184. print("enable full fp16 training.")
  185. unet.to(weight_dtype)
  186. text_encoder.to(weight_dtype)
  187. # acceleratorがなんかよろしくやってくれるらしい
  188. if args.train_text_encoder:
  189. unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  190. unet, text_encoder, optimizer, train_dataloader, lr_scheduler
  191. )
  192. else:
  193. unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
  194. # transform DDP after prepare
  195. text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
  196. # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
  197. if args.full_fp16:
  198. train_util.patch_accelerator_for_fp16_training(accelerator)
  199. # resumeする
  200. train_util.resume_from_local_or_hf_if_specified(accelerator, args)
  201. # epoch数を計算する
  202. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
  203. num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
  204. if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
  205. args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
  206. # 学習する
  207. total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
  208. print("running training / 学習開始")
  209. print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
  210. print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
  211. print(f" num epochs / epoch数: {num_train_epochs}")
  212. print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
  213. print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
  214. print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
  215. print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
  216. progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
  217. global_step = 0
  218. noise_scheduler = DDPMScheduler(
  219. beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
  220. )
  221. if accelerator.is_main_process:
  222. accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
  223. for epoch in range(num_train_epochs):
  224. print(f"epoch {epoch+1}/{num_train_epochs}")
  225. current_epoch.value = epoch + 1
  226. for m in training_models:
  227. m.train()
  228. loss_total = 0
  229. for step, batch in enumerate(train_dataloader):
  230. current_step.value = global_step
  231. with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
  232. with torch.no_grad():
  233. if "latents" in batch and batch["latents"] is not None:
  234. latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
  235. else:
  236. # latentに変換
  237. latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
  238. latents = latents * 0.18215
  239. b_size = latents.shape[0]
  240. with torch.set_grad_enabled(args.train_text_encoder):
  241. # Get the text embedding for conditioning
  242. if args.weighted_captions:
  243. encoder_hidden_states = get_weighted_text_embeddings(
  244. tokenizer,
  245. text_encoder,
  246. batch["captions"],
  247. accelerator.device,
  248. args.max_token_length // 75 if args.max_token_length else 1,
  249. clip_skip=args.clip_skip,
  250. )
  251. else:
  252. input_ids = batch["input_ids"].to(accelerator.device)
  253. encoder_hidden_states = train_util.get_hidden_states(
  254. args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
  255. )
  256. # Sample noise that we'll add to the latents
  257. noise = torch.randn_like(latents, device=latents.device)
  258. if args.noise_offset:
  259. # https://www.crosslabs.org//blog/diffusion-with-offset-noise
  260. noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
  261. elif args.multires_noise_iterations:
  262. noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
  263. # Sample a random timestep for each image
  264. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
  265. timesteps = timesteps.long()
  266. # Add noise to the latents according to the noise magnitude at each timestep
  267. # (this is the forward diffusion process)
  268. noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  269. # Predict the noise residual
  270. with accelerator.autocast():
  271. noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  272. if args.v_parameterization:
  273. # v-parameterization training
  274. target = noise_scheduler.get_velocity(latents, noise, timesteps)
  275. else:
  276. target = noise
  277. if args.min_snr_gamma:
  278. # do not mean over batch dimension for snr weight
  279. loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
  280. loss = loss.mean([1, 2, 3])
  281. loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
  282. loss = loss.mean() # mean over batch dimension
  283. else:
  284. loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
  285. accelerator.backward(loss)
  286. if accelerator.sync_gradients and args.max_grad_norm != 0.0:
  287. params_to_clip = []
  288. for m in training_models:
  289. params_to_clip.extend(m.parameters())
  290. accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
  291. optimizer.step()
  292. lr_scheduler.step()
  293. optimizer.zero_grad(set_to_none=True)
  294. # Checks if the accelerator has performed an optimization step behind the scenes
  295. if accelerator.sync_gradients:
  296. progress_bar.update(1)
  297. global_step += 1
  298. train_util.sample_images(
  299. accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
  300. )
  301. # 指定ステップごとにモデルを保存
  302. if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
  303. accelerator.wait_for_everyone()
  304. if accelerator.is_main_process:
  305. src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
  306. train_util.save_sd_model_on_epoch_end_or_stepwise(
  307. args,
  308. False,
  309. accelerator,
  310. src_path,
  311. save_stable_diffusion_format,
  312. use_safetensors,
  313. save_dtype,
  314. epoch,
  315. num_train_epochs,
  316. global_step,
  317. unwrap_model(text_encoder),
  318. unwrap_model(unet),
  319. vae,
  320. )
  321. current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
  322. if args.logging_dir is not None:
  323. logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
  324. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
  325. logs["lr/d*lr"] = (
  326. lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
  327. )
  328. accelerator.log(logs, step=global_step)
  329. # TODO moving averageにする
  330. loss_total += current_loss
  331. avr_loss = loss_total / (step + 1)
  332. logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
  333. progress_bar.set_postfix(**logs)
  334. if global_step >= args.max_train_steps:
  335. break
  336. if args.logging_dir is not None:
  337. logs = {"loss/epoch": loss_total / len(train_dataloader)}
  338. accelerator.log(logs, step=epoch + 1)
  339. accelerator.wait_for_everyone()
  340. if args.save_every_n_epochs is not None:
  341. if accelerator.is_main_process:
  342. src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
  343. train_util.save_sd_model_on_epoch_end_or_stepwise(
  344. args,
  345. True,
  346. accelerator,
  347. src_path,
  348. save_stable_diffusion_format,
  349. use_safetensors,
  350. save_dtype,
  351. epoch,
  352. num_train_epochs,
  353. global_step,
  354. unwrap_model(text_encoder),
  355. unwrap_model(unet),
  356. vae,
  357. )
  358. train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
  359. is_main_process = accelerator.is_main_process
  360. if is_main_process:
  361. unet = unwrap_model(unet)
  362. text_encoder = unwrap_model(text_encoder)
  363. accelerator.end_training()
  364. if args.save_state and is_main_process:
  365. train_util.save_state_on_train_end(args, accelerator)
  366. del accelerator # この後メモリを使うのでこれは消す
  367. if is_main_process:
  368. src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
  369. train_util.save_sd_model_on_train_end(
  370. args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
  371. )
  372. print("model saved.")
  373. def setup_parser() -> argparse.ArgumentParser:
  374. parser = argparse.ArgumentParser()
  375. train_util.add_sd_models_arguments(parser)
  376. train_util.add_dataset_arguments(parser, False, True, True)
  377. train_util.add_training_arguments(parser, False)
  378. train_util.add_sd_saving_arguments(parser)
  379. train_util.add_optimizer_arguments(parser)
  380. config_util.add_config_arguments(parser)
  381. custom_train_functions.add_custom_train_arguments(parser)
  382. parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
  383. parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
  384. return parser
  385. if __name__ == "__main__":
  386. parser = setup_parser()
  387. args = parser.parse_args()
  388. args = train_util.read_config_from_file(args, parser)
  389. train(args)