model.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import re
  3. import json
  4. from . import util
  5. from . import libdata
  6. from modules import shared
  7. source_filename = "model"
  8. def get_db_models():
  9. rgx = re.compile(r"\[.*\]")
  10. output = [""]
  11. try:
  12. out_dir = libdata.dreambooth_models_path
  13. if os.path.exists(out_dir):
  14. for item in os.listdir(out_dir):
  15. check_path = os.path.join(out_dir, item)
  16. if os.path.isdir(check_path) and not rgx.search(item):
  17. json_path = os.path.join(check_path, libdata.dreambooth_setting_file_name)
  18. if not os.path.isfile(json_path):
  19. continue
  20. output.append(item)
  21. except Exception:
  22. pass
  23. return output
  24. def get_db_model_setting(model_name):
  25. try:
  26. model_path = os.path.join(libdata.dreambooth_models_path, model_name, libdata.dreambooth_setting_file_name)
  27. return load_model_info(model_path)
  28. except Exception as e1:
  29. return
  30. def get_custom_model_folder():
  31. """load model folder by user setting"""
  32. util.console.log("Get Custom Model Folder")
  33. if shared.cmd_opts.embeddings_dir and os.path.isdir(shared.cmd_opts.embeddings_dir):
  34. libdata.folders["ti"] = shared.cmd_opts.embeddings_dir
  35. if shared.cmd_opts.hypernetwork_dir and os.path.isdir(shared.cmd_opts.hypernetwork_dir):
  36. libdata.folders["hyper"] = shared.cmd_opts.hypernetwork_dir
  37. if shared.cmd_opts.ckpt_dir and os.path.isdir(shared.cmd_opts.ckpt_dir):
  38. libdata.folders["ckp"] = shared.cmd_opts.ckpt_dir
  39. if hasattr(shared.cmd_opts, "lora_dir"):
  40. if shared.cmd_opts.lora_dir and os.path.isdir(shared.cmd_opts.lora_dir):
  41. libdata.folders["lora"] = shared.cmd_opts.lora_dir
  42. if hasattr(shared.cmd_opts, "lyco_dir"):
  43. if shared.cmd_opts.lyco_dir and os.path.isdir(shared.cmd_opts.lyco_dir):
  44. libdata.folders["lyco"] = shared.cmd_opts.lyco_dir
  45. def write_model_info(path, model_info):
  46. """write model JSON data
  47. Parameters
  48. ----------
  49. path
  50. file path to write
  51. model_info
  52. data to write
  53. """
  54. util.console.log("Write model info to file: " + path)
  55. with open(os.path.realpath(path), 'w') as f:
  56. f.write(json.dumps(model_info, indent=4))
  57. def load_model_info(path):
  58. """load model JSON data
  59. Parameters
  60. ----------
  61. path
  62. file path to load
  63. Returns
  64. -------
  65. JSON
  66. loadded JSON data
  67. """
  68. model_info = None
  69. try:
  70. with open(os.path.realpath(path), 'r') as f:
  71. try:
  72. model_info = json.load(f)
  73. except Exception as e:
  74. util.console.error("Selected file is not json: " + path, f"{source_filename}.load_model_info")
  75. util.console.log(e)
  76. return
  77. except Exception as e1:
  78. util.console.error("file not found: " + path, f"{source_filename}.load_model_info")
  79. return
  80. return model_info