safetensors_hack.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import io
  2. import os
  3. import mmap
  4. import torch
  5. import json
  6. import hashlib
  7. import safetensors
  8. import safetensors.torch
  9. from modules import sd_models
  10. # PyTorch 1.13 and later have _UntypedStorage renamed to UntypedStorage
  11. UntypedStorage = torch.storage.UntypedStorage if hasattr(torch.storage, 'UntypedStorage') else torch.storage._UntypedStorage
  12. def read_metadata(filename):
  13. """Reads the JSON metadata from a .safetensors file"""
  14. with open(filename, mode="r", encoding="utf8") as file_obj:
  15. with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
  16. header = m.read(8)
  17. n = int.from_bytes(header, "little")
  18. metadata_bytes = m.read(n)
  19. metadata = json.loads(metadata_bytes)
  20. return metadata.get("__metadata__", {})
  21. def load_file(filename, device):
  22. """"Loads a .safetensors file without memory mapping that locks the model file.
  23. Works around safetensors issue: https://github.com/huggingface/safetensors/issues/164"""
  24. with open(filename, mode="r", encoding="utf8") as file_obj:
  25. with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
  26. header = m.read(8)
  27. n = int.from_bytes(header, "little")
  28. metadata_bytes = m.read(n)
  29. metadata = json.loads(metadata_bytes)
  30. size = os.stat(filename).st_size
  31. storage = UntypedStorage.from_file(filename, False, size)
  32. offset = n + 8
  33. md = metadata.get("__metadata__", {})
  34. return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}, md
  35. def hash_file(filename):
  36. """Hashes a .safetensors file using the new hashing method.
  37. Only hashes the weights of the model."""
  38. hash_sha256 = hashlib.sha256()
  39. blksize = 1024 * 1024
  40. with open(filename, mode="r", encoding="utf8") as file_obj:
  41. with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
  42. header = m.read(8)
  43. n = int.from_bytes(header, "little")
  44. with open(filename, mode="rb") as file_obj:
  45. offset = n + 8
  46. file_obj.seek(offset)
  47. for chunk in iter(lambda: file_obj.read(blksize), b""):
  48. hash_sha256.update(chunk)
  49. return hash_sha256.hexdigest()
  50. def legacy_hash_file(filename):
  51. """Hashes a model file using the legacy `sd_models.model_hash()` method."""
  52. hash_sha256 = hashlib.sha256()
  53. metadata = read_metadata(filename)
  54. # For compatibility with legacy models: This replicates the behavior of
  55. # sd_models.model_hash as if there were no user-specified metadata in the
  56. # .safetensors file. That leaves the training parameters, which are
  57. # immutable. It is important the hash does not include the embedded user
  58. # metadata as that would mean the hash could change every time the user
  59. # updates the name/description/etc. The new hashing method fixes this
  60. # problem by only hashing the region of the file containing the tensors.
  61. if any(not k.startswith("ss_") for k in metadata):
  62. # Strip the user metadata, re-serialize the file as if it were freshly
  63. # created from sd-scripts, and hash that with model_hash's behavior.
  64. tensors, metadata = load_file(filename, "cpu")
  65. metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
  66. model_bytes = safetensors.torch.save(tensors, metadata)
  67. hash_sha256.update(model_bytes[0x100000:0x110000])
  68. return hash_sha256.hexdigest()[0:8]
  69. else:
  70. # This should work fine with model_hash since when the legacy hashing
  71. # method was being used the user metadata system hadn't been implemented
  72. # yet.
  73. return sd_models.model_hash(filename)
  74. DTYPES = {
  75. "F64": torch.float64,
  76. "F32": torch.float32,
  77. "F16": torch.float16,
  78. "BF16": torch.bfloat16,
  79. "I64": torch.int64,
  80. # "U64": torch.uint64,
  81. "I32": torch.int32,
  82. # "U32": torch.uint32,
  83. "I16": torch.int16,
  84. # "U16": torch.uint16,
  85. "I8": torch.int8,
  86. "U8": torch.uint8,
  87. "BOOL": torch.bool
  88. }
  89. def create_tensor(storage, info, offset):
  90. """Creates a tensor without holding on to an open handle to the parent model
  91. file."""
  92. dtype = DTYPES[info["dtype"]]
  93. shape = info["shape"]
  94. start, stop = info["data_offsets"]
  95. return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape).clone().detach()