extract_submodel.py 549 B

1234567891011121314151617
  1. import torch
  2. import sys
  3. if __name__ == "__main__":
  4. inpath = sys.argv[1]
  5. outpath = sys.argv[2]
  6. submodel = "cond_stage_model"
  7. if len(sys.argv) > 3:
  8. submodel = sys.argv[3]
  9. print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
  10. sd = torch.load(inpath, map_location="cpu")
  11. new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
  12. for k,v in sd["state_dict"].items()
  13. if k.startswith("cond_stage_model"))}
  14. torch.save(new_sd, outpath)