extract_controlnet_diff.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import argparse
  2. import torch
  3. from safetensors.torch import load_file, save_file
  4. if __name__ == "__main__":
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument("--sd15", default=None, type=str, required=True, help="Path to the original sd15.")
  7. parser.add_argument("--control", default=None, type=str, required=True, help="Path to the sd15 with control.")
  8. parser.add_argument("--dst", default=None, type=str, required=True, help="Path to the output difference model.")
  9. parser.add_argument("--fp16", action="store_true", help="Save as fp16.")
  10. parser.add_argument("--bf16", action="store_true", help="Save as bf16.")
  11. args = parser.parse_args()
  12. assert args.sd15 is not None, "Must provide a original sd15 model path!"
  13. assert args.control is not None, "Must provide a sd15 with control model path!"
  14. assert args.dst is not None, "Must provide a output path!"
  15. # make differences: copy from https://github.com/lllyasviel/ControlNet/blob/main/tool_transfer_control.py
  16. def get_node_name(name, parent_name):
  17. if len(name) <= len(parent_name):
  18. return False, ''
  19. p = name[:len(parent_name)]
  20. if p != parent_name:
  21. return False, ''
  22. return True, name[len(parent_name):]
  23. # remove first/cond stage from sd to reduce memory usage
  24. def remove_first_and_cond(sd):
  25. keys = list(sd.keys())
  26. for key in keys:
  27. is_first_stage, _ = get_node_name(key, 'first_stage_model')
  28. is_cond_stage, _ = get_node_name(key, 'cond_stage_model')
  29. if is_first_stage or is_cond_stage:
  30. sd.pop(key, None)
  31. return sd
  32. print(f"loading: {args.sd15}")
  33. if args.sd15.endswith(".safetensors"):
  34. sd15_state_dict = load_file(args.sd15)
  35. else:
  36. sd15_state_dict = torch.load(args.sd15)
  37. sd15_state_dict = sd15_state_dict.pop("state_dict", sd15_state_dict)
  38. sd15_state_dict = remove_first_and_cond(sd15_state_dict)
  39. print(f"loading: {args.control}")
  40. if args.control.endswith(".safetensors"):
  41. control_state_dict = load_file(args.control)
  42. else:
  43. control_state_dict = torch.load(args.control)
  44. control_state_dict = remove_first_and_cond(control_state_dict)
  45. # make diff of original and control
  46. print(f"create difference")
  47. keys = list(control_state_dict.keys())
  48. final_state_dict = {"difference": torch.tensor(1.0)} # indicates difference
  49. for key in keys:
  50. p = control_state_dict.pop(key)
  51. is_control, node_name = get_node_name(key, 'control_')
  52. if not is_control:
  53. continue
  54. sd15_key_name = 'model.diffusion_' + node_name
  55. if sd15_key_name in sd15_state_dict: # part of U-Net
  56. # print("in sd15", key, sd15_key_name)
  57. p_new = p - sd15_state_dict.pop(sd15_key_name)
  58. if torch.max(torch.abs(p_new)) < 1e-6: # no difference?
  59. print("no diff", key, sd15_key_name)
  60. continue
  61. else:
  62. # print("not in sd15", key, sd15_key_name)
  63. p_new = p # hint or zero_conv
  64. final_state_dict[key] = p_new
  65. save_dtype = None
  66. if args.fp16:
  67. save_dtype = torch.float16
  68. elif args.bf16:
  69. save_dtype = torch.bfloat16
  70. if save_dtype is not None:
  71. for key in final_state_dict.keys():
  72. final_state_dict[key] = final_state_dict[key].to(save_dtype)
  73. print("saving difference.")
  74. if args.dst.endswith(".safetensors"):
  75. save_file(final_state_dict, args.dst)
  76. else:
  77. torch.save({"state_dict": final_state_dict}, args.dst)
  78. print("done!")