_weights_only_unpickler.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # mypy: allow-untyped-defs
  2. # Unpickler restricted to loading only state dicts
  3. # Restrict constructing types to a list defined in _get_allowed_globals()
  4. # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
  5. # Restrict APPEND/APPENDS to `list`
  6. # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
  7. # defined by `_get_allowed_globals()` method, that contains:
  8. # - torch types (Storage, dtypes, Tensor, `torch.Size`),
  9. # - `torch._utils._rebuild` functions.
  10. # - `torch.nn.Parameter`
  11. # - `collections.Counter`
  12. # - `collections.OrderedDict`
  13. # Additionally, users can use an allowlist for adding classes they have deemed as safe using
  14. # `_add_safe_globals()` (`torch.serialization.add_safe_globals`)
  15. # `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`)
  16. # `_get_safe_globals()` (`torch.serialization.get_safe_globals`)
  17. # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
  18. # Expected to be useful for loading PyTorch model weights
  19. # For example:
  20. # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
  21. # buf = io.BytesIO(data)
  22. # weights = torch.load(buf, weights_only = True)
  23. import functools as _functools
  24. import warnings
  25. from _codecs import encode
  26. from collections import Counter, OrderedDict
  27. from pickle import (
  28. APPEND,
  29. APPENDS,
  30. BINFLOAT,
  31. BINGET,
  32. BININT,
  33. BININT1,
  34. BININT2,
  35. BINPERSID,
  36. BINPUT,
  37. BINUNICODE,
  38. BUILD,
  39. bytes_types,
  40. decode_long,
  41. EMPTY_DICT,
  42. EMPTY_LIST,
  43. EMPTY_SET,
  44. EMPTY_TUPLE,
  45. GLOBAL,
  46. LONG1,
  47. LONG_BINGET,
  48. LONG_BINPUT,
  49. MARK,
  50. NEWFALSE,
  51. NEWOBJ,
  52. NEWTRUE,
  53. NONE,
  54. PROTO,
  55. REDUCE,
  56. SETITEM,
  57. SETITEMS,
  58. SHORT_BINSTRING,
  59. STOP,
  60. TUPLE,
  61. TUPLE1,
  62. TUPLE2,
  63. TUPLE3,
  64. UnpicklingError,
  65. )
  66. from struct import unpack
  67. from sys import maxsize
  68. from typing import Any, Callable, Union
  69. import torch
  70. from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING
  71. # modules in this list are never allowed, even if the user attempts to allowlist
  72. # functions/classes from them
  73. _blocklisted_modules = [
  74. "sys",
  75. "os",
  76. "posix",
  77. "nt",
  78. ]
  79. _marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set()
  80. def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]):
  81. global _marked_safe_globals_set
  82. _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
  83. def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
  84. global _marked_safe_globals_set
  85. return list(_marked_safe_globals_set)
  86. def _clear_safe_globals():
  87. global _marked_safe_globals_set
  88. _marked_safe_globals_set = set()
  89. def _remove_safe_globals(
  90. globals_to_remove: list[Union[Callable, tuple[Callable, str]]],
  91. ):
  92. global _marked_safe_globals_set
  93. _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
  94. class _safe_globals:
  95. def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]):
  96. self.safe_globals = safe_globals
  97. def __enter__(self):
  98. _add_safe_globals(self.safe_globals)
  99. def __exit__(self, type, value, tb):
  100. _remove_safe_globals(self.safe_globals)
  101. # Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
  102. # For example if user had a script like
  103. # torch.load(file_a)
  104. # torch.serialization._add_safe_globals([torch.foo])
  105. # torch.load(file_b)
  106. # the dynamic additions to safe_globals would not be picked up by
  107. # _get_allowed_globals due to the lru_cache
  108. def _get_user_allowed_globals():
  109. rc: dict[str, Any] = {}
  110. for f in _marked_safe_globals_set:
  111. if isinstance(f, tuple):
  112. if len(f) != 2:
  113. raise ValueError(
  114. f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}"
  115. )
  116. if type(f[1]) is not str:
  117. raise TypeError(
  118. f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}"
  119. )
  120. f, name = f
  121. rc[name] = f
  122. else:
  123. module, name = f.__module__, f.__qualname__
  124. rc[f"{module}.{name}"] = f
  125. return rc
  126. def _tensor_rebuild_functions():
  127. return {
  128. torch._utils._rebuild_parameter,
  129. torch._utils._rebuild_parameter_with_state,
  130. torch._utils._rebuild_qtensor,
  131. torch._utils._rebuild_tensor,
  132. torch._utils._rebuild_tensor_v2,
  133. torch._utils._rebuild_tensor_v3,
  134. torch._utils._rebuild_sparse_tensor,
  135. torch._utils._rebuild_meta_tensor_no_storage,
  136. torch._utils._rebuild_nested_tensor,
  137. torch._utils._rebuild_wrapper_subclass,
  138. # Allowlisting this, but not allowlisting the numpy functions by default
  139. # Reasoning is that we don't have control over the numpy functions, but
  140. # this utility is provided by pytorch
  141. torch._utils._rebuild_device_tensor_from_numpy,
  142. # In 2.6, we should no longer have a dependency on numpy and the above
  143. # _rebuild_device_tensor_from_numpy function.
  144. torch._utils._rebuild_device_tensor_from_cpu_tensor,
  145. }
  146. # Unpickling machinery
  147. @_functools.lru_cache(maxsize=1)
  148. def _get_allowed_globals():
  149. rc: dict[str, Any] = {
  150. "collections.OrderedDict": OrderedDict,
  151. "collections.Counter": Counter,
  152. "torch.nn.parameter.Parameter": torch.nn.Parameter,
  153. "torch.serialization._get_layout": torch.serialization._get_layout,
  154. "torch.Size": torch.Size,
  155. "torch.Tensor": torch.Tensor,
  156. "torch.device": torch.device,
  157. "_codecs.encode": encode, # for bytes
  158. "builtins.bytearray": bytearray, # for bytearray
  159. "builtins.set": set, # for set
  160. "builtins.complex": complex, # for complex
  161. }
  162. # dtype
  163. for t in torch.storage._dtype_to_storage_type_map().keys():
  164. rc[str(t)] = t
  165. for t in torch.storage._new_dtypes():
  166. rc[str(t)] = t
  167. for t in [getattr(torch, f"uint{x}") for x in range(1, 8)]:
  168. rc[str(t)] = t
  169. for t in [getattr(torch, f"int{x}") for x in range(1, 8)]:
  170. rc[str(t)] = t
  171. # Tensor classes
  172. for tt in torch._tensor_classes:
  173. rc[f"{tt.__module__}.{tt.__name__}"] = tt
  174. # Storage classes
  175. for ts in torch._storage_classes:
  176. if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
  177. # Wrap legacy storage types in a dummy class
  178. rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
  179. ts.__name__
  180. )
  181. else:
  182. rc[f"{ts.__module__}.{ts.__name__}"] = ts
  183. # Quantization specific
  184. for qt in [
  185. torch.per_tensor_affine,
  186. torch.per_tensor_symmetric,
  187. torch.per_channel_affine,
  188. torch.per_channel_symmetric,
  189. torch.per_channel_affine_float_qparams,
  190. ]:
  191. rc[str(qt)] = qt
  192. # Rebuild functions
  193. for f in _tensor_rebuild_functions():
  194. rc[f"torch._utils.{f.__name__}"] = f
  195. # Handles Tensor Subclasses, Tensor's with attributes.
  196. # NOTE: It calls into above rebuild functions for regular Tensor types.
  197. rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
  198. return rc
  199. def _read_global_instruction(readline: Callable) -> tuple[str, str]:
  200. module = readline()[:-1].decode("utf-8")
  201. name = readline()[:-1].decode("utf-8")
  202. # Patch since torch.save default protocol is 2
  203. # users will be running this code in python > 3
  204. if (module, name) in NAME_MAPPING:
  205. module, name = NAME_MAPPING[(module, name)]
  206. elif module in IMPORT_MAPPING:
  207. module = IMPORT_MAPPING[module]
  208. return module, name
  209. def get_globals_in_pkl(file) -> set[str]:
  210. globals_in_checkpoint = set()
  211. read = file.read
  212. readline = file.readline
  213. op_to_bytes_to_read = {
  214. NEWOBJ[0]: 0,
  215. REDUCE[0]: 0,
  216. BUILD[0]: 0,
  217. APPEND[0]: 0,
  218. APPENDS[0]: 0,
  219. SETITEM[0]: 0,
  220. SETITEMS[0]: 0,
  221. MARK[0]: 0,
  222. TUPLE[0]: 0,
  223. TUPLE1[0]: 0,
  224. TUPLE2[0]: 0,
  225. TUPLE3[0]: 0,
  226. NONE[0]: 0,
  227. NEWFALSE[0]: 0,
  228. NEWTRUE[0]: 0,
  229. EMPTY_TUPLE[0]: 0,
  230. EMPTY_LIST[0]: 0,
  231. EMPTY_DICT[0]: 0,
  232. EMPTY_SET[0]: 0,
  233. BINPERSID[0]: 0,
  234. BININT[0]: 4,
  235. BININT1[0]: 1,
  236. BININT2[0]: 2,
  237. BINFLOAT[0]: 8,
  238. BINGET[0]: 1,
  239. LONG_BINGET[0]: 4,
  240. BINPUT[0]: 1,
  241. LONG_BINPUT[0]: 4,
  242. }
  243. while True:
  244. key = read(1)
  245. if not key:
  246. raise EOFError
  247. assert isinstance(key, bytes_types)
  248. if key[0] == GLOBAL[0]:
  249. module, name = _read_global_instruction(readline)
  250. globals_in_checkpoint.add(f"{module}.{name}")
  251. elif key[0] in op_to_bytes_to_read:
  252. bytes_to_read = op_to_bytes_to_read[key[0]]
  253. if bytes_to_read:
  254. read(bytes_to_read)
  255. # ops where bytes to read depends on the data
  256. elif key[0] == BINUNICODE[0]:
  257. strlen = unpack("<I", read(4))[0]
  258. if strlen > maxsize:
  259. raise UnpicklingError("String is too long")
  260. read(strlen)
  261. elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}:
  262. strlen = read(1)[0]
  263. read(strlen)
  264. # first and last op
  265. elif key[0] == PROTO[0]:
  266. read(1)[0]
  267. elif key[0] == STOP[0]:
  268. return globals_in_checkpoint
  269. else:
  270. raise UnpicklingError(f"Unsupported operand {key[0]}")
  271. class Unpickler:
  272. def __init__(self, file, *, encoding: str = "bytes"):
  273. self.encoding = encoding
  274. self.readline = file.readline
  275. self.read = file.read
  276. self.memo: dict[int, Any] = {}
  277. self.proto: int = -1
  278. def load(self):
  279. """Read a pickled object representation from the open file.
  280. Return the reconstituted object hierarchy specified in the file.
  281. """
  282. self.metastack = []
  283. self.stack: list[Any] = []
  284. self.append = self.stack.append
  285. read = self.read
  286. while True:
  287. key = read(1)
  288. if not key:
  289. raise EOFError
  290. assert isinstance(key, bytes_types)
  291. # Risky operators
  292. if key[0] == GLOBAL[0]:
  293. module, name = _read_global_instruction(self.readline)
  294. full_path = f"{module}.{name}"
  295. if module in _blocklisted_modules:
  296. raise UnpicklingError(
  297. f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
  298. )
  299. if full_path in _get_allowed_globals():
  300. self.append(_get_allowed_globals()[full_path])
  301. elif full_path in _get_user_allowed_globals():
  302. self.append(_get_user_allowed_globals()[full_path])
  303. elif full_path in (
  304. [
  305. "torch.nested._internal.nested_tensor.NestedTensor",
  306. "torch.nested._internal.nested_tensor._rebuild_njt",
  307. "torch._dynamo.decorators._DimRange",
  308. ]
  309. ):
  310. raise UnpicklingError(
  311. "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
  312. )
  313. elif full_path in (
  314. [
  315. "torch.distributed.device_mesh.DeviceMesh",
  316. "torch.distributed.tensor._dtensor_spec.DTensorSpec",
  317. "torch.distributed.tensor._dtensor_spec.TensorMeta",
  318. "torch.distributed.tensor.DTensor",
  319. "torch.distributed.tensor.placement_types.Partial",
  320. "torch.distributed.tensor.placement_types.Replicate",
  321. "torch.distributed.tensor.placement_types.Shard",
  322. ]
  323. ):
  324. raise UnpicklingError(
  325. "``torch.distributed.tensor`` must be imported to load DTensors"
  326. )
  327. else:
  328. builtins_name = "builtins"
  329. if (
  330. builtins_name in full_path
  331. and builtins_name == full_path[: len(builtins_name)]
  332. ):
  333. full_path = full_path[len(builtins_name) :]
  334. full_path = (
  335. full_path[1:]
  336. if len(full_path) > 0 and full_path[0] == "."
  337. else builtins_name + full_path
  338. )
  339. raise UnpicklingError(
  340. f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
  341. f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the "
  342. f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global "
  343. "if you trust this class/function."
  344. )
  345. elif key[0] == NEWOBJ[0]:
  346. args = self.stack.pop()
  347. cls = self.stack.pop()
  348. if cls is torch.nn.Parameter:
  349. self.append(torch.nn.Parameter(*args))
  350. elif (
  351. cls in _get_user_allowed_globals().values()
  352. or cls in _get_allowed_globals().values()
  353. ):
  354. result = cls.__new__(cls, *args)
  355. if cls in torch._tensor_classes and "sparse" in cls.__module__:
  356. _sparse_tensors_to_validate.append(result)
  357. self.append(result)
  358. else:
  359. raise UnpicklingError(
  360. "Can only create new object for nn.Parameter or classes allowlisted "
  361. f"via `add_safe_globals` but got {cls}"
  362. )
  363. elif key[0] == REDUCE[0]:
  364. args = self.stack.pop()
  365. func = self.stack[-1]
  366. if (
  367. func not in _get_allowed_globals().values()
  368. and func not in _get_user_allowed_globals().values()
  369. ):
  370. error_msg = (
  371. f"Trying to call reduce for unrecognized function {func}"
  372. )
  373. if hasattr(func, "__self__"):
  374. error_msg += f" which belongs to {func.__self__}"
  375. raise UnpicklingError(error_msg)
  376. result = func(*args)
  377. if func in torch._tensor_classes and "sparse" in func.__module__:
  378. _sparse_tensors_to_validate.append(result)
  379. self.stack[-1] = result
  380. elif key[0] == BUILD[0]:
  381. state = self.stack.pop()
  382. inst = self.stack[-1]
  383. if type(inst) is torch.Tensor:
  384. # Legacy unpickling
  385. inst.set_(*state)
  386. elif type(inst) is torch.nn.Parameter:
  387. inst.__setstate__(state)
  388. elif type(inst) is OrderedDict:
  389. inst.__dict__.update(state)
  390. elif (
  391. type(inst) in _get_user_allowed_globals().values()
  392. or type(inst) in _get_allowed_globals().values()
  393. ):
  394. if hasattr(inst, "__setstate__"):
  395. inst.__setstate__(state)
  396. else:
  397. # mimics load_build in pickle
  398. # https://github.com/python/cpython/blob/f0c6fccd08904787a39269367f09f263d496114c/Lib/pickle.py#L1854-L1867
  399. slotstate = None
  400. if isinstance(state, tuple) and len(state) == 2:
  401. state, slotstate = state
  402. if state:
  403. inst.__dict__.update(state)
  404. if slotstate:
  405. for k, v in slotstate.items():
  406. setattr(inst, k, v)
  407. else:
  408. raise UnpicklingError(
  409. "Can only build Tensor, Parameter, OrderedDict or types allowlisted "
  410. f"via `add_safe_globals`, but got {type(inst)}"
  411. )
  412. # Stack manipulation
  413. elif key[0] == APPEND[0]:
  414. item = self.stack.pop()
  415. list_obj = self.stack[-1]
  416. if type(list_obj) is not list:
  417. raise UnpicklingError(
  418. f"Can only append to lists, but got {type(list_obj)}"
  419. )
  420. list_obj.append(item)
  421. elif key[0] == APPENDS[0]:
  422. items = self.pop_mark()
  423. list_obj = self.stack[-1]
  424. if type(list_obj) is not list:
  425. raise UnpicklingError(
  426. f"Can only extend lists, but got {type(list_obj)}"
  427. )
  428. list_obj.extend(items)
  429. elif key[0] == SETITEM[0]:
  430. (v, k) = (self.stack.pop(), self.stack.pop())
  431. self.stack[-1][k] = v
  432. elif key[0] == SETITEMS[0]:
  433. items = self.pop_mark()
  434. for i in range(0, len(items), 2):
  435. self.stack[-1][items[i]] = items[i + 1]
  436. elif key[0] == MARK[0]:
  437. self.metastack.append(self.stack)
  438. self.stack = []
  439. self.append = self.stack.append
  440. elif key[0] == TUPLE[0]:
  441. items = self.pop_mark()
  442. self.append(tuple(items))
  443. elif key[0] == TUPLE1[0]:
  444. self.stack[-1] = (self.stack[-1],)
  445. elif key[0] == TUPLE2[0]:
  446. self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
  447. elif key[0] == TUPLE3[0]:
  448. self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
  449. # Basic types construction
  450. elif key[0] == NONE[0]:
  451. self.append(None)
  452. elif key[0] == NEWFALSE[0]:
  453. self.append(False)
  454. elif key[0] == NEWTRUE[0]:
  455. self.append(True)
  456. elif key[0] == EMPTY_TUPLE[0]:
  457. self.append(())
  458. elif key[0] == EMPTY_LIST[0]:
  459. self.append([])
  460. elif key[0] == EMPTY_DICT[0]:
  461. self.append({})
  462. elif key[0] == EMPTY_SET[0]:
  463. self.append(set())
  464. elif key[0] == BININT[0]:
  465. self.append(unpack("<i", read(4))[0])
  466. elif key[0] == BININT1[0]:
  467. self.append(self.read(1)[0])
  468. elif key[0] == BININT2[0]:
  469. self.append(unpack("<H", read(2))[0])
  470. elif key[0] == BINFLOAT[0]:
  471. self.append(unpack(">d", self.read(8))[0])
  472. elif key[0] == BINUNICODE[0]:
  473. strlen = unpack("<I", read(4))[0]
  474. if strlen > maxsize:
  475. raise UnpicklingError("String is too long")
  476. strval = str(read(strlen), "utf-8", "surrogatepass")
  477. self.append(strval)
  478. elif key[0] == SHORT_BINSTRING[0]:
  479. strlen = read(1)[0]
  480. strdata = read(strlen)
  481. if self.encoding != "bytes":
  482. strdata = strdata.decode(self.encoding, "strict")
  483. self.append(strdata)
  484. elif key[0] == BINPERSID[0]:
  485. pid = self.stack.pop()
  486. # Only allow persistent load of storage
  487. if type(pid) is not tuple and type(pid) is not int:
  488. raise UnpicklingError(
  489. f"persistent_load id must be tuple or int, but got {type(pid)}"
  490. )
  491. if (
  492. type(pid) is tuple
  493. and len(pid) > 0
  494. and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
  495. ):
  496. raise UnpicklingError(
  497. f"Only persistent_load of storage is allowed, but got {pid[0]}"
  498. )
  499. self.append(self.persistent_load(pid))
  500. elif key[0] in [BINGET[0], LONG_BINGET[0]]:
  501. idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
  502. self.append(self.memo[idx])
  503. elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
  504. i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
  505. if i < 0:
  506. raise ValueError("negative argument")
  507. self.memo[i] = self.stack[-1]
  508. elif key[0] == LONG1[0]:
  509. n = read(1)[0]
  510. data = read(n)
  511. self.append(decode_long(data))
  512. # First and last deserializer ops
  513. elif key[0] == PROTO[0]:
  514. self.proto = read(1)[0]
  515. if self.proto != 2:
  516. warnings.warn(
  517. f"Detected pickle protocol {self.proto} in the checkpoint, which was "
  518. "not the default pickle protocol used by `torch.load` (2). The weights_only "
  519. "Unpickler might not support all instructions implemented by this protocol, "
  520. "please file an issue for adding support if you encounter this."
  521. )
  522. elif key[0] == STOP[0]:
  523. rc = self.stack.pop()
  524. return rc
  525. else:
  526. raise UnpicklingError(f"Unsupported operand {key[0]}")
  527. # Return a list of items pushed in the stack after last MARK instruction.
  528. def pop_mark(self):
  529. items = self.stack
  530. self.stack = self.metastack.pop()
  531. self.append = self.stack.append
  532. return items
  533. def persistent_load(self, pid):
  534. raise UnpicklingError("unsupported persistent id encountered")
  535. def load(file, *, encoding: str = "ASCII"):
  536. return Unpickler(file, encoding=encoding).load()