train_textual_inversion.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. import importlib
  2. import argparse
  3. import gc
  4. import math
  5. import os
  6. import toml
  7. from multiprocessing import Value
  8. from tqdm import tqdm
  9. import torch
  10. from accelerate.utils import set_seed
  11. import diffusers
  12. from diffusers import DDPMScheduler
  13. import library.train_util as train_util
  14. import library.huggingface_util as huggingface_util
  15. import library.config_util as config_util
  16. from library.config_util import (
  17. ConfigSanitizer,
  18. BlueprintGenerator,
  19. )
  20. import library.custom_train_functions as custom_train_functions
  21. from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
  22. imagenet_templates_small = [
  23. "a photo of a {}",
  24. "a rendering of a {}",
  25. "a cropped photo of the {}",
  26. "the photo of a {}",
  27. "a photo of a clean {}",
  28. "a photo of a dirty {}",
  29. "a dark photo of the {}",
  30. "a photo of my {}",
  31. "a photo of the cool {}",
  32. "a close-up photo of a {}",
  33. "a bright photo of the {}",
  34. "a cropped photo of a {}",
  35. "a photo of the {}",
  36. "a good photo of the {}",
  37. "a photo of one {}",
  38. "a close-up photo of the {}",
  39. "a rendition of the {}",
  40. "a photo of the clean {}",
  41. "a rendition of a {}",
  42. "a photo of a nice {}",
  43. "a good photo of a {}",
  44. "a photo of the nice {}",
  45. "a photo of the small {}",
  46. "a photo of the weird {}",
  47. "a photo of the large {}",
  48. "a photo of a cool {}",
  49. "a photo of a small {}",
  50. ]
  51. imagenet_style_templates_small = [
  52. "a painting in the style of {}",
  53. "a rendering in the style of {}",
  54. "a cropped painting in the style of {}",
  55. "the painting in the style of {}",
  56. "a clean painting in the style of {}",
  57. "a dirty painting in the style of {}",
  58. "a dark painting in the style of {}",
  59. "a picture in the style of {}",
  60. "a cool painting in the style of {}",
  61. "a close-up painting in the style of {}",
  62. "a bright painting in the style of {}",
  63. "a cropped painting in the style of {}",
  64. "a good painting in the style of {}",
  65. "a close-up painting in the style of {}",
  66. "a rendition in the style of {}",
  67. "a nice painting in the style of {}",
  68. "a small painting in the style of {}",
  69. "a weird painting in the style of {}",
  70. "a large painting in the style of {}",
  71. ]
  72. def train(args):
  73. if args.output_name is None:
  74. args.output_name = args.token_string
  75. use_template = args.use_object_template or args.use_style_template
  76. train_util.verify_training_args(args)
  77. train_util.prepare_dataset_args(args, True)
  78. cache_latents = args.cache_latents
  79. if args.seed is not None:
  80. set_seed(args.seed)
  81. tokenizer = train_util.load_tokenizer(args)
  82. # acceleratorを準備する
  83. print("prepare accelerator")
  84. accelerator, unwrap_model = train_util.prepare_accelerator(args)
  85. # mixed precisionに対応した型を用意しておき適宜castする
  86. weight_dtype, save_dtype = train_util.prepare_dtype(args)
  87. # モデルを読み込む
  88. text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
  89. # Convert the init_word to token_id
  90. if args.init_word is not None:
  91. init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
  92. if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
  93. print(
  94. f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
  95. )
  96. else:
  97. init_token_ids = None
  98. # add new word to tokenizer, count is num_vectors_per_token
  99. token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
  100. num_added_tokens = tokenizer.add_tokens(token_strings)
  101. assert (
  102. num_added_tokens == args.num_vectors_per_token
  103. ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
  104. token_ids = tokenizer.convert_tokens_to_ids(token_strings)
  105. print(f"tokens are added: {token_ids}")
  106. assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
  107. assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
  108. # Resize the token embeddings as we are adding new special tokens to the tokenizer
  109. text_encoder.resize_token_embeddings(len(tokenizer))
  110. # Initialise the newly added placeholder token with the embeddings of the initializer token
  111. token_embeds = text_encoder.get_input_embeddings().weight.data
  112. if init_token_ids is not None:
  113. for i, token_id in enumerate(token_ids):
  114. token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
  115. # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
  116. # load weights
  117. if args.weights is not None:
  118. embeddings = load_weights(args.weights)
  119. assert len(token_ids) == len(
  120. embeddings
  121. ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
  122. # print(token_ids, embeddings.size())
  123. for token_id, embedding in zip(token_ids, embeddings):
  124. token_embeds[token_id] = embedding
  125. # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
  126. print(f"weighs loaded")
  127. print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
  128. # データセットを準備する
  129. blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
  130. if args.dataset_config is not None:
  131. print(f"Load dataset config from {args.dataset_config}")
  132. user_config = config_util.load_user_config(args.dataset_config)
  133. ignored = ["train_data_dir", "reg_data_dir", "in_json"]
  134. if any(getattr(args, attr) is not None for attr in ignored):
  135. print(
  136. "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
  137. ", ".join(ignored)
  138. )
  139. )
  140. else:
  141. use_dreambooth_method = args.in_json is None
  142. if use_dreambooth_method:
  143. print("Use DreamBooth method.")
  144. user_config = {
  145. "datasets": [
  146. {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
  147. ]
  148. }
  149. else:
  150. print("Train with captions.")
  151. user_config = {
  152. "datasets": [
  153. {
  154. "subsets": [
  155. {
  156. "image_dir": args.train_data_dir,
  157. "metadata_file": args.in_json,
  158. }
  159. ]
  160. }
  161. ]
  162. }
  163. blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
  164. train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
  165. current_epoch = Value("i", 0)
  166. current_step = Value("i", 0)
  167. ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
  168. collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
  169. # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
  170. if use_template:
  171. print("use template for training captions. is object: {args.use_object_template}")
  172. templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
  173. replace_to = " ".join(token_strings)
  174. captions = []
  175. for tmpl in templates:
  176. captions.append(tmpl.format(replace_to))
  177. train_dataset_group.add_replacement("", captions)
  178. if args.num_vectors_per_token > 1:
  179. prompt_replacement = (args.token_string, replace_to)
  180. else:
  181. prompt_replacement = None
  182. else:
  183. if args.num_vectors_per_token > 1:
  184. replace_to = " ".join(token_strings)
  185. train_dataset_group.add_replacement(args.token_string, replace_to)
  186. prompt_replacement = (args.token_string, replace_to)
  187. else:
  188. prompt_replacement = None
  189. if args.debug_dataset:
  190. train_util.debug_dataset(train_dataset_group, show_input_ids=True)
  191. return
  192. if len(train_dataset_group) == 0:
  193. print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
  194. return
  195. if cache_latents:
  196. assert (
  197. train_dataset_group.is_latent_cacheable()
  198. ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
  199. # モデルに xformers とか memory efficient attention を組み込む
  200. train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
  201. # 学習を準備する
  202. if cache_latents:
  203. vae.to(accelerator.device, dtype=weight_dtype)
  204. vae.requires_grad_(False)
  205. vae.eval()
  206. with torch.no_grad():
  207. train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
  208. vae.to("cpu")
  209. if torch.cuda.is_available():
  210. torch.cuda.empty_cache()
  211. gc.collect()
  212. accelerator.wait_for_everyone()
  213. if args.gradient_checkpointing:
  214. unet.enable_gradient_checkpointing()
  215. text_encoder.gradient_checkpointing_enable()
  216. # 学習に必要なクラスを準備する
  217. print("prepare optimizer, data loader etc.")
  218. trainable_params = text_encoder.get_input_embeddings().parameters()
  219. _, _, optimizer = train_util.get_optimizer(args, trainable_params)
  220. # dataloaderを準備する
  221. # DataLoaderのプロセス数:0はメインプロセスになる
  222. n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
  223. train_dataloader = torch.utils.data.DataLoader(
  224. train_dataset_group,
  225. batch_size=1,
  226. shuffle=True,
  227. collate_fn=collater,
  228. num_workers=n_workers,
  229. persistent_workers=args.persistent_data_loader_workers,
  230. )
  231. # 学習ステップ数を計算する
  232. if args.max_train_epochs is not None:
  233. args.max_train_steps = args.max_train_epochs * math.ceil(
  234. len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
  235. )
  236. print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
  237. # データセット側にも学習ステップを送信
  238. train_dataset_group.set_max_train_steps(args.max_train_steps)
  239. # lr schedulerを用意する
  240. lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
  241. # acceleratorがなんかよろしくやってくれるらしい
  242. text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  243. text_encoder, optimizer, train_dataloader, lr_scheduler
  244. )
  245. # transform DDP after prepare
  246. text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
  247. index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
  248. # print(len(index_no_updates), torch.sum(index_no_updates))
  249. orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
  250. # Freeze all parameters except for the token embeddings in text encoder
  251. text_encoder.requires_grad_(True)
  252. text_encoder.text_model.encoder.requires_grad_(False)
  253. text_encoder.text_model.final_layer_norm.requires_grad_(False)
  254. text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
  255. # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
  256. unet.requires_grad_(False)
  257. unet.to(accelerator.device, dtype=weight_dtype)
  258. if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
  259. unet.train()
  260. else:
  261. unet.eval()
  262. if not cache_latents:
  263. vae.requires_grad_(False)
  264. vae.eval()
  265. vae.to(accelerator.device, dtype=weight_dtype)
  266. # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
  267. if args.full_fp16:
  268. train_util.patch_accelerator_for_fp16_training(accelerator)
  269. text_encoder.to(weight_dtype)
  270. # resumeする
  271. train_util.resume_from_local_or_hf_if_specified(accelerator, args)
  272. # epoch数を計算する
  273. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
  274. num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
  275. if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
  276. args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
  277. # 学習する
  278. total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
  279. print("running training / 学習開始")
  280. print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
  281. print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
  282. print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
  283. print(f" num epochs / epoch数: {num_train_epochs}")
  284. print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
  285. print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
  286. print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
  287. print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
  288. progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
  289. global_step = 0
  290. noise_scheduler = DDPMScheduler(
  291. beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
  292. )
  293. if accelerator.is_main_process:
  294. accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
  295. # function for saving/removing
  296. def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
  297. os.makedirs(args.output_dir, exist_ok=True)
  298. ckpt_file = os.path.join(args.output_dir, ckpt_name)
  299. print(f"saving checkpoint: {ckpt_file}")
  300. save_weights(ckpt_file, embs, save_dtype)
  301. if args.huggingface_repo_id is not None:
  302. huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
  303. def remove_model(old_ckpt_name):
  304. old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
  305. if os.path.exists(old_ckpt_file):
  306. print(f"removing old checkpoint: {old_ckpt_file}")
  307. os.remove(old_ckpt_file)
  308. # training loop
  309. for epoch in range(num_train_epochs):
  310. print(f"epoch {epoch+1}/{num_train_epochs}")
  311. current_epoch.value = epoch + 1
  312. text_encoder.train()
  313. loss_total = 0
  314. for step, batch in enumerate(train_dataloader):
  315. current_step.value = global_step
  316. with accelerator.accumulate(text_encoder):
  317. with torch.no_grad():
  318. if "latents" in batch and batch["latents"] is not None:
  319. latents = batch["latents"].to(accelerator.device)
  320. else:
  321. # latentに変換
  322. latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
  323. latents = latents * 0.18215
  324. b_size = latents.shape[0]
  325. # Get the text embedding for conditioning
  326. input_ids = batch["input_ids"].to(accelerator.device)
  327. # use float instead of fp16/bf16 because text encoder is float
  328. encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
  329. # Sample noise that we'll add to the latents
  330. noise = torch.randn_like(latents, device=latents.device)
  331. if args.noise_offset:
  332. # https://www.crosslabs.org//blog/diffusion-with-offset-noise
  333. noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
  334. elif args.multires_noise_iterations:
  335. noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
  336. # Sample a random timestep for each image
  337. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
  338. timesteps = timesteps.long()
  339. # Add noise to the latents according to the noise magnitude at each timestep
  340. # (this is the forward diffusion process)
  341. noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  342. # Predict the noise residual
  343. with accelerator.autocast():
  344. noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  345. if args.v_parameterization:
  346. # v-parameterization training
  347. target = noise_scheduler.get_velocity(latents, noise, timesteps)
  348. else:
  349. target = noise
  350. loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
  351. loss = loss.mean([1, 2, 3])
  352. if args.min_snr_gamma:
  353. loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
  354. loss_weights = batch["loss_weights"] # 各sampleごとのweight
  355. loss = loss * loss_weights
  356. loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
  357. accelerator.backward(loss)
  358. if accelerator.sync_gradients and args.max_grad_norm != 0.0:
  359. params_to_clip = text_encoder.get_input_embeddings().parameters()
  360. accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
  361. optimizer.step()
  362. lr_scheduler.step()
  363. optimizer.zero_grad(set_to_none=True)
  364. # Let's make sure we don't update any embedding weights besides the newly added token
  365. with torch.no_grad():
  366. unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
  367. index_no_updates
  368. ]
  369. # Checks if the accelerator has performed an optimization step behind the scenes
  370. if accelerator.sync_gradients:
  371. progress_bar.update(1)
  372. global_step += 1
  373. train_util.sample_images(
  374. accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
  375. )
  376. # 指定ステップごとにモデルを保存
  377. if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
  378. accelerator.wait_for_everyone()
  379. if accelerator.is_main_process:
  380. updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
  381. ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
  382. save_model(ckpt_name, updated_embs, global_step, epoch)
  383. if args.save_state:
  384. train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
  385. remove_step_no = train_util.get_remove_step_no(args, global_step)
  386. if remove_step_no is not None:
  387. remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
  388. remove_model(remove_ckpt_name)
  389. current_loss = loss.detach().item()
  390. if args.logging_dir is not None:
  391. logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
  392. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
  393. logs["lr/d*lr"] = (
  394. lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
  395. )
  396. accelerator.log(logs, step=global_step)
  397. loss_total += current_loss
  398. avr_loss = loss_total / (step + 1)
  399. logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
  400. progress_bar.set_postfix(**logs)
  401. if global_step >= args.max_train_steps:
  402. break
  403. if args.logging_dir is not None:
  404. logs = {"loss/epoch": loss_total / len(train_dataloader)}
  405. accelerator.log(logs, step=epoch + 1)
  406. accelerator.wait_for_everyone()
  407. updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
  408. if args.save_every_n_epochs is not None:
  409. saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
  410. if accelerator.is_main_process and saving:
  411. ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
  412. save_model(ckpt_name, updated_embs, epoch + 1, global_step)
  413. remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
  414. if remove_epoch_no is not None:
  415. remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
  416. remove_model(remove_ckpt_name)
  417. if args.save_state:
  418. train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
  419. train_util.sample_images(
  420. accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
  421. )
  422. # end of epoch
  423. is_main_process = accelerator.is_main_process
  424. if is_main_process:
  425. text_encoder = unwrap_model(text_encoder)
  426. accelerator.end_training()
  427. if args.save_state and is_main_process:
  428. train_util.save_state_on_train_end(args, accelerator)
  429. updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
  430. del accelerator # この後メモリを使うのでこれは消す
  431. if is_main_process:
  432. ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
  433. save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
  434. print("model saved.")
  435. def save_weights(file, updated_embs, save_dtype):
  436. state_dict = {"emb_params": updated_embs}
  437. if save_dtype is not None:
  438. for key in list(state_dict.keys()):
  439. v = state_dict[key]
  440. v = v.detach().clone().to("cpu").to(save_dtype)
  441. state_dict[key] = v
  442. if os.path.splitext(file)[1] == ".safetensors":
  443. from safetensors.torch import save_file
  444. save_file(state_dict, file)
  445. else:
  446. torch.save(state_dict, file) # can be loaded in Web UI
  447. def load_weights(file):
  448. if os.path.splitext(file)[1] == ".safetensors":
  449. from safetensors.torch import load_file
  450. data = load_file(file)
  451. else:
  452. # compatible to Web UI's file format
  453. data = torch.load(file, map_location="cpu")
  454. if type(data) != dict:
  455. raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
  456. if "string_to_param" in data: # textual inversion embeddings
  457. data = data["string_to_param"]
  458. if hasattr(data, "_parameters"): # support old PyTorch?
  459. data = getattr(data, "_parameters")
  460. emb = next(iter(data.values()))
  461. if type(emb) != torch.Tensor:
  462. raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
  463. if len(emb.size()) == 1:
  464. emb = emb.unsqueeze(0)
  465. return emb
  466. def setup_parser() -> argparse.ArgumentParser:
  467. parser = argparse.ArgumentParser()
  468. train_util.add_sd_models_arguments(parser)
  469. train_util.add_dataset_arguments(parser, True, True, False)
  470. train_util.add_training_arguments(parser, True)
  471. train_util.add_optimizer_arguments(parser)
  472. config_util.add_config_arguments(parser)
  473. custom_train_functions.add_custom_train_arguments(parser, False)
  474. parser.add_argument(
  475. "--save_model_as",
  476. type=str,
  477. default="pt",
  478. choices=[None, "ckpt", "pt", "safetensors"],
  479. help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
  480. )
  481. parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
  482. parser.add_argument(
  483. "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
  484. )
  485. parser.add_argument(
  486. "--token_string",
  487. type=str,
  488. default=None,
  489. help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
  490. )
  491. parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
  492. parser.add_argument(
  493. "--use_object_template",
  494. action="store_true",
  495. help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
  496. )
  497. parser.add_argument(
  498. "--use_style_template",
  499. action="store_true",
  500. help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
  501. )
  502. return parser
  503. if __name__ == "__main__":
  504. parser = setup_parser()
  505. args = parser.parse_args()
  506. args = train_util.read_config_from_file(args, parser)
  507. train(args)