_serialization.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import pickle
  2. from dataclasses import dataclass
  3. from io import BufferedIOBase
  4. from typing import Any
  5. import torch
  6. import torch._weights_only_unpickler as _weights_only_unpickler
  7. from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
  8. __all__: list[str] = []
  9. @dataclass
  10. class _Entry:
  11. key: str
  12. is_storage: bool
  13. length: int
  14. _weights_only_unpickler._add_safe_globals([_Entry])
  15. class _PseudoZipFile:
  16. def __init__(self) -> None:
  17. self.records: dict[str, tuple[object, int]] = {}
  18. def write_record(self, key: str, data: object, length: int) -> None:
  19. self.records[key] = (data, length)
  20. def write_to(self, f: BufferedIOBase) -> None:
  21. entries = []
  22. for key, (data, length) in self.records.items():
  23. entries.append(
  24. _Entry(
  25. key=key,
  26. is_storage=isinstance(data, torch.UntypedStorage),
  27. length=length,
  28. )
  29. )
  30. pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
  31. for key, (data, length) in self.records.items():
  32. if isinstance(data, bytes):
  33. f.write(data)
  34. elif isinstance(data, str):
  35. f.write(data.encode("utf-8"))
  36. elif isinstance(data, torch.UntypedStorage):
  37. data._write_file(f, False, False, 1)
  38. else:
  39. raise TypeError(f"unknown type: {type(data)}")
  40. def read_from(self, f: BufferedIOBase) -> None:
  41. entries = _weights_only_unpickler.load(f)
  42. for entry in entries:
  43. data = f.read(entry.length)
  44. if entry.is_storage:
  45. storage = torch.frombuffer(
  46. data,
  47. dtype=torch.uint8,
  48. ).untyped_storage()
  49. self.records[entry.key] = (
  50. storage,
  51. entry.length,
  52. )
  53. else:
  54. self.records[entry.key] = (data, entry.length)
  55. def has_record(self, key: str) -> bool:
  56. return key in self.records
  57. def get_record(self, key: str) -> object:
  58. return self.records[key][0]
  59. def get_storage_from_record(
  60. self, key: str, _length: int, _type: int
  61. ) -> torch.Tensor:
  62. return torch.tensor(self.records[key][0], dtype=torch.uint8)
  63. def serialization_id(self) -> str:
  64. return "torchft"
  65. def _streaming_save(
  66. obj: object,
  67. f: BufferedIOBase,
  68. pickle_module: Any = pickle,
  69. pickle_protocol: int = DEFAULT_PROTOCOL,
  70. ) -> None:
  71. """
  72. Save the object to a file-like object in a streaming fashion compatible with
  73. network sockets.
  74. This behaves similarly to :func:`torch.save` with a few notable differences:
  75. * A non-seekable file like object can be used when loading.
  76. * No forwards/backwards compatibility is provided for the serialization
  77. format. This is only intended to be used with a single version of PyTorch
  78. with transient storage (i.e. sockets or temp files).
  79. * mmap is not supported
  80. See :func:`torch.save` for more details on specific arguments.
  81. """
  82. zip_file = _PseudoZipFile()
  83. _save(
  84. obj,
  85. zip_file=zip_file,
  86. pickle_module=pickle_module,
  87. pickle_protocol=pickle_protocol,
  88. _disable_byteorder_record=False,
  89. )
  90. zip_file.write_to(f)
  91. def _streaming_load(
  92. f: BufferedIOBase,
  93. map_location: MAP_LOCATION = None,
  94. pickle_module: Any = None,
  95. *,
  96. weights_only: bool = True,
  97. **pickle_load_args: Any,
  98. ) -> object:
  99. """
  100. Load the object from a file-like object in a streaming fashion compatible with
  101. network sockets.
  102. See :func:`_streaming_save` for more details about the streaming behavior.
  103. See :func:`torch.load` for more details on specific arguments.
  104. """
  105. if weights_only:
  106. if pickle_module is not None:
  107. raise RuntimeError(
  108. "Can not safely load weights when explicit pickle_module is specified"
  109. )
  110. pickle_module = _weights_only_unpickler
  111. else:
  112. if pickle_module is None:
  113. pickle_module = pickle
  114. if "encoding" not in pickle_load_args.keys():
  115. pickle_load_args["encoding"] = "utf-8"
  116. zip_file = _PseudoZipFile()
  117. zip_file.read_from(f)
  118. return _load(
  119. zip_file=zip_file,
  120. map_location=map_location,
  121. pickle_module=pickle_module,
  122. **pickle_load_args,
  123. )