serialization.py 83 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133
  1. # mypy: allow-untyped-defs
  2. import copyreg
  3. import difflib
  4. import functools
  5. import io
  6. import os
  7. import pickle
  8. import re
  9. import shutil
  10. import struct
  11. import sys
  12. import tarfile
  13. import tempfile
  14. import threading
  15. import warnings
  16. from contextlib import closing, contextmanager
  17. from enum import Enum
  18. from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
  19. from typing_extensions import TypeAlias, TypeIs
  20. import torch
  21. import torch._weights_only_unpickler as _weights_only_unpickler
  22. from torch._sources import get_source_lines_and_file
  23. from torch._utils import _import_dotted_name
  24. from torch.storage import _get_dtype_from_pickle_storage_type
  25. from torch.types import FileLike, Storage
  26. __all__ = [
  27. "SourceChangeWarning",
  28. "mkdtemp",
  29. "register_package",
  30. "check_module_version_greater_or_equal",
  31. "validate_cuda_device",
  32. "validate_hpu_device",
  33. "location_tag",
  34. "default_restore_location",
  35. "normalize_storage_type",
  36. "storage_to_tensor_type",
  37. "save",
  38. "load",
  39. "StorageType",
  40. "LoadEndianness",
  41. "get_crc32_options",
  42. "set_crc32_options",
  43. "get_default_load_endianness",
  44. "set_default_load_endianness",
  45. "get_default_mmap_options",
  46. "set_default_mmap_options",
  47. "clear_safe_globals",
  48. "get_safe_globals",
  49. "add_safe_globals",
  50. "safe_globals",
  51. "get_unsafe_globals_in_checkpoint",
  52. "skip_data",
  53. ]
  54. DEFAULT_PROTOCOL = 2
  55. LONG_SIZE = struct.Struct("=l").size
  56. INT_SIZE = struct.Struct("=i").size
  57. SHORT_SIZE = struct.Struct("=h").size
  58. MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
  59. PROTOCOL_VERSION = 1001
  60. STORAGE_KEY_SEPARATOR = ","
  61. MAP_LOCATION: TypeAlias = Optional[
  62. Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]]
  63. ]
  64. STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
  65. IS_WINDOWS = sys.platform == "win32"
  66. UNSAFE_MESSAGE = (
  67. "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
  68. "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
  69. "but it can result in arbitrary code execution. Do it only if you got the file from a "
  70. "trusted source."
  71. )
  72. if not IS_WINDOWS:
  73. from mmap import MAP_PRIVATE, MAP_SHARED
  74. else:
  75. MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
  76. def _default_to_weights_only(pickle_module):
  77. is_fbcode = not hasattr(torch.version, "git_version")
  78. return pickle_module is None and not is_fbcode
  79. # _serialization_tls is used to store thread local state specific to serialization
  80. # that needs to be propagated to other files, in particular we use this for
  81. # (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
  82. # (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
  83. # (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
  84. class _SerializationLocal(threading.local):
  85. def __init__(self):
  86. super().__init__()
  87. self.map_location: Optional[MAP_LOCATION] = None
  88. self.skip_data: bool = False
  89. self.materialize_fake_tensors: bool = False
  90. _serialization_tls = _SerializationLocal()
  91. class SourceChangeWarning(Warning):
  92. pass
  93. @contextmanager
  94. def mkdtemp():
  95. path = tempfile.mkdtemp()
  96. try:
  97. yield path
  98. finally:
  99. shutil.rmtree(path)
  100. _package_registry: list[
  101. tuple[
  102. int,
  103. Callable[[STORAGE], Optional[str]],
  104. Callable[[STORAGE, str], Optional[STORAGE]],
  105. ]
  106. ] = []
  107. class LoadEndianness(Enum):
  108. NATIVE = 1
  109. LITTLE = 2
  110. BIG = 3
  111. def get_default_load_endianness() -> Optional[LoadEndianness]:
  112. """
  113. Get fallback byte order for loading files
  114. If byteorder mark is not present in saved checkpoint,
  115. this byte order is used as fallback.
  116. By default, it's "native" byte order.
  117. Returns:
  118. default_load_endian: Optional[LoadEndianness]
  119. """
  120. from torch.utils.serialization import config
  121. return config.load.endianness
  122. def set_default_load_endianness(endianness):
  123. """
  124. Set fallback byte order for loading files
  125. If byteorder mark is not present in saved checkpoint,
  126. this byte order is used as fallback.
  127. By default, it's "native" byte order.
  128. Args:
  129. endianness: the new fallback byte order
  130. """
  131. if not isinstance(endianness, LoadEndianness) and endianness is not None:
  132. raise TypeError("Invalid argument type in function set_default_load_endianness")
  133. from torch.utils.serialization import config
  134. config.load.endianness = endianness
  135. def get_crc32_options() -> bool:
  136. """
  137. Get whether :func:`torch.save` computes and writes crc32 for each record.
  138. Defaults to ``True``.
  139. """
  140. from torch.utils.serialization import config
  141. return config.save.compute_crc32
  142. def set_crc32_options(compute_crc32: bool):
  143. """
  144. Set whether :func:`torch.save` computes and writes crc32 for each record.
  145. .. note::
  146. Setting this to ``False`` may make unzipping of the ``torch.save`` output
  147. fail or warn due to corrupted CRC32. However ``torch.load`` will be
  148. able to load the file.
  149. Args:
  150. compute_crc32 (bool): set crc32 computation flag
  151. """
  152. from torch.utils.serialization import config
  153. config.save.compute_crc32 = compute_crc32
  154. def get_default_mmap_options() -> Optional[int]:
  155. """
  156. Get default mmap options for :func:`torch.load` with ``mmap=True``.
  157. Defaults to ``mmap.MAP_PRIVATE``.
  158. Returns:
  159. default_mmap_options: int
  160. """
  161. from torch.utils.serialization import config
  162. return config.load.mmap_flags
  163. def _get_storage_alignment() -> int:
  164. """
  165. Gets alignment for storages in torch.save files/
  166. Defaults to 64.
  167. Returns:
  168. storage_alginment: int
  169. """
  170. from torch.utils.serialization import config
  171. return config.save.storage_alignment
  172. class set_default_mmap_options:
  173. """
  174. Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
  175. For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
  176. Please open an issue if you need any other option to be added here.
  177. .. note::
  178. This feature is currently not supported for Windows.
  179. Args:
  180. flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
  181. """
  182. def __init__(self, flags: int) -> None:
  183. if IS_WINDOWS:
  184. raise RuntimeError(
  185. "Changing the default mmap options is currently not supported for Windows"
  186. )
  187. if flags != MAP_PRIVATE and flags != MAP_SHARED:
  188. raise ValueError(
  189. "Invalid argument in function set_default_mmap_options, "
  190. f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
  191. )
  192. # global config
  193. from torch.utils.serialization import config
  194. self.prev = config.load.mmap_flags
  195. config.load.mmap_flags = flags
  196. def __enter__(self) -> None:
  197. pass
  198. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  199. from torch.utils.serialization import config
  200. config.load.mmap_flags = self.prev
  201. def clear_safe_globals() -> None:
  202. """
  203. Clears the list of globals that are safe for ``weights_only`` load.
  204. """
  205. _weights_only_unpickler._clear_safe_globals()
  206. def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
  207. """
  208. Returns the list of user-added globals that are safe for ``weights_only`` load.
  209. """
  210. return _weights_only_unpickler._get_safe_globals()
  211. def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None:
  212. """
  213. Marks the given globals as safe for ``weights_only`` load. For example, functions
  214. added to this list can be called during unpickling, classes could be instantiated
  215. and have state set.
  216. Each item in the list can either be a function/class or a tuple of the form
  217. (function/class, string) where string is the full path of the function/class.
  218. Within the serialized format, each function is identified with its full
  219. path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this
  220. full path that should match the one in the checkpoint otherwise the default
  221. ``{fn.__module__}.{fn.__qualname__}`` will be used.
  222. Args:
  223. safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe
  224. Example:
  225. >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
  226. >>> import tempfile
  227. >>> class MyTensor(torch.Tensor):
  228. ... pass
  229. >>> t = MyTensor(torch.randn(2, 3))
  230. >>> with tempfile.NamedTemporaryFile() as f:
  231. ... torch.save(t, f.name)
  232. # Running `torch.load(f.name, weights_only=True)` will fail with
  233. # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
  234. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
  235. ... torch.serialization.add_safe_globals([MyTensor])
  236. ... torch.load(f.name, weights_only=True)
  237. # MyTensor([[-0.5024, -1.8152, -0.5455],
  238. # [-0.8234, 2.0500, -0.3657]])
  239. """
  240. _weights_only_unpickler._add_safe_globals(safe_globals)
  241. class safe_globals(_weights_only_unpickler._safe_globals):
  242. r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
  243. Args:
  244. safe_globals: List of globals for weights_only load.
  245. Example:
  246. >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
  247. >>> import tempfile
  248. >>> class MyTensor(torch.Tensor):
  249. ... pass
  250. >>> t = MyTensor(torch.randn(2, 3))
  251. >>> with tempfile.NamedTemporaryFile() as f:
  252. ... torch.save(t, f.name)
  253. # Running `torch.load(f.name, weights_only=True)` will fail with
  254. # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
  255. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
  256. ... with torch.serialization.safe_globals([MyTensor]):
  257. ... torch.load(f.name, weights_only=True)
  258. # MyTensor([[-0.5024, -1.8152, -0.5455],
  259. # [-0.8234, 2.0500, -0.3657]])
  260. >>> assert torch.serialization.get_safe_globals() == []
  261. """
  262. def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
  263. """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
  264. For a given function or class ``f``, the corresponding string will be of the form
  265. ``{f.__module__}.{f.__name__}``.
  266. This function will return any GLOBALs in the checkpoint that are not in the set marked safe
  267. for ``weights_only`` (either via :func:`add_safe_globals` or :class:`safe_globals` context or
  268. allowlisted by ``torch`` by default).
  269. .. note::
  270. This function will statically disassemble the pickle file in the checkpoint.
  271. The implication is any classes dynamically pushed onto the stack during unpickling
  272. will not be included in the output.
  273. Args:
  274. f: File-like object or string containing the checkpoint object saved via ``torch.save``
  275. Returns:
  276. A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``.
  277. """
  278. default_safe_globals_strings = set(
  279. _weights_only_unpickler._get_allowed_globals().keys()
  280. )
  281. user_safe_global_strings = set(
  282. _weights_only_unpickler._get_user_allowed_globals().keys()
  283. )
  284. safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings)
  285. with _open_file_like(f, "rb") as opened_file:
  286. if not _is_zipfile(opened_file):
  287. raise ValueError("Expected input to be a checkpoint returned by torch.save")
  288. with _open_zipfile_reader(opened_file) as zip_file:
  289. if _is_torchscript_zip(zip_file):
  290. raise ValueError(
  291. "Expected input to be a checkpoint returned by torch.save but got a torchscript checkpoint"
  292. )
  293. data_file = io.BytesIO(zip_file.get_record("data.pkl"))
  294. all_globals = _weights_only_unpickler.get_globals_in_pkl(data_file)
  295. return list(all_globals.difference(safe_global_strings))
  296. class skip_data:
  297. """
  298. Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls.
  299. For the save path, storages will still be saved, but the space that their bytes would usually be written to
  300. will be empty space. The storage bytes can then be populated in a separate pass.
  301. For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
  302. .. warning::
  303. The ``skip_data`` context manager is an early prototype and is subject to change.
  304. Args:
  305. materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path.
  306. Example:
  307. >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
  308. >>> import tempfile
  309. >>> t = torch.randn(2, 3)
  310. >>> with tempfile.NamedTemporaryFile() as f:
  311. ... with torch.serialization.skip_data():
  312. ... torch.save(t, f.name)
  313. ... torch.load(f.name, weights_only=True)
  314. tensor([[0., 0., 0.],
  315. [0., 0., 0.]])
  316. """
  317. def __init__(self, materialize_fake_tensors: bool = False):
  318. self.materialize_fake_tensors = materialize_fake_tensors
  319. def __enter__(self):
  320. global _serialization_tls
  321. self._old_skip_data = _serialization_tls.skip_data
  322. self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors
  323. _serialization_tls.skip_data = True
  324. _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors
  325. def __exit__(self, type, value, tb):
  326. global _serialization_tls
  327. _serialization_tls.skip_data = self._old_skip_data
  328. _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors
  329. def _is_zipfile(f) -> bool:
  330. # This is a stricter implementation than zipfile.is_zipfile().
  331. # zipfile.is_zipfile() is True if the magic number appears anywhere in the
  332. # binary. Since we expect the files here to be generated by torch.save or
  333. # torch.jit.save, it's safe to only check the start bytes and avoid
  334. # collisions and assume the zip has only 1 file.
  335. # See bugs.python.org/issue28494.
  336. start = f.tell()
  337. # Read the first few bytes and match against the ZIP file signature
  338. local_header_magic_number = b"PK\x03\x04"
  339. read_bytes = f.read(len(local_header_magic_number))
  340. f.seek(start)
  341. return read_bytes == local_header_magic_number
  342. def register_package(
  343. priority: int,
  344. tagger: Callable[[STORAGE], Optional[str]],
  345. deserializer: Callable[[STORAGE, str], Optional[STORAGE]],
  346. ):
  347. """
  348. Registers callables for tagging and deserializing storage objects with an associated priority.
  349. Tagging associates a device with a storage object at save time while deserializing moves a
  350. storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
  351. are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
  352. value that is not `None`.
  353. To override the deserialization behavior for a device in the global registry, one can register a
  354. tagger with a higher priority than the existing tagger.
  355. This function can also be used to register a tagger and deserializer for new devices.
  356. Args:
  357. priority: Indicates the priority associated with the tagger and deserializer, where a lower
  358. value indicates higher priority.
  359. tagger: Callable that takes in a storage object and returns its tagged device as a string
  360. or None.
  361. deserializer: Callable that takes in storage object and a device string and returns a storage
  362. object on the appropriate device or None.
  363. Returns:
  364. `None`
  365. Example:
  366. >>> def ipu_tag(obj):
  367. >>> if obj.device.type == 'ipu':
  368. >>> return 'ipu'
  369. >>> def ipu_deserialize(obj, location):
  370. >>> if location.startswith('ipu'):
  371. >>> ipu = getattr(torch, "ipu", None)
  372. >>> assert ipu is not None, "IPU device module is not loaded"
  373. >>> assert torch.ipu.is_available(), "ipu is not available"
  374. >>> return obj.ipu(location)
  375. >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
  376. """
  377. queue_elem = (priority, tagger, deserializer)
  378. _package_registry.append(queue_elem)
  379. _package_registry.sort()
  380. def check_module_version_greater_or_equal(
  381. module,
  382. req_version_tuple,
  383. error_if_malformed=True,
  384. ):
  385. """
  386. Check if a module's version satisfies requirements
  387. Usually, a module's version string will be like 'x.y.z', which would be represented
  388. as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
  389. string does not match the given tuple's format up to the length of the tuple, then
  390. error and exit or emit a warning.
  391. Args:
  392. module: the module to check the version of
  393. req_version_tuple: tuple (usually of ints) representing the required version
  394. error_if_malformed: whether we should exit if module version string is malformed
  395. Returns:
  396. requirement_is_met: bool
  397. """
  398. try:
  399. version_strs = module.__version__.split(".")
  400. # Cast module version fields to match the types of the required version
  401. module_version = tuple(
  402. type(req_field)(version_strs[idx])
  403. for idx, req_field in enumerate(req_version_tuple)
  404. )
  405. requirement_is_met = module_version >= req_version_tuple
  406. except Exception as e:
  407. message = (
  408. f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
  409. f" with tuple {str(req_version_tuple)}"
  410. )
  411. if error_if_malformed:
  412. raise RuntimeError(message) from e
  413. else:
  414. warnings.warn(message + ", but continuing assuming that requirement is met")
  415. requirement_is_met = True
  416. return requirement_is_met
  417. def _cpu_tag(obj):
  418. if obj.device.type == "cpu":
  419. return "cpu"
  420. def _mps_tag(obj):
  421. if obj.device.type == "mps":
  422. return "mps"
  423. def _meta_tag(obj):
  424. if obj.device.type == "meta":
  425. return "meta"
  426. def _backend_tag(backend_name, obj):
  427. if backend_name == "privateuse1":
  428. backend_name = torch._C._get_privateuse1_backend_name()
  429. if obj.device.type == backend_name:
  430. if obj.device.index is None:
  431. return backend_name
  432. else:
  433. return backend_name + ":" + str(obj.device.index)
  434. def _cpu_deserialize(obj, location):
  435. if location == "cpu":
  436. return obj
  437. def _mps_deserialize(obj, location):
  438. if location.startswith("mps"):
  439. return obj.mps()
  440. def _meta_deserialize(obj, location):
  441. if location == "meta":
  442. return torch.UntypedStorage(obj.nbytes(), device="meta")
  443. def _validate_device(location, backend_name):
  444. """
  445. Check whether the device index of specified backend is valid
  446. In case of privateuse1 backend, your must first register a device_module for
  447. privateuse1 using torch._register_device_module. Implement the following
  448. methods in device_module like cuda: device_module._utils._get_device_index(location, True),
  449. device_module.device_count().
  450. Args:
  451. location: string of device
  452. backend_name: the backend name or the name of privateuse1, which can be renamed
  453. Returns:
  454. device_index: int
  455. """
  456. if not hasattr(torch, backend_name):
  457. raise RuntimeError(
  458. f"The {backend_name.upper()} device module is not registered. "
  459. "If you are running on a CPU-only machine, "
  460. "please use torch.load with map_location=torch.device('cpu') "
  461. "to map your storages to the CPU."
  462. )
  463. device_module = getattr(torch, backend_name)
  464. if hasattr(device_module, "_utils") and hasattr(
  465. device_module._utils, "_get_device_index"
  466. ):
  467. device_index = device_module._utils._get_device_index(location, True)
  468. device = torch.device(backend_name, device_index)
  469. else:
  470. device = torch.device(location)
  471. device_index = device.index if device.index else 0
  472. if hasattr(device_module, "is_available") and not device_module.is_available():
  473. raise RuntimeError(
  474. f"Attempting to deserialize object on a {backend_name.upper()} "
  475. f"device but torch.{backend_name}.is_available() is False. "
  476. "If you are running on a CPU-only machine, "
  477. "please use torch.load with map_location=torch.device('cpu') "
  478. "to map your storages to the CPU."
  479. )
  480. if hasattr(device_module, "device_count"):
  481. device_count = device_module.device_count()
  482. if device_index >= device_count:
  483. raise RuntimeError(
  484. f"Attempting to deserialize object on {backend_name.upper()} device "
  485. f"{device_index} but torch.{backend_name}.device_count() is {device_count}. "
  486. "Please use torch.load with map_location to map your storages "
  487. "to an existing device."
  488. )
  489. return device
  490. def validate_cuda_device(location):
  491. return _validate_device(location, "cuda").index
  492. def validate_hpu_device(location):
  493. return _validate_device(location, "hpu").index
  494. def _deserialize(backend_name, obj, location):
  495. if backend_name == "privateuse1":
  496. backend_name = torch._C._get_privateuse1_backend_name()
  497. if location.startswith(backend_name):
  498. device = _validate_device(location, backend_name)
  499. return obj.to(device=device)
  500. register_package(10, _cpu_tag, _cpu_deserialize)
  501. register_package(
  502. 20,
  503. functools.partial(_backend_tag, "cuda"),
  504. functools.partial(_deserialize, "cuda"),
  505. )
  506. register_package(21, _mps_tag, _mps_deserialize)
  507. register_package(22, _meta_tag, _meta_deserialize)
  508. register_package(
  509. 23,
  510. functools.partial(_backend_tag, "privateuse1"),
  511. functools.partial(_deserialize, "privateuse1"),
  512. )
  513. register_package(
  514. 24,
  515. functools.partial(_backend_tag, "hpu"),
  516. functools.partial(_deserialize, "hpu"),
  517. )
  518. register_package(
  519. 25,
  520. functools.partial(_backend_tag, "xpu"),
  521. functools.partial(_deserialize, "xpu"),
  522. )
  523. def location_tag(
  524. storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage],
  525. ):
  526. for _, tagger, _ in _package_registry:
  527. location = tagger(storage)
  528. if location:
  529. return location
  530. raise RuntimeError(
  531. "don't know how to determine data location of " + torch.typename(storage)
  532. )
  533. def default_restore_location(storage, location):
  534. """
  535. Restores `storage` using a deserializer function registered for the `location`.
  536. This function looks in the registry for deserializer functions that match the `location`.
  537. If found, it attempts to use them, in priority order, to restore `storage` until one
  538. returns a not `None` result. If no deserializer can be found in the registry, or all found fail
  539. to bear a result, it raises a `RuntimeError`.
  540. Args:
  541. storage (STORAGE): the storage object to restore
  542. location (str): the location tag associated with the storage object
  543. Returns:
  544. storage: Optional[STORAGE]
  545. Raises:
  546. RuntimeError: If no deserializer matching `location` is found in the registry or if
  547. all matching ones return `None`.
  548. """
  549. for _, _, fn in _package_registry:
  550. result = fn(storage, location)
  551. if result is not None:
  552. return result
  553. raise RuntimeError(
  554. "don't know how to restore data location of "
  555. + torch.typename(storage)
  556. + " (tagged with "
  557. + location
  558. + ")"
  559. )
  560. def normalize_storage_type(storage_type):
  561. return getattr(torch, storage_type.__name__)
  562. def storage_to_tensor_type(storage):
  563. storage_type = type(storage)
  564. module = _import_dotted_name(storage_type.__module__)
  565. return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
  566. def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]:
  567. return isinstance(name_or_buffer, (str, os.PathLike))
  568. T = TypeVar("T")
  569. class _opener(Generic[T]):
  570. def __init__(self, file_like: T) -> None:
  571. self.file_like: T = file_like
  572. def __enter__(self):
  573. return self.file_like
  574. def __exit__(self, *args):
  575. pass
  576. class _open_file(_opener[IO[bytes]]):
  577. def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
  578. super().__init__(open(name, mode))
  579. def __exit__(self, *args):
  580. self.file_like.close()
  581. class _open_buffer_reader(_opener[IO[bytes]]):
  582. def __init__(self, buffer: IO[bytes]) -> None:
  583. super().__init__(buffer)
  584. _check_seekable(buffer)
  585. class _open_buffer_writer(_opener[IO[bytes]]):
  586. def __exit__(self, *args):
  587. self.file_like.flush()
  588. def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
  589. if _is_path(name_or_buffer):
  590. return _open_file(name_or_buffer, mode)
  591. else:
  592. if "w" in mode:
  593. return _open_buffer_writer(name_or_buffer)
  594. elif "r" in mode:
  595. return _open_buffer_reader(name_or_buffer)
  596. else:
  597. raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
  598. class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
  599. def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
  600. super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
  601. class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
  602. def __init__(self, name: str) -> None:
  603. self.file_stream = None
  604. self.name = name
  605. try:
  606. self.name.encode("ascii")
  607. except UnicodeEncodeError:
  608. # PyTorchFileWriter only supports ascii filename.
  609. # For filenames with non-ascii characters, we rely on Python
  610. # for writing out the file.
  611. self.file_stream = io.FileIO(self.name, mode="w")
  612. super().__init__(
  613. torch._C.PyTorchFileWriter(
  614. self.file_stream, get_crc32_options(), _get_storage_alignment()
  615. )
  616. )
  617. else:
  618. super().__init__(
  619. torch._C.PyTorchFileWriter(
  620. self.name, get_crc32_options(), _get_storage_alignment()
  621. )
  622. )
  623. def __exit__(self, *args) -> None:
  624. self.file_like.write_end_of_file()
  625. if self.file_stream is not None:
  626. self.file_stream.close()
  627. class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
  628. def __init__(self, buffer: IO[bytes]) -> None:
  629. if not callable(getattr(buffer, "write", None)):
  630. msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
  631. if not hasattr(buffer, "write"):
  632. raise AttributeError(msg)
  633. raise TypeError(msg)
  634. self.buffer = buffer
  635. super().__init__(
  636. torch._C.PyTorchFileWriter(
  637. buffer, get_crc32_options(), _get_storage_alignment()
  638. )
  639. )
  640. def __exit__(self, *args) -> None:
  641. self.file_like.write_end_of_file()
  642. self.buffer.flush()
  643. def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener:
  644. container: type[_opener]
  645. if _is_path(name_or_buffer):
  646. container = _open_zipfile_writer_file
  647. else:
  648. container = _open_zipfile_writer_buffer
  649. return container(name_or_buffer) # type: ignore[arg-type]
  650. def _is_compressed_file(f) -> bool:
  651. compress_modules = ["gzip"]
  652. try:
  653. return f.__module__ in compress_modules
  654. except AttributeError:
  655. return False
  656. def _should_read_directly(f):
  657. """
  658. Checks if f is a file that should be read directly. It should be read
  659. directly if it is backed by a real file (has a fileno) and is not a
  660. a compressed file (e.g. gzip)
  661. """
  662. if _is_compressed_file(f):
  663. return False
  664. try:
  665. return f.fileno() >= 0
  666. except io.UnsupportedOperation:
  667. return False
  668. except AttributeError:
  669. return False
  670. def _check_seekable(f) -> bool:
  671. def raise_err_msg(patterns, e):
  672. for p in patterns:
  673. if p in str(e):
  674. msg = (
  675. str(e)
  676. + ". You can only torch.load from a file that is seekable."
  677. + " Please pre-load the data into a buffer like io.BytesIO and"
  678. + " try to load from it instead."
  679. )
  680. raise type(e)(msg)
  681. raise e
  682. try:
  683. f.seek(f.tell())
  684. return True
  685. except (io.UnsupportedOperation, AttributeError) as e:
  686. raise_err_msg(["seek", "tell"], e)
  687. return False
  688. def _check_dill_version(pickle_module) -> None:
  689. """Checks if using dill as the pickle module, and if so, checks if it is the correct version.
  690. If dill version is lower than 0.3.1, a ValueError is raised.
  691. Args:
  692. pickle_module: module used for pickling metadata and objects
  693. """
  694. if pickle_module is not None and pickle_module.__name__ == "dill":
  695. required_dill_version = (0, 3, 1)
  696. if not check_module_version_greater_or_equal(
  697. pickle_module, required_dill_version, False
  698. ):
  699. raise ValueError(
  700. (
  701. "'torch' supports dill >= {}, but you have dill {}."
  702. " Please upgrade dill or switch to 'pickle'"
  703. ).format(
  704. ".".join([str(num) for num in required_dill_version]),
  705. pickle_module.__version__,
  706. )
  707. )
  708. def _check_save_filelike(f):
  709. if not _is_path(f) and not hasattr(f, "write"):
  710. raise AttributeError(
  711. "expected 'f' to be string, path, or a file-like object with "
  712. "a 'write' attribute"
  713. )
  714. def save(
  715. obj: object,
  716. f: FileLike,
  717. pickle_module: Any = pickle,
  718. pickle_protocol: int = DEFAULT_PROTOCOL,
  719. _use_new_zipfile_serialization: bool = True,
  720. _disable_byteorder_record: bool = False,
  721. ) -> None:
  722. # Reference: https://github.com/pytorch/pytorch/issues/54354
  723. # The first line of this docstring overrides the one Sphinx generates for the
  724. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  725. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  726. """save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=True)
  727. Saves an object to a disk file.
  728. See also: :ref:`saving-loading-tensors`
  729. See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
  730. Args:
  731. obj: saved object
  732. f: a file-like object (has to implement write and flush) or a string or
  733. os.PathLike object containing a file name
  734. pickle_module: module used for pickling metadata and objects
  735. pickle_protocol: can be specified to override the default protocol
  736. .. note::
  737. A common PyTorch convention is to save tensors using .pt file extension.
  738. .. note::
  739. PyTorch preserves storage sharing across serialization. See
  740. :ref:`preserve-storage-sharing` for more details.
  741. .. note::
  742. The 1.6 release of PyTorch switched ``torch.save`` to use a new
  743. zipfile-based file format. ``torch.load`` still retains the ability to
  744. load files in the old format. If for any reason you want ``torch.save``
  745. to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
  746. Example:
  747. >>> # xdoctest: +SKIP("makes cwd dirty")
  748. >>> # Save to file
  749. >>> x = torch.tensor([0, 1, 2, 3, 4])
  750. >>> torch.save(x, "tensor.pt")
  751. >>> # Save to io.BytesIO buffer
  752. >>> buffer = io.BytesIO()
  753. >>> torch.save(x, buffer)
  754. """
  755. torch._C._log_api_usage_once("torch.save")
  756. _check_dill_version(pickle_module)
  757. _check_save_filelike(f)
  758. if isinstance(f, (str, os.PathLike)):
  759. f = os.fspath(f)
  760. if _use_new_zipfile_serialization:
  761. with _open_zipfile_writer(f) as opened_zipfile:
  762. _save(
  763. obj,
  764. opened_zipfile,
  765. pickle_module,
  766. pickle_protocol,
  767. _disable_byteorder_record,
  768. )
  769. return
  770. else:
  771. global _serialization_tls
  772. if _serialization_tls.skip_data:
  773. raise RuntimeError(
  774. "Cannot use skip_data=True with _use_new_zipfile_serialization=False"
  775. )
  776. with _open_file_like(f, "wb") as opened_file:
  777. _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  778. def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
  779. import torch.nn as nn
  780. serialized_container_types = {}
  781. serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {}
  782. # Since loading storages that view the same data with different dtypes is
  783. # not supported, we need to keep track of the dtype associated with each
  784. # storage data_ptr and throw an error if the dtype is ever different.
  785. # TODO: This feature could be added in the future
  786. storage_dtypes: dict[int, torch.dtype] = {}
  787. def persistent_id(obj: Any) -> Optional[tuple]:
  788. # FIXME: the docs say that persistent_id should only return a string
  789. # but torch store returns tuples. This works only in the binary protocol
  790. # see
  791. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  792. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  793. if isinstance(obj, type) and issubclass(obj, nn.Module):
  794. if obj in serialized_container_types:
  795. return None
  796. serialized_container_types[obj] = True
  797. source_file = source = None
  798. try:
  799. source_lines, _, source_file = get_source_lines_and_file(obj)
  800. source = "".join(source_lines)
  801. except (
  802. Exception
  803. ): # saving the source is optional, so we can ignore any errors
  804. warnings.warn(
  805. "Couldn't retrieve source code for container of "
  806. "type " + obj.__name__ + ". It won't be checked "
  807. "for correctness upon loading."
  808. )
  809. return ("module", obj, source_file, source)
  810. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  811. storage: torch.UntypedStorage
  812. if isinstance(obj, torch.storage.TypedStorage):
  813. # TODO: Once we decide to break serialization FC, this case
  814. # can be deleted
  815. storage = obj._untyped_storage
  816. storage_dtype = obj.dtype
  817. storage_type_str = obj._pickle_storage_type()
  818. storage_type = getattr(torch, storage_type_str)
  819. dtype = obj.dtype
  820. storage_numel = obj._size()
  821. elif isinstance(obj, torch.UntypedStorage):
  822. storage = obj
  823. storage_dtype = torch.uint8
  824. storage_type = normalize_storage_type(type(obj))
  825. dtype = torch.uint8
  826. storage_numel = storage.nbytes()
  827. else:
  828. raise TypeError(f"type not recognized: {type(obj)}")
  829. # If storage is allocated, ensure that any other saved storages
  830. # pointing to the same data all have the same dtype. If storage is
  831. # not allocated, don't perform this check
  832. if storage.data_ptr() != 0:
  833. if storage.data_ptr() in storage_dtypes:
  834. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  835. raise RuntimeError(
  836. "Cannot save multiple tensors or storages that "
  837. "view the same data as different types"
  838. )
  839. else:
  840. storage_dtypes[storage.data_ptr()] = storage_dtype
  841. view_metadata: Optional[tuple[str, int, int]]
  842. # Offset is always 0, but we keep it for backwards compatibility
  843. # with the old serialization format (which supported storage views)
  844. offset = 0
  845. storage_key = str(storage._cdata)
  846. location = location_tag(storage)
  847. # TODO: There's an issue here with FC. It might be impossible to
  848. # solve, but it's worth noting. Imagine we save a list `[storage,
  849. # tensor]`, where `tensor.storage()` is the same as `storage`, and
  850. # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
  851. # torch.float`. The storage will be serialized with element size
  852. # of 1, since we're choosing to serialize the first occurrence of
  853. # a duplicate storage. Since this legacy serialization format saves
  854. # the numel of the storage, rather than nbytes directly, we'll be
  855. # effectively saving nbytes in this case. We'll be able to load it
  856. # and the tensor back up with no problems in _this_ and future
  857. # versions of pytorch, but in older versions, here's the problem:
  858. # the storage will be loaded up as a UntypedStorage, and then the
  859. # FloatTensor will loaded and the UntypedStorage will be assigned to
  860. # it. Since the storage dtype does not match the tensor dtype, this
  861. # will cause an error. If we reverse the list, like `[tensor,
  862. # storage]`, then we will save the `tensor.storage()` as a faked
  863. # `FloatStorage`, and the saved size will be the correct
  864. # dtype-specific numel count that old versions expect. `tensor`
  865. # will be able to load up properly in old versions, pointing to
  866. # a FloatStorage. However, `storage` is still being translated to
  867. # a UntypedStorage, and it will try to resolve to the same
  868. # FloatStorage that `tensor` contains. This will also cause an
  869. # error. It doesn't seem like there's any way around this.
  870. # Probably, we just cannot maintain FC for the legacy format if the
  871. # saved list contains both a tensor and a storage that point to the
  872. # same data. We should still be able to maintain FC for lists of
  873. # just tensors, as long as all views share the same dtype as the
  874. # tensor they are viewing.
  875. if storage_key not in serialized_storages:
  876. serialized_storages[storage_key] = (storage, dtype)
  877. is_view = storage._cdata != storage._cdata
  878. if is_view:
  879. view_metadata = (str(storage._cdata), offset, storage.nbytes())
  880. else:
  881. view_metadata = None
  882. res = (
  883. "storage",
  884. storage_type,
  885. storage_key,
  886. location,
  887. storage_numel,
  888. view_metadata,
  889. )
  890. return res
  891. return None
  892. sys_info = {
  893. "protocol_version": PROTOCOL_VERSION,
  894. "little_endian": sys.byteorder == "little",
  895. "type_sizes": {
  896. "short": SHORT_SIZE,
  897. "int": INT_SIZE,
  898. "long": LONG_SIZE,
  899. },
  900. }
  901. pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
  902. pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
  903. pickle_module.dump(sys_info, f, protocol=pickle_protocol)
  904. class PyTorchLegacyPickler(pickle_module.Pickler):
  905. def persistent_id(self, obj):
  906. return persistent_id(obj)
  907. pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol)
  908. pickler.dump(obj)
  909. serialized_storage_keys = sorted(serialized_storages.keys())
  910. pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
  911. f.flush()
  912. for key in serialized_storage_keys:
  913. storage, dtype = serialized_storages[key]
  914. storage._write_file(
  915. f, _should_read_directly(f), True, torch._utils._element_size(dtype)
  916. )
  917. def _save(
  918. obj,
  919. zip_file,
  920. pickle_module,
  921. pickle_protocol,
  922. _disable_byteorder_record,
  923. ):
  924. serialized_storages: dict[str, torch.storage.UntypedStorage] = {}
  925. id_map: dict[int, str] = {}
  926. # Since loading storages that view the same data with different dtypes is
  927. # not supported, we need to keep track of the dtype associated with each
  928. # storage data_ptr and throw an error if the dtype is ever different.
  929. # TODO: This feature could be added in the future
  930. storage_dtypes: dict[int, torch.dtype] = {}
  931. def persistent_id(obj):
  932. # FIXME: the docs say that persistent_id should only return a string
  933. # but torch store returns tuples. This works only in the binary protocol
  934. # see
  935. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  936. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  937. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  938. if isinstance(obj, torch.storage.TypedStorage):
  939. # TODO: Once we decide to break serialization FC, this case
  940. # can be deleted
  941. storage = obj._untyped_storage
  942. storage_dtype = obj.dtype
  943. storage_type_str = obj._pickle_storage_type()
  944. storage_type = getattr(torch, storage_type_str)
  945. storage_numel = obj._size()
  946. else:
  947. storage = obj
  948. storage_dtype = torch.uint8
  949. storage_type = normalize_storage_type(type(obj))
  950. storage_numel = storage.nbytes()
  951. # If storage is allocated, ensure that any other saved storages
  952. # pointing to the same data all have the same dtype. If storage is
  953. # not allocated, don't perform this check
  954. if str(storage.device) != "meta" and storage.data_ptr() != 0:
  955. if storage.data_ptr() in storage_dtypes:
  956. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  957. raise RuntimeError(
  958. "Cannot save multiple tensors or storages that "
  959. "view the same data as different types"
  960. )
  961. else:
  962. storage_dtypes[storage.data_ptr()] = storage_dtype
  963. storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
  964. if hasattr(obj, "_fake_device") and obj._fake_device is not None:
  965. location = str(obj._fake_device)
  966. else:
  967. location = location_tag(storage)
  968. serialized_storages[storage_key] = storage
  969. return ("storage", storage_type, storage_key, location, storage_numel)
  970. return None
  971. # Write the pickle data for `obj`
  972. data_buf = io.BytesIO()
  973. class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined]
  974. def persistent_id(self, obj):
  975. return persistent_id(obj)
  976. pickler = PyTorchPickler(data_buf, protocol=pickle_protocol)
  977. pickler.dump(obj)
  978. data_value = data_buf.getvalue()
  979. zip_file.write_record("data.pkl", data_value, len(data_value))
  980. # .format_version is used to track
  981. # 1. version 1 represents the order of storages being changed from
  982. # lexicographical based on keys to numerically ordered based on keys
  983. # 2. version 2 represents including storage_alignment as a record
  984. # within the zipfile
  985. zip_file.write_record(".format_version", "1", len("1"))
  986. storage_alignment = str(_get_storage_alignment())
  987. zip_file.write_record(
  988. ".storage_alignment", storage_alignment, len(storage_alignment)
  989. )
  990. # Write byte order marker
  991. if not _disable_byteorder_record:
  992. if sys.byteorder not in ["little", "big"]:
  993. raise ValueError("Unknown endianness type: " + sys.byteorder)
  994. zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))
  995. # Write each tensor to a file named tensor/the_tensor_key in the zip archive
  996. for key in serialized_storages.keys():
  997. name = f"data/{key}"
  998. storage = serialized_storages[key]
  999. num_bytes = storage.nbytes()
  1000. global _serialization_tls
  1001. if _serialization_tls.skip_data:
  1002. zip_file.write_record_metadata(name, num_bytes)
  1003. else:
  1004. # given that we copy things around anyway, we might use storage.cpu()
  1005. # this means to that to get tensors serialized, you need to implement
  1006. # .cpu() on the underlying Storage
  1007. if storage.device.type != "cpu":
  1008. from torch.utils.serialization import config
  1009. if (
  1010. config.save.use_pinned_memory_for_d2h
  1011. and (
  1012. acc := torch.accelerator.current_accelerator(
  1013. check_available=True
  1014. )
  1015. )
  1016. is not None
  1017. and acc.type == storage.device.type
  1018. ):
  1019. new_storage = torch.empty(
  1020. num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
  1021. ).untyped_storage()
  1022. new_storage.copy_(storage)
  1023. torch.accelerator.current_stream(storage.device.index).synchronize()
  1024. storage = new_storage
  1025. else:
  1026. storage = storage.cpu()
  1027. # Now that it is on the CPU we can directly copy it into the zip file
  1028. zip_file.write_record(name, storage, num_bytes)
  1029. def load(
  1030. f: FileLike,
  1031. map_location: MAP_LOCATION = None,
  1032. pickle_module: Any = None,
  1033. *,
  1034. weights_only: Optional[bool] = None,
  1035. mmap: Optional[bool] = None,
  1036. **pickle_load_args: Any,
  1037. ) -> Any:
  1038. # Reference: https://github.com/pytorch/pytorch/issues/54354
  1039. # The first line of this docstring overrides the one Sphinx generates for the
  1040. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  1041. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  1042. """load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)
  1043. Loads an object saved with :func:`torch.save` from a file.
  1044. :func:`torch.load` uses Python's unpickling facilities but treats storages,
  1045. which underlie tensors, specially. They are first deserialized on the
  1046. CPU and are then moved to the device they were saved from. If this fails
  1047. (e.g. because the run time system doesn't have certain devices), an exception
  1048. is raised. However, storages can be dynamically remapped to an alternative
  1049. set of devices using the :attr:`map_location` argument.
  1050. If :attr:`map_location` is a callable, it will be called once for each serialized
  1051. storage with two arguments: storage and location. The storage argument
  1052. will be the initial deserialization of the storage, residing on the CPU.
  1053. Each serialized storage has a location tag associated with it which
  1054. identifies the device it was saved from, and this tag is the second
  1055. argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
  1056. for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
  1057. :attr:`map_location` should return either ``None`` or a storage. If
  1058. :attr:`map_location` returns a storage, it will be used as the final deserialized
  1059. object, already moved to the right device. Otherwise, :func:`torch.load` will
  1060. fall back to the default behavior, as if :attr:`map_location` wasn't specified.
  1061. If :attr:`map_location` is a :class:`torch.device` object or a string containing
  1062. a device tag, it indicates the location where all tensors should be loaded.
  1063. Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
  1064. appearing in the file (keys), to ones that specify where to put the
  1065. storages (values).
  1066. User extensions can register their own location tags and tagging and
  1067. deserialization methods using :func:`torch.serialization.register_package`.
  1068. See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
  1069. Args:
  1070. f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  1071. or a string or os.PathLike object containing a file name
  1072. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
  1073. locations
  1074. pickle_module: module used for unpickling metadata and objects (has to
  1075. match the :attr:`pickle_module` used to serialize file)
  1076. weights_only: Indicates whether unpickler should be restricted to
  1077. loading only tensors, primitive types, dictionaries
  1078. and any types added via :func:`torch.serialization.add_safe_globals`.
  1079. See :ref:`weights-only` for more details.
  1080. mmap: Indicates whether the file should be mapped rather than loading all the storages into memory.
  1081. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
  1082. are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
  1083. second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
  1084. tensor storages from disk to CPU memory in the first step, ``f`` is mapped, which means tensor storages
  1085. will be lazily loaded when their data is accessed.
  1086. pickle_load_args: (Python 3 only) optional keyword arguments passed over to
  1087. :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
  1088. :attr:`errors=...`.
  1089. .. warning::
  1090. :func:`torch.load()` unless `weights_only` parameter is set to `True`,
  1091. uses ``pickle`` module implicitly, which is known to be insecure.
  1092. It is possible to construct malicious pickle data which will execute arbitrary code
  1093. during unpickling. Never load data that could have come from an untrusted
  1094. source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
  1095. .. note::
  1096. When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
  1097. will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
  1098. and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
  1099. .. note::
  1100. By default, we decode byte strings as ``utf-8``. This is to avoid a common error
  1101. case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
  1102. when loading files saved by Python 2 in Python 3. If this default
  1103. is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
  1104. these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
  1105. to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
  1106. as byte arrays which can be decoded later with ``byte_array.decode(...)``.
  1107. Example:
  1108. >>> # xdoctest: +SKIP("undefined filepaths")
  1109. >>> torch.load("tensors.pt", weights_only=True)
  1110. # Load all tensors onto the CPU
  1111. >>> torch.load(
  1112. ... "tensors.pt",
  1113. ... map_location=torch.device("cpu"),
  1114. ... weights_only=True,
  1115. ... )
  1116. # Load all tensors onto the CPU, using a function
  1117. >>> torch.load(
  1118. ... "tensors.pt",
  1119. ... map_location=lambda storage, loc: storage,
  1120. ... weights_only=True,
  1121. ... )
  1122. # Load all tensors onto GPU 1
  1123. >>> torch.load(
  1124. ... "tensors.pt",
  1125. ... map_location=lambda storage, loc: storage.cuda(1),
  1126. ... weights_only=True,
  1127. ... ) # type: ignore[attr-defined]
  1128. # Map tensors from GPU 1 to GPU 0
  1129. >>> torch.load(
  1130. ... "tensors.pt",
  1131. ... map_location={"cuda:1": "cuda:0"},
  1132. ... weights_only=True,
  1133. ... )
  1134. # Load tensor from io.BytesIO object
  1135. # Loading from a buffer setting weights_only=False, warning this can be unsafe
  1136. >>> with open("tensor.pt", "rb") as f:
  1137. ... buffer = io.BytesIO(f.read())
  1138. >>> torch.load(buffer, weights_only=False)
  1139. # Load a module with 'ascii' encoding for unpickling
  1140. # Loading from a module setting weights_only=False, warning this can be unsafe
  1141. >>> torch.load("module.pt", encoding="ascii", weights_only=False)
  1142. """
  1143. torch._C._log_api_usage_once("torch.load")
  1144. DOCS_MESSAGE = (
  1145. "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
  1146. "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
  1147. )
  1148. def _get_wo_message(message: str) -> str:
  1149. unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
  1150. has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
  1151. blocklist_pattern = r"whose module (\S+) is blocked"
  1152. has_blocklist = re.search(blocklist_pattern, message) is not None
  1153. import_pattern = r"(\S+) must be (\S+) to load"
  1154. has_import = re.search(import_pattern, message) is not None
  1155. if has_unsafe_global:
  1156. updated_message = (
  1157. "Weights only load failed. This file can still be loaded, to do so you have two options, "
  1158. "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
  1159. f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
  1160. "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
  1161. + message
  1162. )
  1163. else:
  1164. if has_import:
  1165. return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n"
  1166. else:
  1167. updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
  1168. if not has_blocklist:
  1169. updated_message += (
  1170. "Please file an issue with the following so that we can make "
  1171. "`weights_only=True` compatible with your use case: WeightsUnpickler error: "
  1172. )
  1173. updated_message += "\n\n" + message
  1174. return updated_message + DOCS_MESSAGE
  1175. weights_only_not_set = weights_only is None
  1176. if weights_only_not_set:
  1177. weights_only = _default_to_weights_only(pickle_module)
  1178. true_values = ["1", "y", "yes", "true"]
  1179. # Add ability to force safe only or non-safe weight loads via environment variables
  1180. force_weights_only_load = (
  1181. os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values
  1182. )
  1183. force_no_weights_only_load = (
  1184. os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values
  1185. )
  1186. if force_weights_only_load and force_no_weights_only_load:
  1187. raise RuntimeError(
  1188. "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` "
  1189. "should be set, but both were set."
  1190. )
  1191. elif force_weights_only_load:
  1192. weights_only = True
  1193. elif force_no_weights_only_load:
  1194. # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only
  1195. if weights_only_not_set:
  1196. warnings.warn(
  1197. "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
  1198. "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
  1199. UserWarning,
  1200. stacklevel=2,
  1201. )
  1202. weights_only = False
  1203. if weights_only:
  1204. if pickle_module is not None:
  1205. raise RuntimeError(
  1206. "Can not safely load weights when explicit pickle_module is specified"
  1207. )
  1208. else:
  1209. if pickle_module is None:
  1210. pickle_module = pickle
  1211. # make flipping default BC-compatible
  1212. if mmap is None:
  1213. from torch.utils.serialization import config
  1214. mmap = config.load.mmap
  1215. _check_dill_version(pickle_module)
  1216. if "encoding" not in pickle_load_args.keys():
  1217. pickle_load_args["encoding"] = "utf-8"
  1218. with _open_file_like(f, "rb") as opened_file:
  1219. if _is_zipfile(opened_file):
  1220. # The zipfile reader is going to advance the current file position.
  1221. # If we want to actually tail call to torch.jit.load, we need to
  1222. # reset back to the original position.
  1223. orig_position = opened_file.tell()
  1224. overall_storage = None
  1225. with _open_zipfile_reader(opened_file) as opened_zipfile:
  1226. if _is_torchscript_zip(opened_zipfile):
  1227. warnings.warn(
  1228. "'torch.load' received a zip file that looks like a TorchScript archive"
  1229. " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
  1230. " silence this warning)",
  1231. UserWarning,
  1232. )
  1233. if weights_only:
  1234. raise RuntimeError(
  1235. "Cannot use ``weights_only=True`` with TorchScript archives passed to "
  1236. "``torch.load``. " + UNSAFE_MESSAGE
  1237. )
  1238. opened_file.seek(orig_position)
  1239. return torch.jit.load(opened_file, map_location=map_location)
  1240. if mmap:
  1241. if not _is_path(f):
  1242. raise ValueError(
  1243. "f must be a file path in order to use the mmap argument"
  1244. )
  1245. size = os.path.getsize(f)
  1246. if not IS_WINDOWS:
  1247. shared = get_default_mmap_options() == MAP_SHARED
  1248. else:
  1249. shared = False
  1250. overall_storage = torch.UntypedStorage.from_file(
  1251. os.fspath(f), shared, size
  1252. )
  1253. if weights_only:
  1254. try:
  1255. return _load(
  1256. opened_zipfile,
  1257. map_location,
  1258. _weights_only_unpickler,
  1259. overall_storage=overall_storage,
  1260. **pickle_load_args,
  1261. )
  1262. except pickle.UnpicklingError as e:
  1263. raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
  1264. return _load(
  1265. opened_zipfile,
  1266. map_location,
  1267. pickle_module,
  1268. overall_storage=overall_storage,
  1269. **pickle_load_args,
  1270. )
  1271. if mmap:
  1272. f_name = "" if not isinstance(f, str) else f"{f}, "
  1273. raise RuntimeError(
  1274. "mmap can only be used with files saved with "
  1275. f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
  1276. "please torch.save your checkpoint with this option in order to use mmap."
  1277. )
  1278. if weights_only:
  1279. try:
  1280. return _legacy_load(
  1281. opened_file,
  1282. map_location,
  1283. _weights_only_unpickler,
  1284. **pickle_load_args,
  1285. )
  1286. except pickle.UnpicklingError as e:
  1287. raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
  1288. return _legacy_load(
  1289. opened_file, map_location, pickle_module, **pickle_load_args
  1290. )
  1291. # Register pickling support for layout instances such as
  1292. # torch.sparse_coo, etc
  1293. def _get_layout(name):
  1294. """Get layout extension object from its string representation."""
  1295. cache = _get_layout.cache # type: ignore[attr-defined]
  1296. if not cache:
  1297. for v in torch.__dict__.values():
  1298. if isinstance(v, torch.layout):
  1299. cache[str(v)] = v
  1300. return cache[name]
  1301. # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
  1302. _get_layout.cache = {} # type: ignore[attr-defined]
  1303. copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
  1304. def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
  1305. deserialized_objects: dict[int, Any] = {}
  1306. restore_location = _get_restore_location(map_location)
  1307. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  1308. def find_class(self, mod_name, name):
  1309. if type(name) is str and "Storage" in name:
  1310. try:
  1311. return StorageType(name)
  1312. except KeyError:
  1313. pass
  1314. return super().find_class(mod_name, name)
  1315. def _check_container_source(container_type, source_file, original_source):
  1316. try:
  1317. current_source = "".join(get_source_lines_and_file(container_type)[0])
  1318. except Exception: # saving the source is optional, so we can ignore any errors
  1319. warnings.warn(
  1320. "Couldn't retrieve source code for container of "
  1321. "type " + container_type.__name__ + ". It won't be checked "
  1322. "for correctness upon loading."
  1323. )
  1324. return
  1325. if original_source != current_source:
  1326. if container_type.dump_patches:
  1327. file_name = container_type.__name__ + ".patch"
  1328. diff = difflib.unified_diff(
  1329. current_source.split("\n"),
  1330. original_source.split("\n"),
  1331. source_file,
  1332. source_file,
  1333. lineterm="",
  1334. )
  1335. lines = "\n".join(diff)
  1336. try:
  1337. with open(file_name, "a+") as f:
  1338. file_size = f.seek(0, 2)
  1339. f.seek(0)
  1340. if file_size == 0:
  1341. f.write(lines)
  1342. elif file_size != len(lines) or f.read() != lines:
  1343. raise OSError
  1344. msg = (
  1345. "Saved a reverse patch to " + file_name + ". "
  1346. "Run `patch -p0 < " + file_name + "` to revert your "
  1347. "changes."
  1348. )
  1349. except OSError:
  1350. msg = (
  1351. "Tried to save a patch, but couldn't create a "
  1352. "writable file " + file_name + ". Make sure it "
  1353. "doesn't exist and your working directory is "
  1354. "writable."
  1355. )
  1356. else:
  1357. msg = (
  1358. "you can retrieve the original source code by "
  1359. "accessing the object's source attribute or set "
  1360. "`torch.nn.Module.dump_patches = True` and use the "
  1361. "patch tool to revert the changes."
  1362. )
  1363. msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
  1364. warnings.warn(msg, SourceChangeWarning)
  1365. def legacy_load(f):
  1366. deserialized_objects: dict[int, Any] = {}
  1367. def persistent_load(saved_id):
  1368. if isinstance(saved_id, tuple):
  1369. # Ignore containers that don't have any sources saved
  1370. if all(saved_id[1:]):
  1371. _check_container_source(*saved_id)
  1372. return saved_id[0]
  1373. return deserialized_objects[int(saved_id)]
  1374. with (
  1375. closing(
  1376. tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
  1377. ) as tar,
  1378. mkdtemp() as tmpdir,
  1379. ):
  1380. if pickle_module is _weights_only_unpickler:
  1381. raise RuntimeError(
  1382. "Cannot use ``weights_only=True`` with files saved in the "
  1383. "legacy .tar format. " + UNSAFE_MESSAGE
  1384. )
  1385. tar.extract("storages", path=tmpdir)
  1386. with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
  1387. num_storages = pickle_module.load(f, **pickle_load_args)
  1388. for _ in range(num_storages):
  1389. args = pickle_module.load(f, **pickle_load_args)
  1390. key, location, storage_type = args
  1391. dtype = storage_type._dtype
  1392. obj = cast(Storage, torch.UntypedStorage)._new_with_file(
  1393. f, torch._utils._element_size(dtype)
  1394. )
  1395. obj = restore_location(obj, location)
  1396. # TODO: Once we decide to break serialization FC, we can
  1397. # stop wrapping with TypedStorage
  1398. deserialized_objects[key] = torch.storage.TypedStorage(
  1399. wrap_storage=obj, dtype=dtype, _internal=True
  1400. )
  1401. storage_views = pickle_module.load(f, **pickle_load_args)
  1402. for target_cdata, root_cdata, offset, numel in storage_views:
  1403. root = deserialized_objects[root_cdata]
  1404. element_size = torch._utils._element_size(root.dtype)
  1405. offset_bytes = offset * element_size
  1406. # TODO: Once we decide to break serialization FC, we can
  1407. # stop wrapping with TypedStorage
  1408. deserialized_objects[target_cdata] = torch.storage.TypedStorage(
  1409. wrap_storage=root._untyped_storage[
  1410. offset_bytes : offset_bytes + numel * element_size
  1411. ],
  1412. dtype=root.dtype,
  1413. _internal=True,
  1414. )
  1415. tar.extract("tensors", path=tmpdir)
  1416. with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
  1417. num_tensors = pickle_module.load(f, **pickle_load_args)
  1418. for _ in range(num_tensors):
  1419. args = pickle_module.load(f, **pickle_load_args)
  1420. key, storage_id, _original_tensor_type = args
  1421. storage = deserialized_objects[storage_id]
  1422. (ndim,) = struct.unpack("<i", f.read(4))
  1423. # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
  1424. f.read(4)
  1425. numel = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
  1426. stride = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
  1427. (storage_offset,) = struct.unpack("<q", f.read(8))
  1428. tensor = torch.empty((0,), dtype=storage.dtype).set_(
  1429. storage._untyped_storage, storage_offset, numel, stride
  1430. )
  1431. deserialized_objects[key] = tensor
  1432. pickle_file = tar.extractfile("pickle")
  1433. unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
  1434. unpickler.persistent_load = persistent_load
  1435. result = unpickler.load()
  1436. return result
  1437. deserialized_objects = {}
  1438. def persistent_load(saved_id):
  1439. assert isinstance(saved_id, tuple)
  1440. typename = _maybe_decode_ascii(saved_id[0])
  1441. data = saved_id[1:]
  1442. if typename == "module":
  1443. # Ignore containers that don't have any sources saved
  1444. if all(data[1:]):
  1445. _check_container_source(*data)
  1446. return data[0]
  1447. elif typename == "storage":
  1448. storage_type, root_key, location, numel, view_metadata = data
  1449. location = _maybe_decode_ascii(location)
  1450. dtype = storage_type.dtype
  1451. nbytes = numel * torch._utils._element_size(dtype)
  1452. if root_key not in deserialized_objects:
  1453. if torch._guards.active_fake_mode() is not None:
  1454. obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
  1455. elif _serialization_tls.skip_data:
  1456. obj = cast(Storage, torch.UntypedStorage(nbytes))
  1457. obj = restore_location(obj, location)
  1458. else:
  1459. obj = cast(Storage, torch.UntypedStorage(nbytes))
  1460. obj._torch_load_uninitialized = True
  1461. obj = restore_location(obj, location)
  1462. # TODO: Once we decide to break serialization FC, we can
  1463. # stop wrapping with TypedStorage
  1464. typed_storage = torch.storage.TypedStorage(
  1465. wrap_storage=obj, dtype=dtype, _internal=True
  1466. )
  1467. deserialized_objects[root_key] = typed_storage
  1468. else:
  1469. typed_storage = deserialized_objects[root_key]
  1470. if typed_storage._data_ptr() == 0:
  1471. typed_storage = torch.storage.TypedStorage(
  1472. device=typed_storage._untyped_storage.device,
  1473. dtype=dtype,
  1474. _internal=True,
  1475. )
  1476. if view_metadata is not None:
  1477. view_key, offset, view_size = view_metadata
  1478. offset_bytes = offset * torch._utils._element_size(dtype)
  1479. view_size_bytes = view_size * torch._utils._element_size(dtype)
  1480. if view_key not in deserialized_objects:
  1481. # TODO: Once we decide to break serialization FC, we can
  1482. # stop wrapping with TypedStorage
  1483. deserialized_objects[view_key] = torch.storage.TypedStorage(
  1484. wrap_storage=typed_storage._untyped_storage[
  1485. offset_bytes : offset_bytes + view_size_bytes
  1486. ],
  1487. dtype=dtype,
  1488. _internal=True,
  1489. )
  1490. res = deserialized_objects[view_key]
  1491. else:
  1492. res = typed_storage
  1493. return res
  1494. else:
  1495. raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
  1496. _check_seekable(f)
  1497. f_should_read_directly = _should_read_directly(f)
  1498. if f_should_read_directly and f.tell() == 0:
  1499. # legacy_load requires that f has fileno()
  1500. # only if offset is zero we can attempt the legacy tar file loader
  1501. try:
  1502. return legacy_load(f)
  1503. except tarfile.TarError:
  1504. if _is_zipfile(f):
  1505. # .zip is used for torch.jit.save and will throw an un-pickling error here
  1506. raise RuntimeError(
  1507. f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
  1508. ) from None
  1509. # if not a tarfile, reset file offset and proceed
  1510. f.seek(0)
  1511. magic_number = pickle_module.load(f, **pickle_load_args)
  1512. if magic_number != MAGIC_NUMBER:
  1513. raise RuntimeError("Invalid magic number; corrupt file?")
  1514. protocol_version = pickle_module.load(f, **pickle_load_args)
  1515. if protocol_version != PROTOCOL_VERSION:
  1516. raise RuntimeError(f"Invalid protocol version: {protocol_version}")
  1517. _sys_info = pickle_module.load(f, **pickle_load_args)
  1518. unpickler = UnpicklerWrapper(f, **pickle_load_args)
  1519. unpickler.persistent_load = persistent_load
  1520. result = unpickler.load()
  1521. deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
  1522. if torch._guards.active_fake_mode() is None and not _serialization_tls.skip_data:
  1523. offset = f.tell() if f_should_read_directly else None
  1524. for key in deserialized_storage_keys:
  1525. assert key in deserialized_objects
  1526. typed_storage = deserialized_objects[key]
  1527. typed_storage._untyped_storage._set_from_file(
  1528. f,
  1529. offset,
  1530. f_should_read_directly,
  1531. torch._utils._element_size(typed_storage.dtype),
  1532. )
  1533. if offset is not None:
  1534. offset = f.tell()
  1535. torch._utils._validate_loaded_sparse_tensors()
  1536. return result
  1537. def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
  1538. # When using encoding='bytes' in Py3, some **internal** keys stored as
  1539. # strings in Py2 are loaded as bytes. This function decodes them with
  1540. # ascii encoding, one that Py3 uses by default.
  1541. #
  1542. # NOTE: This should only be used on internal keys (e.g., `typename` and
  1543. # `location` in `persistent_load` below!
  1544. if isinstance(bytes_str, bytes):
  1545. return bytes_str.decode("ascii")
  1546. return bytes_str
  1547. def _get_restore_location(map_location):
  1548. if map_location is None:
  1549. restore_location = default_restore_location
  1550. elif isinstance(map_location, dict):
  1551. def restore_location(storage, location):
  1552. location = map_location.get(location, location)
  1553. return default_restore_location(storage, location)
  1554. elif isinstance(map_location, (str, bytes)):
  1555. def restore_location(storage, location):
  1556. return default_restore_location(storage, map_location)
  1557. elif isinstance(map_location, torch.device):
  1558. def restore_location(storage, location):
  1559. return default_restore_location(storage, str(map_location))
  1560. else:
  1561. def restore_location(storage, location):
  1562. result = map_location(storage, location)
  1563. if result is None:
  1564. result = default_restore_location(storage, location)
  1565. return result
  1566. return restore_location
  1567. class StorageType:
  1568. def __init__(self, name):
  1569. self._dtype = _get_dtype_from_pickle_storage_type(name)
  1570. @property
  1571. def dtype(self):
  1572. return self._dtype
  1573. def __str__(self):
  1574. return f"StorageType(dtype={self.dtype})"
  1575. def _load(
  1576. zip_file,
  1577. map_location,
  1578. pickle_module,
  1579. pickle_file="data.pkl",
  1580. overall_storage=None,
  1581. **pickle_load_args,
  1582. ):
  1583. restore_location = _get_restore_location(map_location)
  1584. loaded_storages = {}
  1585. can_calculate_storage_offsets = False
  1586. if zip_file.has_record(".format_version"):
  1587. version = zip_file.get_record(".format_version")
  1588. can_calculate_storage_offsets = version >= b"1"
  1589. # check if byteswapping is needed
  1590. byteordername = "byteorder"
  1591. byteorderdata = None
  1592. if zip_file.has_record(byteordername):
  1593. byteorderdata = zip_file.get_record(byteordername)
  1594. if byteorderdata not in [b"little", b"big"]:
  1595. raise ValueError("Unknown endianness type: " + byteorderdata.decode())
  1596. elif (
  1597. get_default_load_endianness() == LoadEndianness.LITTLE
  1598. or get_default_load_endianness() is None
  1599. ):
  1600. byteorderdata = b"little"
  1601. elif get_default_load_endianness() == LoadEndianness.BIG:
  1602. byteorderdata = b"big"
  1603. elif get_default_load_endianness() == LoadEndianness.NATIVE:
  1604. pass
  1605. else:
  1606. raise ValueError("Invalid load endianness type")
  1607. storage_alignment = 64
  1608. if zip_file.has_record(".storage_alignment"):
  1609. storage_alignment = int(zip_file.get_record(".storage_alignment"))
  1610. if (
  1611. not zip_file.has_record(byteordername)
  1612. and get_default_load_endianness() is None
  1613. and sys.byteorder == "big"
  1614. ):
  1615. # Default behaviour was changed
  1616. # See https://github.com/pytorch/pytorch/issues/101688
  1617. warnings.warn(
  1618. "The default load endianness for checkpoints without a byteorder mark "
  1619. "on big endian machines was changed from 'native' to 'little' endian, "
  1620. "to avoid this behavior please use "
  1621. "torch.serialization.set_default_load_endianness to set "
  1622. "the desired default load endianness",
  1623. UserWarning,
  1624. )
  1625. from torch.utils.serialization import config
  1626. calculate_storage_offsets = config.load.calculate_storage_offsets
  1627. run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1"
  1628. current_offset = None
  1629. # constants from miniz.h/miniz.c
  1630. data_descripter_size64 = 24
  1631. data_descripter_size32 = 16
  1632. mz_uint32_max = 0xFFFFFFFF
  1633. offsets: dict[str, int] = dict()
  1634. def _get_offset(key, name, numel):
  1635. """
  1636. Return the offset of the storage associated with key with record name `name` and size numel.
  1637. It is expected that the zipfile header of this storage starts at current_offset.
  1638. WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular,
  1639. the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept
  1640. in sync with that of miniz!
  1641. After reading a storage of size numel that starts at storage_offset
  1642. if it is the first time that storage was read, update nonlocal variable
  1643. current_offset to the start of the next zipfile header by incrementing
  1644. it by numel and the data descriptor size.
  1645. """
  1646. nonlocal current_offset, offsets
  1647. if name in offsets:
  1648. storage_offset = offsets[name]
  1649. return storage_offset
  1650. if current_offset is None:
  1651. assert key == "0"
  1652. current_offset = zip_file.get_record_offset(name)
  1653. local_header_offset = zip_file.get_record_header_offset(name)
  1654. storage_offset = current_offset
  1655. else:
  1656. storage_offset = zip_file.get_record_offset_no_read(
  1657. current_offset, name, numel, storage_alignment
  1658. )
  1659. local_header_offset = current_offset
  1660. # This is only actually needed for storages that have typed_storage._data_ptr() == 0
  1661. # after being read. Otherwise persistent_load would never "re-call" load_tensor
  1662. # for a given key.
  1663. offsets[name] = storage_offset
  1664. # Increment current_offset to offset where next zipfile header starts
  1665. current_offset = storage_offset + numel
  1666. # add size of data descriptor after payload
  1667. if numel > 0:
  1668. if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max:
  1669. current_offset += data_descripter_size64
  1670. else:
  1671. current_offset += data_descripter_size32
  1672. return storage_offset
  1673. def load_tensor(dtype, numel, key, location):
  1674. name = f"data/{key}"
  1675. if torch._guards.detect_fake_mode(None) is not None:
  1676. nbytes = numel * torch._utils._element_size(dtype)
  1677. storage = torch.UntypedStorage(nbytes, device="meta")
  1678. if can_calculate_storage_offsets:
  1679. storage._checkpoint_offset = _get_offset(key, name, numel)
  1680. else:
  1681. storage._checkpoint_offset = zip_file.get_record_offset(name)
  1682. elif _serialization_tls.skip_data:
  1683. nbytes = numel * torch._utils._element_size(dtype)
  1684. storage = torch.UntypedStorage(nbytes)
  1685. elif overall_storage is not None:
  1686. if can_calculate_storage_offsets and calculate_storage_offsets:
  1687. storage_offset = _get_offset(key, name, numel)
  1688. if run_debug_asserts:
  1689. if storage_offset != zip_file.get_record_offset(name):
  1690. raise RuntimeError(
  1691. "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
  1692. f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
  1693. f"{zip_file.get_record_offset(name)}"
  1694. )
  1695. else:
  1696. storage_offset = zip_file.get_record_offset(name)
  1697. storage = overall_storage[storage_offset : storage_offset + numel]
  1698. else:
  1699. if can_calculate_storage_offsets and run_debug_asserts:
  1700. # This is debug code that we use to test the validity of
  1701. # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI
  1702. storage_offset = _get_offset(key, name, numel)
  1703. if storage_offset != zip_file.get_record_offset(name):
  1704. raise RuntimeError(
  1705. "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
  1706. f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
  1707. f"{zip_file.get_record_offset(name)}"
  1708. )
  1709. storage = (
  1710. zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
  1711. ._typed_storage()
  1712. ._untyped_storage
  1713. )
  1714. # swap here if byteswapping is needed
  1715. if byteorderdata is not None:
  1716. if byteorderdata.decode() != sys.byteorder:
  1717. storage.byteswap(dtype)
  1718. # TODO: Once we decide to break serialization FC, we can
  1719. # stop wrapping with TypedStorage
  1720. if torch._guards.detect_fake_mode(None) is None:
  1721. wrap_storage = restore_location(storage, location)
  1722. else:
  1723. storage._fake_device = location
  1724. wrap_storage = storage
  1725. typed_storage = torch.storage.TypedStorage(
  1726. wrap_storage=wrap_storage,
  1727. dtype=dtype,
  1728. _internal=True,
  1729. )
  1730. if typed_storage._data_ptr() != 0:
  1731. loaded_storages[key] = typed_storage
  1732. return typed_storage
  1733. def persistent_load(saved_id):
  1734. assert isinstance(saved_id, tuple)
  1735. typename = _maybe_decode_ascii(saved_id[0])
  1736. data = saved_id[1:]
  1737. assert typename == "storage", (
  1738. f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
  1739. )
  1740. storage_type, key, location, numel = data
  1741. if storage_type is torch.UntypedStorage:
  1742. dtype = torch.uint8
  1743. else:
  1744. dtype = storage_type.dtype
  1745. if key in loaded_storages:
  1746. typed_storage = loaded_storages[key]
  1747. else:
  1748. nbytes = numel * torch._utils._element_size(dtype)
  1749. typed_storage = load_tensor(
  1750. dtype, nbytes, key, _maybe_decode_ascii(location)
  1751. )
  1752. return typed_storage
  1753. load_module_mapping: dict[str, str] = {
  1754. # See https://github.com/pytorch/pytorch/pull/51633
  1755. "torch.tensor": "torch._tensor"
  1756. }
  1757. # Need to subclass Unpickler instead of directly monkey-patching the find_class method
  1758. # because it's marked readonly in pickle.
  1759. # The type: ignore is because mypy can't statically determine the type of this class.
  1760. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  1761. # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
  1762. # Lets us override the imports that pickle uses when unpickling an object.
  1763. # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
  1764. def find_class(self, mod_name, name):
  1765. if type(name) is str and "Storage" in name:
  1766. try:
  1767. return StorageType(name)
  1768. except KeyError:
  1769. pass
  1770. mod_name = load_module_mapping.get(mod_name, mod_name)
  1771. return super().find_class(mod_name, name)
  1772. # Load the data (which may in turn use `persistent_load` to load tensors)
  1773. data_file = io.BytesIO(zip_file.get_record(pickle_file))
  1774. unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
  1775. unpickler.persistent_load = persistent_load
  1776. # Needed for tensors where storage device and rebuild tensor device are
  1777. # not connected (wrapper subclasses and tensors rebuilt using numpy)
  1778. global _serialization_tls
  1779. _serialization_tls.map_location = map_location
  1780. result = unpickler.load()
  1781. _serialization_tls.map_location = None
  1782. torch._utils._validate_loaded_sparse_tensors()
  1783. torch._C._log_api_usage_metadata(
  1784. "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
  1785. )
  1786. return result
  1787. def _is_torchscript_zip(zip_file):
  1788. return "constants.pkl" in zip_file.get_all_records()