123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- import os
- import os.path
- import re
- import shutil
- import json
- import stat
- import tqdm
- from collections import OrderedDict
- from multiprocessing.pool import ThreadPool as Pool
- from modules import shared, sd_models, hashes
- from scripts import safetensors_hack, model_util, util
- import modules.scripts as scripts
- # MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count or 5
- MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count if hasattr(shared.cmd_opts, "addnet_max_model_count") else 5
- LORA_MODEL_EXTS = [".pt", ".ckpt", ".safetensors"]
- re_legacy_hash = re.compile("\(([0-9a-f]{8})\)$") # matches 8-character hashes, new hash has 12 characters
- lora_models = {} # "My_Lora(abcdef123456)" -> "C:/path/to/model.safetensors"
- lora_model_names = {} # "my_lora" -> "My_Lora(My_Lora(abcdef123456)"
- legacy_model_names = {}
- lora_models_dir = os.path.join(scripts.basedir(), "models/lora")
- os.makedirs(lora_models_dir, exist_ok=True)
- def is_safetensors(filename):
- return os.path.splitext(filename)[1] == ".safetensors"
- def read_model_metadata(model_path, module):
- if model_path.startswith("\"") and model_path.endswith("\""): # trim '"' at start/end
- model_path = model_path[1:-1]
- if not os.path.exists(model_path):
- return None
- metadata = None
- if module == "LoRA":
- if os.path.splitext(model_path)[1] == '.safetensors':
- metadata = safetensors_hack.read_metadata(model_path)
- return metadata
- def write_model_metadata(model_path, module, updates):
- if model_path.startswith("\"") and model_path.endswith("\""): # trim '"' at start/end
- model_path = model_path[1:-1]
- if not os.path.exists(model_path):
- return None
- from safetensors.torch import save_file
- back_up = shared.opts.data.get("additional_networks_back_up_model_when_saving", True)
- if back_up:
- backup_path = model_path + ".backup"
- if not os.path.exists(backup_path):
- print(f"[MetadataEditor] Backing up current model to {backup_path}")
- shutil.copyfile(model_path, backup_path)
- metadata = None
- tensors = {}
- if module == "LoRA":
- if os.path.splitext(model_path)[1] == '.safetensors':
- tensors, metadata = safetensors_hack.load_file(model_path, "cpu")
- for k, v in updates.items():
- metadata[k] = str(v)
- save_file(tensors, model_path, metadata)
- print(f"[MetadataEditor] Model saved: {model_path}")
- def get_model_list(module, model, model_dir, sort_by):
- if model_dir == "":
- # Get list of models with same folder as this one
- model_path = lora_models.get(model, None)
- if model_path is None:
- return []
- model_dir = os.path.dirname(model_path)
- if not os.path.isdir(model_dir):
- return []
- found, _ = get_all_models([model_dir], sort_by, "")
- return found.keys()
- def traverse_all_files(curr_path, model_list):
- f_list = [(os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path)]
- for f_info in f_list:
- fname, fstat = f_info
- if os.path.splitext(fname)[1] in LORA_MODEL_EXTS:
- model_list.append(f_info)
- elif stat.S_ISDIR(fstat.st_mode):
- model_list = traverse_all_files(fname, model_list)
- return model_list
- def get_model_hash(metadata, filename):
- if metadata is None:
- return hashes.calculate_sha256(filename)
- if "sshs_model_hash" in metadata:
- return metadata["sshs_model_hash"]
- return safetensors_hack.hash_file(filename)
- def get_legacy_hash(metadata, filename):
- if metadata is None:
- return sd_models.model_hash(filename)
- if "sshs_legacy_hash" in metadata:
- return metadata["sshs_legacy_hash"]
- return safetensors_hack.legacy_hash_file(filename)
- import filelock
- cache_filename = os.path.join(scripts.basedir(), "hashes.json")
- cache_data = None
- def cache(subsection):
- global cache_data
- if cache_data is None:
- with filelock.FileLock(cache_filename+".lock"):
- if not os.path.isfile(cache_filename):
- cache_data = {}
- else:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
- return s
- def dump_cache():
- with filelock.FileLock(cache_filename+".lock"):
- with open(cache_filename, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
- def get_model_rating(filename):
- if not model_util.is_safetensors(filename):
- return 0
- metadata = safetensors_hack.read_metadata(filename)
- return int(metadata.get("ssmd_rating", "0"))
- def has_user_metadata(filename):
- if not model_util.is_safetensors(filename):
- return False
- metadata = safetensors_hack.read_metadata(filename)
- return any(k.startswith("ssmd_") for k in metadata.keys())
- def hash_model_file(finfo):
- filename = finfo[0]
- stat = finfo[1]
- name = os.path.splitext(os.path.basename(filename))[0]
- # Prevent a hypothetical "None.pt" from being listed.
- if name != "None":
- metadata = None
- cached = cache("hashes").get(filename, None)
- if cached is None or stat.st_mtime != cached["mtime"]:
- if metadata is None and model_util.is_safetensors(filename):
- try:
- metadata = safetensors_hack.read_metadata(filename)
- except Exception as ex:
- return {"error": ex, "filename": filename}
- model_hash = get_model_hash(metadata, filename)
- legacy_hash = get_legacy_hash(metadata, filename)
- else:
- model_hash = cached["model"]
- legacy_hash = cached["legacy"]
- return {"model": model_hash, "legacy": legacy_hash, "fileinfo": finfo}
- def get_all_models(paths, sort_by, filter_by):
- fileinfos = []
- for path in paths:
- if os.path.isdir(path):
- fileinfos += traverse_all_files(path, [])
- show_only_safetensors = shared.opts.data.get("additional_networks_show_only_safetensors", False)
- show_only_missing_meta = shared.opts.data.get("additional_networks_show_only_models_with_metadata", "disabled")
- if show_only_safetensors:
- fileinfos = [x for x in fileinfos if is_safetensors(x[0])]
- if show_only_missing_meta == "has metadata":
- fileinfos = [x for x in fileinfos if has_user_metadata(x[0])]
- elif show_only_missing_meta == "missing metadata":
- fileinfos = [x for x in fileinfos if not has_user_metadata(x[0])]
- print("[AddNet] Updating model hashes...")
- data = []
- thread_count = max(1, int(shared.opts.data.get("additional_networks_hash_thread_count", 1)))
- p = Pool(processes=thread_count)
- with tqdm.tqdm(total=len(fileinfos)) as pbar:
- for res in p.imap_unordered(hash_model_file, fileinfos):
- pbar.update()
- if "error" in res:
- print(f"Failed to read model file {res['filename']}: {res['error']}")
- else:
- data.append(res)
- p.close()
- cache_hashes = cache("hashes")
- res = OrderedDict()
- res_legacy = OrderedDict()
- filter_by = filter_by.strip(" ")
- if len(filter_by) != 0:
- data = [x for x in data if filter_by.lower() in os.path.basename(x["fileinfo"][0]).lower()]
- if sort_by == "name":
- data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0]))
- elif sort_by == "date":
- data = sorted(data, key=lambda x: -x["fileinfo"][1].st_mtime)
- elif sort_by == "path name":
- data = sorted(data, key=lambda x: x["fileinfo"][0])
- elif sort_by == "rating":
- data = sorted(data, key=lambda x: get_model_rating(x["fileinfo"][0]), reverse=True)
- elif sort_by == "has user metadata":
- data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0]) if has_user_metadata(x["fileinfo"][0]) else "", reverse=True)
- reverse = shared.opts.data.get("additional_networks_reverse_sort_order", False)
- if reverse:
- data = reversed(data)
- for result in data:
- finfo = result["fileinfo"]
- filename = finfo[0]
- stat = finfo[1]
- model_hash = result["model"]
- legacy_hash = result["legacy"]
- name = os.path.splitext(os.path.basename(filename))[0]
- # Commas in the model name will mess up infotext restoration since the
- # infotext is delimited by commas
- name = name.replace(",", "_")
- # Prevent a hypothetical "None.pt" from being listed.
- if name != "None":
- full_name = name + f"({model_hash[0:12]})"
- res[full_name] = filename
- res_legacy[legacy_hash] = full_name
- cache_hashes[filename] = {"model": model_hash, "legacy": legacy_hash, "mtime": stat.st_mtime}
- return res, res_legacy
- def find_closest_lora_model_name(search: str):
- if not search or search == "None":
- return None
- # Match name and hash, case-sensitive
- # "MyModel-epoch00002(abcdef123456)"
- if search in lora_models:
- return search
- # Match model path, case-sensitive (from metadata editor)
- # "C:/path/to/mymodel-epoch00002.safetensors"
- if os.path.isfile(search):
- import json
- find = os.path.normpath(search)
- value = next((k for k in lora_models.keys() if lora_models[k] == find), None)
- if value:
- return value
- search = search.lower()
- # Match full name, case-insensitive
- # "mymodel-epoch00002"
- if search in lora_model_names:
- return lora_model_names.get(search)
- # Match legacy hash (8 characters)
- # "MyModel(abcd1234)"
- result = re_legacy_hash.search(search)
- if result is not None:
- model_hash = result.group(1)
- if model_hash in legacy_model_names:
- new_model_name = legacy_model_names[model_hash]
- return new_model_name
- # Use any model with the search term as the prefix, case-insensitive, sorted
- # by name length
- # "mymodel"
- applicable = [name for name in lora_model_names.keys() if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return lora_model_names[applicable[0]]
- def update_models():
- global lora_models, lora_model_names, legacy_model_names
- paths = [lora_models_dir]
- extra_lora_paths = util.split_path_list(shared.opts.data.get("additional_networks_extra_lora_path", ""))
- for path in extra_lora_paths:
- path = path.lstrip()
- if os.path.isdir(path):
- paths.append(path)
- sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name")
- filter_by = shared.opts.data.get("additional_networks_model_name_filter", "")
- res, res_legacy = get_all_models(paths, sort_by, filter_by)
- lora_models.clear()
- lora_models["None"] = None
- lora_models.update(res)
- for name_and_hash, filename in lora_models.items():
- if filename == None:
- continue
- name = os.path.splitext(os.path.basename(filename))[0].lower()
- lora_model_names[name] = name_and_hash
- legacy_model_names = res_legacy
- dump_cache()
- update_models()
|