safe.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # this code is adapted from the script contributed by anon from /h/
  2. import io
  3. import pickle
  4. import collections
  5. import sys
  6. import traceback
  7. import torch
  8. import numpy
  9. import _codecs
  10. import zipfile
  11. import re
  12. # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
  13. TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
  14. def encode(*args):
  15. out = _codecs.encode(*args)
  16. return out
  17. class RestrictedUnpickler(pickle.Unpickler):
  18. extra_handler = None
  19. def persistent_load(self, saved_id):
  20. assert saved_id[0] == 'storage'
  21. return TypedStorage()
  22. def find_class(self, module, name):
  23. if self.extra_handler is not None:
  24. res = self.extra_handler(module, name)
  25. if res is not None:
  26. return res
  27. if module == 'collections' and name == 'OrderedDict':
  28. return getattr(collections, name)
  29. if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
  30. return getattr(torch._utils, name)
  31. if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
  32. return getattr(torch, name)
  33. if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
  34. return getattr(torch.nn.modules.container, name)
  35. if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
  36. return getattr(numpy.core.multiarray, name)
  37. if module == 'numpy' and name in ['dtype', 'ndarray']:
  38. return getattr(numpy, name)
  39. if module == '_codecs' and name == 'encode':
  40. return encode
  41. if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
  42. import pytorch_lightning.callbacks
  43. return pytorch_lightning.callbacks.model_checkpoint
  44. if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
  45. import pytorch_lightning.callbacks.model_checkpoint
  46. return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
  47. if module == "__builtin__" and name == 'set':
  48. return set
  49. # Forbid everything else.
  50. raise Exception(f"global '{module}/{name}' is forbidden")
  51. # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
  52. allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
  53. data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
  54. def check_zip_filenames(filename, names):
  55. for name in names:
  56. if allowed_zip_names_re.match(name):
  57. continue
  58. raise Exception(f"bad file inside {filename}: {name}")
  59. def check_pt(filename, extra_handler):
  60. try:
  61. # new pytorch format is a zip file
  62. with zipfile.ZipFile(filename) as z:
  63. check_zip_filenames(filename, z.namelist())
  64. # find filename of data.pkl in zip file: '<directory name>/data.pkl'
  65. data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
  66. if len(data_pkl_filenames) == 0:
  67. raise Exception(f"data.pkl not found in {filename}")
  68. if len(data_pkl_filenames) > 1:
  69. raise Exception(f"Multiple data.pkl found in {filename}")
  70. with z.open(data_pkl_filenames[0]) as file:
  71. unpickler = RestrictedUnpickler(file)
  72. unpickler.extra_handler = extra_handler
  73. unpickler.load()
  74. except zipfile.BadZipfile:
  75. # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
  76. with open(filename, "rb") as file:
  77. unpickler = RestrictedUnpickler(file)
  78. unpickler.extra_handler = extra_handler
  79. for i in range(5):
  80. unpickler.load()
  81. def load(filename, *args, **kwargs):
  82. return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
  83. def load_with_extra(filename, extra_handler=None, *args, **kwargs):
  84. """
  85. this function is intended to be used by extensions that want to load models with
  86. some extra classes in them that the usual unpickler would find suspicious.
  87. Use the extra_handler argument to specify a function that takes module and field name as text,
  88. and returns that field's value:
  89. ```python
  90. def extra(module, name):
  91. if module == 'collections' and name == 'OrderedDict':
  92. return collections.OrderedDict
  93. return None
  94. safe.load_with_extra('model.pt', extra_handler=extra)
  95. ```
  96. The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
  97. definitely unsafe.
  98. """
  99. from modules import shared
  100. try:
  101. if not shared.cmd_opts.disable_safe_unpickle:
  102. check_pt(filename, extra_handler)
  103. except pickle.UnpicklingError:
  104. print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
  105. print(traceback.format_exc(), file=sys.stderr)
  106. print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
  107. print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
  108. return None
  109. except Exception:
  110. print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
  111. print(traceback.format_exc(), file=sys.stderr)
  112. print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
  113. print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
  114. return None
  115. return unsafe_torch_load(filename, *args, **kwargs)
  116. class Extra:
  117. """
  118. A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
  119. (because it's not your code making the torch.load call). The intended use is like this:
  120. ```
  121. import torch
  122. from modules import safe
  123. def handler(module, name):
  124. if module == 'torch' and name in ['float64', 'float16']:
  125. return getattr(torch, name)
  126. return None
  127. with safe.Extra(handler):
  128. x = torch.load('model.pt')
  129. ```
  130. """
  131. def __init__(self, handler):
  132. self.handler = handler
  133. def __enter__(self):
  134. global global_extra_handler
  135. assert global_extra_handler is None, 'already inside an Extra() block'
  136. global_extra_handler = self.handler
  137. def __exit__(self, exc_type, exc_val, exc_tb):
  138. global global_extra_handler
  139. global_extra_handler = None
  140. unsafe_torch_load = torch.load
  141. torch.load = load
  142. global_extra_handler = None