model_util.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import os
  2. import os.path
  3. import re
  4. import shutil
  5. import json
  6. import stat
  7. import tqdm
  8. from collections import OrderedDict
  9. from multiprocessing.pool import ThreadPool as Pool
  10. from modules import shared, sd_models, hashes
  11. from scripts import safetensors_hack, model_util, util
  12. import modules.scripts as scripts
  13. # MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count or 5
  14. MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count if hasattr(shared.cmd_opts, "addnet_max_model_count") else 5
  15. LORA_MODEL_EXTS = [".pt", ".ckpt", ".safetensors"]
  16. re_legacy_hash = re.compile("\(([0-9a-f]{8})\)$") # matches 8-character hashes, new hash has 12 characters
  17. lora_models = {} # "My_Lora(abcdef123456)" -> "C:/path/to/model.safetensors"
  18. lora_model_names = {} # "my_lora" -> "My_Lora(My_Lora(abcdef123456)"
  19. legacy_model_names = {}
  20. lora_models_dir = os.path.join(scripts.basedir(), "models/lora")
  21. os.makedirs(lora_models_dir, exist_ok=True)
  22. def is_safetensors(filename):
  23. return os.path.splitext(filename)[1] == ".safetensors"
  24. def read_model_metadata(model_path, module):
  25. if model_path.startswith("\"") and model_path.endswith("\""): # trim '"' at start/end
  26. model_path = model_path[1:-1]
  27. if not os.path.exists(model_path):
  28. return None
  29. metadata = None
  30. if module == "LoRA":
  31. if os.path.splitext(model_path)[1] == '.safetensors':
  32. metadata = safetensors_hack.read_metadata(model_path)
  33. return metadata
  34. def write_model_metadata(model_path, module, updates):
  35. if model_path.startswith("\"") and model_path.endswith("\""): # trim '"' at start/end
  36. model_path = model_path[1:-1]
  37. if not os.path.exists(model_path):
  38. return None
  39. from safetensors.torch import save_file
  40. back_up = shared.opts.data.get("additional_networks_back_up_model_when_saving", True)
  41. if back_up:
  42. backup_path = model_path + ".backup"
  43. if not os.path.exists(backup_path):
  44. print(f"[MetadataEditor] Backing up current model to {backup_path}")
  45. shutil.copyfile(model_path, backup_path)
  46. metadata = None
  47. tensors = {}
  48. if module == "LoRA":
  49. if os.path.splitext(model_path)[1] == '.safetensors':
  50. tensors, metadata = safetensors_hack.load_file(model_path, "cpu")
  51. for k, v in updates.items():
  52. metadata[k] = str(v)
  53. save_file(tensors, model_path, metadata)
  54. print(f"[MetadataEditor] Model saved: {model_path}")
  55. def get_model_list(module, model, model_dir, sort_by):
  56. if model_dir == "":
  57. # Get list of models with same folder as this one
  58. model_path = lora_models.get(model, None)
  59. if model_path is None:
  60. return []
  61. model_dir = os.path.dirname(model_path)
  62. if not os.path.isdir(model_dir):
  63. return []
  64. found, _ = get_all_models([model_dir], sort_by, "")
  65. return found.keys()
  66. def traverse_all_files(curr_path, model_list):
  67. f_list = [(os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path)]
  68. for f_info in f_list:
  69. fname, fstat = f_info
  70. if os.path.splitext(fname)[1] in LORA_MODEL_EXTS:
  71. model_list.append(f_info)
  72. elif stat.S_ISDIR(fstat.st_mode):
  73. model_list = traverse_all_files(fname, model_list)
  74. return model_list
  75. def get_model_hash(metadata, filename):
  76. if metadata is None:
  77. return hashes.calculate_sha256(filename)
  78. if "sshs_model_hash" in metadata:
  79. return metadata["sshs_model_hash"]
  80. return safetensors_hack.hash_file(filename)
  81. def get_legacy_hash(metadata, filename):
  82. if metadata is None:
  83. return sd_models.model_hash(filename)
  84. if "sshs_legacy_hash" in metadata:
  85. return metadata["sshs_legacy_hash"]
  86. return safetensors_hack.legacy_hash_file(filename)
  87. import filelock
  88. cache_filename = os.path.join(scripts.basedir(), "hashes.json")
  89. cache_data = None
  90. def cache(subsection):
  91. global cache_data
  92. if cache_data is None:
  93. with filelock.FileLock(cache_filename+".lock"):
  94. if not os.path.isfile(cache_filename):
  95. cache_data = {}
  96. else:
  97. with open(cache_filename, "r", encoding="utf8") as file:
  98. cache_data = json.load(file)
  99. s = cache_data.get(subsection, {})
  100. cache_data[subsection] = s
  101. return s
  102. def dump_cache():
  103. with filelock.FileLock(cache_filename+".lock"):
  104. with open(cache_filename, "w", encoding="utf8") as file:
  105. json.dump(cache_data, file, indent=4)
  106. def get_model_rating(filename):
  107. if not model_util.is_safetensors(filename):
  108. return 0
  109. metadata = safetensors_hack.read_metadata(filename)
  110. return int(metadata.get("ssmd_rating", "0"))
  111. def has_user_metadata(filename):
  112. if not model_util.is_safetensors(filename):
  113. return False
  114. metadata = safetensors_hack.read_metadata(filename)
  115. return any(k.startswith("ssmd_") for k in metadata.keys())
  116. def hash_model_file(finfo):
  117. filename = finfo[0]
  118. stat = finfo[1]
  119. name = os.path.splitext(os.path.basename(filename))[0]
  120. # Prevent a hypothetical "None.pt" from being listed.
  121. if name != "None":
  122. metadata = None
  123. cached = cache("hashes").get(filename, None)
  124. if cached is None or stat.st_mtime != cached["mtime"]:
  125. if metadata is None and model_util.is_safetensors(filename):
  126. try:
  127. metadata = safetensors_hack.read_metadata(filename)
  128. except Exception as ex:
  129. return {"error": ex, "filename": filename}
  130. model_hash = get_model_hash(metadata, filename)
  131. legacy_hash = get_legacy_hash(metadata, filename)
  132. else:
  133. model_hash = cached["model"]
  134. legacy_hash = cached["legacy"]
  135. return {"model": model_hash, "legacy": legacy_hash, "fileinfo": finfo}
  136. def get_all_models(paths, sort_by, filter_by):
  137. fileinfos = []
  138. for path in paths:
  139. if os.path.isdir(path):
  140. fileinfos += traverse_all_files(path, [])
  141. show_only_safetensors = shared.opts.data.get("additional_networks_show_only_safetensors", False)
  142. show_only_missing_meta = shared.opts.data.get("additional_networks_show_only_models_with_metadata", "disabled")
  143. if show_only_safetensors:
  144. fileinfos = [x for x in fileinfos if is_safetensors(x[0])]
  145. if show_only_missing_meta == "has metadata":
  146. fileinfos = [x for x in fileinfos if has_user_metadata(x[0])]
  147. elif show_only_missing_meta == "missing metadata":
  148. fileinfos = [x for x in fileinfos if not has_user_metadata(x[0])]
  149. print("[AddNet] Updating model hashes...")
  150. data = []
  151. thread_count = max(1, int(shared.opts.data.get("additional_networks_hash_thread_count", 1)))
  152. p = Pool(processes=thread_count)
  153. with tqdm.tqdm(total=len(fileinfos)) as pbar:
  154. for res in p.imap_unordered(hash_model_file, fileinfos):
  155. pbar.update()
  156. if "error" in res:
  157. print(f"Failed to read model file {res['filename']}: {res['error']}")
  158. else:
  159. data.append(res)
  160. p.close()
  161. cache_hashes = cache("hashes")
  162. res = OrderedDict()
  163. res_legacy = OrderedDict()
  164. filter_by = filter_by.strip(" ")
  165. if len(filter_by) != 0:
  166. data = [x for x in data if filter_by.lower() in os.path.basename(x["fileinfo"][0]).lower()]
  167. if sort_by == "name":
  168. data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0]))
  169. elif sort_by == "date":
  170. data = sorted(data, key=lambda x: -x["fileinfo"][1].st_mtime)
  171. elif sort_by == "path name":
  172. data = sorted(data, key=lambda x: x["fileinfo"][0])
  173. elif sort_by == "rating":
  174. data = sorted(data, key=lambda x: get_model_rating(x["fileinfo"][0]), reverse=True)
  175. elif sort_by == "has user metadata":
  176. data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0]) if has_user_metadata(x["fileinfo"][0]) else "", reverse=True)
  177. reverse = shared.opts.data.get("additional_networks_reverse_sort_order", False)
  178. if reverse:
  179. data = reversed(data)
  180. for result in data:
  181. finfo = result["fileinfo"]
  182. filename = finfo[0]
  183. stat = finfo[1]
  184. model_hash = result["model"]
  185. legacy_hash = result["legacy"]
  186. name = os.path.splitext(os.path.basename(filename))[0]
  187. # Commas in the model name will mess up infotext restoration since the
  188. # infotext is delimited by commas
  189. name = name.replace(",", "_")
  190. # Prevent a hypothetical "None.pt" from being listed.
  191. if name != "None":
  192. full_name = name + f"({model_hash[0:12]})"
  193. res[full_name] = filename
  194. res_legacy[legacy_hash] = full_name
  195. cache_hashes[filename] = {"model": model_hash, "legacy": legacy_hash, "mtime": stat.st_mtime}
  196. return res, res_legacy
  197. def find_closest_lora_model_name(search: str):
  198. if not search or search == "None":
  199. return None
  200. # Match name and hash, case-sensitive
  201. # "MyModel-epoch00002(abcdef123456)"
  202. if search in lora_models:
  203. return search
  204. # Match model path, case-sensitive (from metadata editor)
  205. # "C:/path/to/mymodel-epoch00002.safetensors"
  206. if os.path.isfile(search):
  207. import json
  208. find = os.path.normpath(search)
  209. value = next((k for k in lora_models.keys() if lora_models[k] == find), None)
  210. if value:
  211. return value
  212. search = search.lower()
  213. # Match full name, case-insensitive
  214. # "mymodel-epoch00002"
  215. if search in lora_model_names:
  216. return lora_model_names.get(search)
  217. # Match legacy hash (8 characters)
  218. # "MyModel(abcd1234)"
  219. result = re_legacy_hash.search(search)
  220. if result is not None:
  221. model_hash = result.group(1)
  222. if model_hash in legacy_model_names:
  223. new_model_name = legacy_model_names[model_hash]
  224. return new_model_name
  225. # Use any model with the search term as the prefix, case-insensitive, sorted
  226. # by name length
  227. # "mymodel"
  228. applicable = [name for name in lora_model_names.keys() if search in name.lower()]
  229. if not applicable:
  230. return None
  231. applicable = sorted(applicable, key=lambda name: len(name))
  232. return lora_model_names[applicable[0]]
  233. def update_models():
  234. global lora_models, lora_model_names, legacy_model_names
  235. paths = [lora_models_dir]
  236. extra_lora_paths = util.split_path_list(shared.opts.data.get("additional_networks_extra_lora_path", ""))
  237. for path in extra_lora_paths:
  238. path = path.lstrip()
  239. if os.path.isdir(path):
  240. paths.append(path)
  241. sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name")
  242. filter_by = shared.opts.data.get("additional_networks_model_name_filter", "")
  243. res, res_legacy = get_all_models(paths, sort_by, filter_by)
  244. lora_models.clear()
  245. lora_models["None"] = None
  246. lora_models.update(res)
  247. for name_and_hash, filename in lora_models.items():
  248. if filename == None:
  249. continue
  250. name = os.path.splitext(os.path.basename(filename))[0].lower()
  251. lora_model_names[name] = name_and_hash
  252. legacy_model_names = res_legacy
  253. dump_cache()
  254. update_models()