train_textual_inversion_XTI.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. import importlib
  2. import argparse
  3. import gc
  4. import math
  5. import os
  6. import toml
  7. from multiprocessing import Value
  8. from tqdm import tqdm
  9. import torch
  10. from accelerate.utils import set_seed
  11. import diffusers
  12. from diffusers import DDPMScheduler
  13. import library.train_util as train_util
  14. import library.huggingface_util as huggingface_util
  15. import library.config_util as config_util
  16. from library.config_util import (
  17. ConfigSanitizer,
  18. BlueprintGenerator,
  19. )
  20. import library.custom_train_functions as custom_train_functions
  21. from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
  22. from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
  23. imagenet_templates_small = [
  24. "a photo of a {}",
  25. "a rendering of a {}",
  26. "a cropped photo of the {}",
  27. "the photo of a {}",
  28. "a photo of a clean {}",
  29. "a photo of a dirty {}",
  30. "a dark photo of the {}",
  31. "a photo of my {}",
  32. "a photo of the cool {}",
  33. "a close-up photo of a {}",
  34. "a bright photo of the {}",
  35. "a cropped photo of a {}",
  36. "a photo of the {}",
  37. "a good photo of the {}",
  38. "a photo of one {}",
  39. "a close-up photo of the {}",
  40. "a rendition of the {}",
  41. "a photo of the clean {}",
  42. "a rendition of a {}",
  43. "a photo of a nice {}",
  44. "a good photo of a {}",
  45. "a photo of the nice {}",
  46. "a photo of the small {}",
  47. "a photo of the weird {}",
  48. "a photo of the large {}",
  49. "a photo of a cool {}",
  50. "a photo of a small {}",
  51. ]
  52. imagenet_style_templates_small = [
  53. "a painting in the style of {}",
  54. "a rendering in the style of {}",
  55. "a cropped painting in the style of {}",
  56. "the painting in the style of {}",
  57. "a clean painting in the style of {}",
  58. "a dirty painting in the style of {}",
  59. "a dark painting in the style of {}",
  60. "a picture in the style of {}",
  61. "a cool painting in the style of {}",
  62. "a close-up painting in the style of {}",
  63. "a bright painting in the style of {}",
  64. "a cropped painting in the style of {}",
  65. "a good painting in the style of {}",
  66. "a close-up painting in the style of {}",
  67. "a rendition in the style of {}",
  68. "a nice painting in the style of {}",
  69. "a small painting in the style of {}",
  70. "a weird painting in the style of {}",
  71. "a large painting in the style of {}",
  72. ]
  73. def train(args):
  74. if args.output_name is None:
  75. args.output_name = args.token_string
  76. use_template = args.use_object_template or args.use_style_template
  77. train_util.verify_training_args(args)
  78. train_util.prepare_dataset_args(args, True)
  79. if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
  80. print(
  81. "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
  82. )
  83. cache_latents = args.cache_latents
  84. if args.seed is not None:
  85. set_seed(args.seed)
  86. tokenizer = train_util.load_tokenizer(args)
  87. # acceleratorを準備する
  88. print("prepare accelerator")
  89. accelerator, unwrap_model = train_util.prepare_accelerator(args)
  90. # mixed precisionに対応した型を用意しておき適宜castする
  91. weight_dtype, save_dtype = train_util.prepare_dtype(args)
  92. # モデルを読み込む
  93. text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
  94. # Convert the init_word to token_id
  95. if args.init_word is not None:
  96. init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
  97. if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
  98. print(
  99. f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
  100. )
  101. else:
  102. init_token_ids = None
  103. # add new word to tokenizer, count is num_vectors_per_token
  104. token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
  105. num_added_tokens = tokenizer.add_tokens(token_strings)
  106. assert (
  107. num_added_tokens == args.num_vectors_per_token
  108. ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
  109. token_ids = tokenizer.convert_tokens_to_ids(token_strings)
  110. print(f"tokens are added: {token_ids}")
  111. assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
  112. assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
  113. token_strings_XTI = []
  114. XTI_layers = [
  115. "IN01",
  116. "IN02",
  117. "IN04",
  118. "IN05",
  119. "IN07",
  120. "IN08",
  121. "MID",
  122. "OUT03",
  123. "OUT04",
  124. "OUT05",
  125. "OUT06",
  126. "OUT07",
  127. "OUT08",
  128. "OUT09",
  129. "OUT10",
  130. "OUT11",
  131. ]
  132. for layer_name in XTI_layers:
  133. token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
  134. tokenizer.add_tokens(token_strings_XTI)
  135. token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
  136. print(f"tokens are added (XTI): {token_ids_XTI}")
  137. # Resize the token embeddings as we are adding new special tokens to the tokenizer
  138. text_encoder.resize_token_embeddings(len(tokenizer))
  139. # Initialise the newly added placeholder token with the embeddings of the initializer token
  140. token_embeds = text_encoder.get_input_embeddings().weight.data
  141. if init_token_ids is not None:
  142. for i, token_id in enumerate(token_ids_XTI):
  143. token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
  144. # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
  145. # load weights
  146. if args.weights is not None:
  147. embeddings = load_weights(args.weights)
  148. assert len(token_ids) == len(
  149. embeddings
  150. ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
  151. # print(token_ids, embeddings.size())
  152. for token_id, embedding in zip(token_ids_XTI, embeddings):
  153. token_embeds[token_id] = embedding
  154. # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
  155. print(f"weighs loaded")
  156. print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
  157. # データセットを準備する
  158. blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
  159. if args.dataset_config is not None:
  160. print(f"Load dataset config from {args.dataset_config}")
  161. user_config = config_util.load_user_config(args.dataset_config)
  162. ignored = ["train_data_dir", "reg_data_dir", "in_json"]
  163. if any(getattr(args, attr) is not None for attr in ignored):
  164. print(
  165. "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
  166. ", ".join(ignored)
  167. )
  168. )
  169. else:
  170. use_dreambooth_method = args.in_json is None
  171. if use_dreambooth_method:
  172. print("Use DreamBooth method.")
  173. user_config = {
  174. "datasets": [
  175. {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
  176. ]
  177. }
  178. else:
  179. print("Train with captions.")
  180. user_config = {
  181. "datasets": [
  182. {
  183. "subsets": [
  184. {
  185. "image_dir": args.train_data_dir,
  186. "metadata_file": args.in_json,
  187. }
  188. ]
  189. }
  190. ]
  191. }
  192. blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
  193. train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
  194. train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
  195. current_epoch = Value("i", 0)
  196. current_step = Value("i", 0)
  197. ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
  198. collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
  199. # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
  200. if use_template:
  201. print("use template for training captions. is object: {args.use_object_template}")
  202. templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
  203. replace_to = " ".join(token_strings)
  204. captions = []
  205. for tmpl in templates:
  206. captions.append(tmpl.format(replace_to))
  207. train_dataset_group.add_replacement("", captions)
  208. if args.num_vectors_per_token > 1:
  209. prompt_replacement = (args.token_string, replace_to)
  210. else:
  211. prompt_replacement = None
  212. else:
  213. if args.num_vectors_per_token > 1:
  214. replace_to = " ".join(token_strings)
  215. train_dataset_group.add_replacement(args.token_string, replace_to)
  216. prompt_replacement = (args.token_string, replace_to)
  217. else:
  218. prompt_replacement = None
  219. if args.debug_dataset:
  220. train_util.debug_dataset(train_dataset_group, show_input_ids=True)
  221. return
  222. if len(train_dataset_group) == 0:
  223. print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
  224. return
  225. if cache_latents:
  226. assert (
  227. train_dataset_group.is_latent_cacheable()
  228. ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
  229. # モデルに xformers とか memory efficient attention を組み込む
  230. train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
  231. diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
  232. diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
  233. diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
  234. # 学習を準備する
  235. if cache_latents:
  236. vae.to(accelerator.device, dtype=weight_dtype)
  237. vae.requires_grad_(False)
  238. vae.eval()
  239. with torch.no_grad():
  240. train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
  241. vae.to("cpu")
  242. if torch.cuda.is_available():
  243. torch.cuda.empty_cache()
  244. gc.collect()
  245. accelerator.wait_for_everyone()
  246. if args.gradient_checkpointing:
  247. unet.enable_gradient_checkpointing()
  248. text_encoder.gradient_checkpointing_enable()
  249. # 学習に必要なクラスを準備する
  250. print("prepare optimizer, data loader etc.")
  251. trainable_params = text_encoder.get_input_embeddings().parameters()
  252. _, _, optimizer = train_util.get_optimizer(args, trainable_params)
  253. # dataloaderを準備する
  254. # DataLoaderのプロセス数:0はメインプロセスになる
  255. n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
  256. train_dataloader = torch.utils.data.DataLoader(
  257. train_dataset_group,
  258. batch_size=1,
  259. shuffle=True,
  260. collate_fn=collater,
  261. num_workers=n_workers,
  262. persistent_workers=args.persistent_data_loader_workers,
  263. )
  264. # 学習ステップ数を計算する
  265. if args.max_train_epochs is not None:
  266. args.max_train_steps = args.max_train_epochs * math.ceil(
  267. len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
  268. )
  269. print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
  270. # データセット側にも学習ステップを送信
  271. train_dataset_group.set_max_train_steps(args.max_train_steps)
  272. # lr schedulerを用意する
  273. lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
  274. # acceleratorがなんかよろしくやってくれるらしい
  275. text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  276. text_encoder, optimizer, train_dataloader, lr_scheduler
  277. )
  278. # transform DDP after prepare
  279. text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
  280. index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
  281. # print(len(index_no_updates), torch.sum(index_no_updates))
  282. orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
  283. # Freeze all parameters except for the token embeddings in text encoder
  284. text_encoder.requires_grad_(True)
  285. text_encoder.text_model.encoder.requires_grad_(False)
  286. text_encoder.text_model.final_layer_norm.requires_grad_(False)
  287. text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
  288. # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
  289. unet.requires_grad_(False)
  290. unet.to(accelerator.device, dtype=weight_dtype)
  291. if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
  292. unet.train()
  293. else:
  294. unet.eval()
  295. if not cache_latents:
  296. vae.requires_grad_(False)
  297. vae.eval()
  298. vae.to(accelerator.device, dtype=weight_dtype)
  299. # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
  300. if args.full_fp16:
  301. train_util.patch_accelerator_for_fp16_training(accelerator)
  302. text_encoder.to(weight_dtype)
  303. # resumeする
  304. train_util.resume_from_local_or_hf_if_specified(accelerator, args)
  305. # epoch数を計算する
  306. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
  307. num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
  308. if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
  309. args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
  310. # 学習する
  311. total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
  312. print("running training / 学習開始")
  313. print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
  314. print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
  315. print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
  316. print(f" num epochs / epoch数: {num_train_epochs}")
  317. print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
  318. print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
  319. print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
  320. print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
  321. progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
  322. global_step = 0
  323. noise_scheduler = DDPMScheduler(
  324. beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
  325. )
  326. if accelerator.is_main_process:
  327. accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
  328. # function for saving/removing
  329. def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
  330. os.makedirs(args.output_dir, exist_ok=True)
  331. ckpt_file = os.path.join(args.output_dir, ckpt_name)
  332. print(f"saving checkpoint: {ckpt_file}")
  333. save_weights(ckpt_file, embs, save_dtype)
  334. if args.huggingface_repo_id is not None:
  335. huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
  336. def remove_model(old_ckpt_name):
  337. old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
  338. if os.path.exists(old_ckpt_file):
  339. print(f"removing old checkpoint: {old_ckpt_file}")
  340. os.remove(old_ckpt_file)
  341. # training loop
  342. for epoch in range(num_train_epochs):
  343. print(f"epoch {epoch+1}/{num_train_epochs}")
  344. current_epoch.value = epoch + 1
  345. text_encoder.train()
  346. loss_total = 0
  347. for step, batch in enumerate(train_dataloader):
  348. current_step.value = global_step
  349. with accelerator.accumulate(text_encoder):
  350. with torch.no_grad():
  351. if "latents" in batch and batch["latents"] is not None:
  352. latents = batch["latents"].to(accelerator.device)
  353. else:
  354. # latentに変換
  355. latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
  356. latents = latents * 0.18215
  357. b_size = latents.shape[0]
  358. # Get the text embedding for conditioning
  359. input_ids = batch["input_ids"].to(accelerator.device)
  360. # weight_dtype) use float instead of fp16/bf16 because text encoder is float
  361. encoder_hidden_states = torch.stack(
  362. [
  363. train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
  364. for s in torch.split(input_ids, 1, dim=1)
  365. ]
  366. )
  367. # Sample noise that we'll add to the latents
  368. noise = torch.randn_like(latents, device=latents.device)
  369. if args.noise_offset:
  370. # https://www.crosslabs.org//blog/diffusion-with-offset-noise
  371. noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
  372. elif args.multires_noise_iterations:
  373. noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
  374. # Sample a random timestep for each image
  375. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
  376. timesteps = timesteps.long()
  377. # Add noise to the latents according to the noise magnitude at each timestep
  378. # (this is the forward diffusion process)
  379. noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  380. # Predict the noise residual
  381. with accelerator.autocast():
  382. noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
  383. if args.v_parameterization:
  384. # v-parameterization training
  385. target = noise_scheduler.get_velocity(latents, noise, timesteps)
  386. else:
  387. target = noise
  388. loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
  389. loss = loss.mean([1, 2, 3])
  390. if args.min_snr_gamma:
  391. loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
  392. loss_weights = batch["loss_weights"] # 各sampleごとのweight
  393. loss = loss * loss_weights
  394. loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
  395. accelerator.backward(loss)
  396. if accelerator.sync_gradients and args.max_grad_norm != 0.0:
  397. params_to_clip = text_encoder.get_input_embeddings().parameters()
  398. accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
  399. optimizer.step()
  400. lr_scheduler.step()
  401. optimizer.zero_grad(set_to_none=True)
  402. # Let's make sure we don't update any embedding weights besides the newly added token
  403. with torch.no_grad():
  404. unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
  405. index_no_updates
  406. ]
  407. # Checks if the accelerator has performed an optimization step behind the scenes
  408. if accelerator.sync_gradients:
  409. progress_bar.update(1)
  410. global_step += 1
  411. # TODO: fix sample_images
  412. # train_util.sample_images(
  413. # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
  414. # )
  415. # 指定ステップごとにモデルを保存
  416. if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
  417. accelerator.wait_for_everyone()
  418. if accelerator.is_main_process:
  419. updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
  420. ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
  421. save_model(ckpt_name, updated_embs, global_step, epoch)
  422. if args.save_state:
  423. train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
  424. remove_step_no = train_util.get_remove_step_no(args, global_step)
  425. if remove_step_no is not None:
  426. remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
  427. remove_model(remove_ckpt_name)
  428. current_loss = loss.detach().item()
  429. if args.logging_dir is not None:
  430. logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
  431. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
  432. logs["lr/d*lr"] = (
  433. lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
  434. )
  435. accelerator.log(logs, step=global_step)
  436. loss_total += current_loss
  437. avr_loss = loss_total / (step + 1)
  438. logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
  439. progress_bar.set_postfix(**logs)
  440. if global_step >= args.max_train_steps:
  441. break
  442. if args.logging_dir is not None:
  443. logs = {"loss/epoch": loss_total / len(train_dataloader)}
  444. accelerator.log(logs, step=epoch + 1)
  445. accelerator.wait_for_everyone()
  446. updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
  447. if args.save_every_n_epochs is not None:
  448. saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
  449. if accelerator.is_main_process and saving:
  450. ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
  451. save_model(ckpt_name, updated_embs, epoch + 1, global_step)
  452. remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
  453. if remove_epoch_no is not None:
  454. remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
  455. remove_model(remove_ckpt_name)
  456. if args.save_state:
  457. train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
  458. # TODO: fix sample_images
  459. # train_util.sample_images(
  460. # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
  461. # )
  462. # end of epoch
  463. is_main_process = accelerator.is_main_process
  464. if is_main_process:
  465. text_encoder = unwrap_model(text_encoder)
  466. accelerator.end_training()
  467. if args.save_state and is_main_process:
  468. train_util.save_state_on_train_end(args, accelerator)
  469. updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
  470. del accelerator # この後メモリを使うのでこれは消す
  471. if is_main_process:
  472. ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
  473. save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
  474. print("model saved.")
  475. def save_weights(file, updated_embs, save_dtype):
  476. updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
  477. updated_embs = updated_embs.chunk(16)
  478. XTI_layers = [
  479. "IN01",
  480. "IN02",
  481. "IN04",
  482. "IN05",
  483. "IN07",
  484. "IN08",
  485. "MID",
  486. "OUT03",
  487. "OUT04",
  488. "OUT05",
  489. "OUT06",
  490. "OUT07",
  491. "OUT08",
  492. "OUT09",
  493. "OUT10",
  494. "OUT11",
  495. ]
  496. state_dict = {}
  497. for i, layer_name in enumerate(XTI_layers):
  498. state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
  499. # if save_dtype is not None:
  500. # for key in list(state_dict.keys()):
  501. # v = state_dict[key]
  502. # v = v.detach().clone().to("cpu").to(save_dtype)
  503. # state_dict[key] = v
  504. if os.path.splitext(file)[1] == ".safetensors":
  505. from safetensors.torch import save_file
  506. save_file(state_dict, file)
  507. else:
  508. torch.save(state_dict, file) # can be loaded in Web UI
  509. def load_weights(file):
  510. if os.path.splitext(file)[1] == ".safetensors":
  511. from safetensors.torch import load_file
  512. data = load_file(file)
  513. else:
  514. raise ValueError(f"NOT XTI: {file}")
  515. if len(data.values()) != 16:
  516. raise ValueError(f"NOT XTI: {file}")
  517. emb = torch.concat([x for x in data.values()])
  518. return emb
  519. def setup_parser() -> argparse.ArgumentParser:
  520. parser = argparse.ArgumentParser()
  521. train_util.add_sd_models_arguments(parser)
  522. train_util.add_dataset_arguments(parser, True, True, False)
  523. train_util.add_training_arguments(parser, True)
  524. train_util.add_optimizer_arguments(parser)
  525. config_util.add_config_arguments(parser)
  526. custom_train_functions.add_custom_train_arguments(parser, False)
  527. parser.add_argument(
  528. "--save_model_as",
  529. type=str,
  530. default="pt",
  531. choices=[None, "ckpt", "pt", "safetensors"],
  532. help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
  533. )
  534. parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
  535. parser.add_argument(
  536. "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
  537. )
  538. parser.add_argument(
  539. "--token_string",
  540. type=str,
  541. default=None,
  542. help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
  543. )
  544. parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
  545. parser.add_argument(
  546. "--use_object_template",
  547. action="store_true",
  548. help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
  549. )
  550. parser.add_argument(
  551. "--use_style_template",
  552. action="store_true",
  553. help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
  554. )
  555. return parser
  556. if __name__ == "__main__":
  557. parser = setup_parser()
  558. args = parser.parse_args()
  559. args = train_util.read_config_from_file(args, parser)
  560. train(args)