addnet_xyz_grid_support.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os
  2. import os.path
  3. from modules import shared
  4. import modules.scripts as scripts
  5. from scripts import model_util, util
  6. from scripts.model_util import MAX_MODEL_COUNT
  7. LORA_TRAIN_METADATA_NAMES = {
  8. "ss_session_id": "Session ID",
  9. "ss_training_started_at": "Training started at",
  10. "ss_output_name": "Output name",
  11. "ss_learning_rate": "Learning rate",
  12. "ss_text_encoder_lr": "Text encoder LR",
  13. "ss_unet_lr": "UNet LR",
  14. "ss_num_train_images": "# of training images",
  15. "ss_num_reg_images": "# of reg images",
  16. "ss_num_batches_per_epoch": "Batches per epoch",
  17. "ss_num_epochs": "Total epochs",
  18. "ss_epoch": "Epoch",
  19. "ss_batch_size_per_device": "Batch size/device",
  20. "ss_total_batch_size": "Total batch size",
  21. "ss_gradient_checkpointing": "Gradient checkpointing",
  22. "ss_gradient_accumulation_steps": "Gradient accum. steps",
  23. "ss_max_train_steps": "Max train steps",
  24. "ss_lr_warmup_steps": "LR warmup steps",
  25. "ss_lr_scheduler": "LR scheduler",
  26. "ss_network_module": "Network module",
  27. "ss_network_dim": "Network dim",
  28. "ss_network_alpha": "Network alpha",
  29. "ss_mixed_precision": "Mixed precision",
  30. "ss_full_fp16": "Full FP16",
  31. "ss_v2": "V2",
  32. "ss_resolution": "Resolution",
  33. "ss_clip_skip": "Clip skip",
  34. "ss_max_token_length": "Max token length",
  35. "ss_color_aug": "Color aug",
  36. "ss_flip_aug": "Flip aug",
  37. "ss_random_crop": "Random crop",
  38. "ss_shuffle_caption": "Shuffle caption",
  39. "ss_cache_latents": "Cache latents",
  40. "ss_enable_bucket": "Enable bucket",
  41. "ss_min_bucket_reso": "Min bucket reso.",
  42. "ss_max_bucket_reso": "Max bucket reso.",
  43. "ss_seed": "Seed",
  44. "ss_keep_tokens": "Keep tokens",
  45. "ss_dataset_dirs": "Dataset dirs.",
  46. "ss_reg_dataset_dirs": "Reg dataset dirs.",
  47. "ss_sd_model_name": "SD model name",
  48. "ss_vae_name": "VAE name",
  49. "ss_training_comment": "Comment"
  50. }
  51. xy_grid = None # XY Grid module
  52. script_class = None # additional_networks scripts.Script class
  53. axis_params = [{}] * MAX_MODEL_COUNT
  54. def update_axis_params(i, module, model):
  55. axis_params[i] = {"module": module, "model": model}
  56. def get_axis_model_choices(i):
  57. module = axis_params[i].get("module", "None")
  58. model = axis_params[i].get("model", "None")
  59. if module == "LoRA":
  60. if model != "None":
  61. sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name")
  62. return model_util.get_model_list(module, model, "", sort_by)
  63. return [f"select `Model {i+1}` in `Additional Networks`. models in same folder for selected one will be shown here."]
  64. def update_script_args(p, value, arg_idx):
  65. global script_class
  66. for s in scripts.scripts_txt2img.alwayson_scripts:
  67. if isinstance(s, script_class):
  68. args = list(p.script_args)
  69. # print(f"Changed arg {arg_idx} from {args[s.args_from + arg_idx - 1]} to {value}")
  70. args[s.args_from + arg_idx] = value
  71. p.script_args = tuple(args)
  72. break
  73. def confirm_models(p, xs):
  74. for x in xs:
  75. if x in ["", "None"]:
  76. continue
  77. if not model_util.find_closest_lora_model_name(x):
  78. raise RuntimeError(f"Unknown LoRA model: {x}")
  79. def apply_module(p, x, xs, i):
  80. update_script_args(p, True, 0) # set Enabled to True
  81. update_script_args(p, x, 2 + 4 * i) # enabled, separate_weights, ({module}, model, weight_unet, weight_tenc), ...
  82. def apply_model(p, x, xs, i):
  83. name = model_util.find_closest_lora_model_name(x)
  84. update_script_args(p, True, 0)
  85. update_script_args(p, name, 3 + 4 * i) # enabled, separate_weights, (module, {model}, weight_unet, weight_tenc), ...
  86. def apply_weight(p, x, xs, i):
  87. update_script_args(p, True, 0)
  88. update_script_args(p, x, 4 + 4 * i ) # enabled, separate_weights, (module, model, {weight_unet, weight_tenc}), ...
  89. update_script_args(p, x, 5 + 4 * i)
  90. def apply_weight_unet(p, x, xs, i):
  91. update_script_args(p, True, 0)
  92. update_script_args(p, x, 4 + 4 * i) # enabled, separate_weights, (module, model, {weight_unet}, weight_tenc), ...
  93. def apply_weight_tenc(p, x, xs, i):
  94. update_script_args(p, True, 0)
  95. update_script_args(p, x, 5 + 4 * i) # enabled, separate_weights, (module, model, weight_unet, {weight_tenc}), ...
  96. def format_lora_model(p, opt, x):
  97. global xy_grid
  98. model = model_util.find_closest_lora_model_name(x)
  99. if model is None or model.lower() in ["", "none"]:
  100. return "None"
  101. value = xy_grid.format_value(p, opt, model)
  102. model_path = model_util.lora_models.get(model)
  103. metadata = model_util.read_model_metadata(model_path, "LoRA")
  104. if not metadata:
  105. return value
  106. metadata_names = util.split_path_list(shared.opts.data.get("additional_networks_xy_grid_model_metadata", ""))
  107. if not metadata_names:
  108. return value
  109. for name in metadata_names:
  110. name = name.strip()
  111. if name in metadata:
  112. formatted_name = LORA_TRAIN_METADATA_NAMES.get(name, name)
  113. value += f"\n{formatted_name}: {metadata[name]}, "
  114. return value.strip(" ").strip(",")
  115. def initialize(script):
  116. global xy_grid, script_class
  117. xy_grid = None
  118. script_class = script
  119. for scriptDataTuple in scripts.scripts_data:
  120. if os.path.basename(scriptDataTuple.path) == "xy_grid.py" or os.path.basename(scriptDataTuple.path) == "xyz_grid.py":
  121. xy_grid = scriptDataTuple.module
  122. for i in range(MAX_MODEL_COUNT):
  123. model = xy_grid.AxisOption(f"AddNet Model {i+1}", str, lambda p, x, xs, i=i: apply_model(p, x, xs, i), format_lora_model, confirm_models, cost=0.5, choices=lambda i=i: get_axis_model_choices(i))
  124. weight = xy_grid.AxisOption(f"AddNet Weight {i+1}", float, lambda p, x, xs, i=i: apply_weight(p, x, xs, i), xy_grid.format_value_add_label, None, cost=0.5)
  125. weight_unet = xy_grid.AxisOption(f"AddNet UNet Weight {i+1}", float, lambda p, x, xs, i=i: apply_weight_unet(p, x, xs, i), xy_grid.format_value_add_label, None, cost=0.5)
  126. weight_tenc = xy_grid.AxisOption(f"AddNet TEnc Weight {i+1}", float, lambda p, x, xs, i=i: apply_weight_tenc(p, x, xs, i), xy_grid.format_value_add_label, None, cost=0.5)
  127. xy_grid.axis_options.extend([model, weight, weight_unet, weight_tenc])