_serialization.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # mypy: allow-untyped-defs
  2. """Serialization.
  3. This module contains functionality for serializing TorchScript modules, notably:
  4. * torch.jit.save
  5. * torch.jit.load
  6. This is not intended to be imported directly; please use the exposed
  7. functionalities in `torch.jit`.
  8. """
  9. import os
  10. import torch
  11. from torch._jit_internal import _get_model_id
  12. from torch._utils_internal import log_torchscript_usage
  13. from torch.jit._recursive import wrap_cpp_module
  14. from torch.serialization import validate_cuda_device
  15. def save(m, f, _extra_files=None):
  16. r"""
  17. Save an offline version of this module for use in a separate process.
  18. The saved module serializes all of the methods, submodules, parameters, and
  19. attributes of this module. It can be loaded into the C++ API using
  20. ``torch::jit::load(filename)`` or into the Python API with
  21. :func:`torch.jit.load <torch.jit.load>`.
  22. To be able to save a module, it must not make any calls to native Python
  23. functions. This means that all submodules must be subclasses of
  24. :class:`ScriptModule` as well.
  25. .. DANGER::
  26. All modules, no matter their device, are always loaded onto the CPU
  27. during loading. This is different from :func:`torch.load`'s semantics
  28. and may change in the future.
  29. Args:
  30. m: A :class:`ScriptModule` to save.
  31. f: A file-like object (has to implement write and flush) or a string
  32. containing a file name.
  33. _extra_files: Map from filename to contents which will be stored as part of `f`.
  34. .. note::
  35. torch.jit.save attempts to preserve the behavior of some operators
  36. across versions. For example, dividing two integer tensors in
  37. PyTorch 1.5 performed floor division, and if the module
  38. containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
  39. its division behavior will be preserved. The same module saved in
  40. PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
  41. behavior of division changed in 1.6, and 1.5 does not know how to
  42. replicate the 1.6 behavior.
  43. Example:
  44. .. testcode::
  45. import torch
  46. import io
  47. class MyModule(torch.nn.Module):
  48. def forward(self, x):
  49. return x + 10
  50. m = torch.jit.script(MyModule())
  51. # Save to file
  52. torch.jit.save(m, 'scriptmodule.pt')
  53. # This line is equivalent to the previous
  54. m.save("scriptmodule.pt")
  55. # Save to io.BytesIO buffer
  56. buffer = io.BytesIO()
  57. torch.jit.save(m, buffer)
  58. # Save with extra files
  59. extra_files = {'foo.txt': b'bar'}
  60. torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
  61. """
  62. log_torchscript_usage("save", model_id=_get_model_id(m))
  63. if _extra_files is None:
  64. _extra_files = {}
  65. if isinstance(f, (str, os.PathLike)):
  66. m.save(f, _extra_files=_extra_files)
  67. else:
  68. ret = m.save_to_buffer(_extra_files=_extra_files)
  69. f.write(ret)
  70. def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
  71. r"""
  72. Load a :class:`ScriptModule` or :class:`ScriptFunction` previously saved with :func:`torch.jit.save <torch.jit.save>`.
  73. All previously saved modules, no matter their device, are first loaded onto CPU,
  74. and then are moved to the devices they were saved from. If this fails (e.g.
  75. because the run time system doesn't have certain devices), an exception is
  76. raised.
  77. Args:
  78. f: a file-like object (has to implement read, readline, tell, and seek),
  79. or a string containing a file name
  80. map_location (string or torch.device): A simplified version of
  81. ``map_location`` in `torch.jit.save` used to dynamically remap
  82. storages to an alternative set of devices.
  83. _extra_files (dictionary of filename to content): The extra
  84. filenames given in the map would be loaded and their content
  85. would be stored in the provided map.
  86. _restore_shapes (bool): Whether or not to retrace the module on load using stored inputs
  87. Returns:
  88. A :class:`ScriptModule` object.
  89. .. warning::
  90. It is possible to construct malicious pickle data which will execute arbitrary code
  91. during func:`torch.jit.load`. Never load data that could have come from an untrusted
  92. source, or that could have been tampered with. **Only load data you trust**.
  93. Example:
  94. .. testcode::
  95. import torch
  96. import io
  97. torch.jit.load('scriptmodule.pt')
  98. # Load ScriptModule from io.BytesIO object
  99. with open('scriptmodule.pt', 'rb') as f:
  100. buffer = io.BytesIO(f.read())
  101. # Load all tensors to the original device
  102. torch.jit.load(buffer)
  103. # Load all tensors onto CPU, using a device
  104. buffer.seek(0)
  105. torch.jit.load(buffer, map_location=torch.device('cpu'))
  106. # Load all tensors onto CPU, using a string
  107. buffer.seek(0)
  108. torch.jit.load(buffer, map_location='cpu')
  109. # Load with extra files.
  110. extra_files = {'foo.txt': ''} # values will be replaced with data
  111. torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
  112. print(extra_files['foo.txt'])
  113. .. testoutput::
  114. :hide:
  115. ...
  116. .. testcleanup::
  117. import os
  118. os.remove("scriptmodule.pt")
  119. """
  120. if isinstance(f, (str, os.PathLike)):
  121. if not os.path.exists(f):
  122. raise ValueError(f"The provided filename {f} does not exist")
  123. if os.path.isdir(f):
  124. raise ValueError(f"The provided filename {f} is a directory")
  125. map_location = validate_map_location(map_location)
  126. if _extra_files is None:
  127. _extra_files = {}
  128. cu = torch._C.CompilationUnit()
  129. if isinstance(f, (str, os.PathLike)):
  130. cpp_module = torch._C.import_ir_module(
  131. cu, os.fspath(f), map_location, _extra_files, _restore_shapes
  132. ) # type: ignore[call-arg]
  133. else:
  134. cpp_module = torch._C.import_ir_module_from_buffer(
  135. cu, f.read(), map_location, _extra_files, _restore_shapes
  136. ) # type: ignore[call-arg]
  137. # TODO: Pretty sure this approach loses ConstSequential status and such
  138. ret = wrap_cpp_module(cpp_module)
  139. log_torchscript_usage("load", model_id=_get_model_id(ret))
  140. return ret
  141. def validate_map_location(map_location=None):
  142. if isinstance(map_location, str):
  143. map_location = torch.device(map_location)
  144. elif not (map_location is None or isinstance(map_location, torch.device)):
  145. raise ValueError(
  146. "map_location should be either None, string or torch.device, "
  147. "but got type: " + str(type(map_location))
  148. )
  149. if str(map_location).startswith("cuda"):
  150. validate_cuda_device(map_location)
  151. return map_location
  152. def jit_module_from_flatbuffer(f):
  153. if isinstance(f, (str, os.PathLike)):
  154. f = os.fspath(f)
  155. return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
  156. else:
  157. return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
  158. def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
  159. r"""
  160. Save an offline version of this module for use in a separate process.
  161. The saved module serializes all of the methods, submodules, parameters, and
  162. attributes of this module. It can be loaded into the C++ API using
  163. ``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with
  164. :func:`torch.jit.jit_module_from_flatbuffer<torch.jit.jit_module_from_flatbuffer>`.
  165. To be able to save a module, it must not make any calls to native Python
  166. functions. This means that all submodules must be subclasses of
  167. :class:`ScriptModule` as well.
  168. .. DANGER::
  169. All modules, no matter their device, are always loaded onto the CPU
  170. during loading. This is different from :func:`torch.load`'s semantics
  171. and may change in the future.
  172. Args:
  173. m: A :class:`ScriptModule` to save.
  174. f: A string for file path
  175. Example:
  176. .. testcode::
  177. import torch
  178. import io
  179. class MyModule(torch.nn.Module):
  180. def forward(self, x):
  181. return x + 10
  182. m = torch.jit.script(MyModule())
  183. # Save to file
  184. torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
  185. """
  186. extra_files = _extra_files
  187. if extra_files is None:
  188. extra_files = {}
  189. if isinstance(f, (str, os.PathLike)):
  190. f = os.fspath(f)
  191. torch._C._save_jit_module(m._c, f, extra_files)
  192. else:
  193. s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
  194. f.write(s)
  195. def get_flatbuffer_module_info(path_or_file):
  196. r"""Get some information regarding a model file in flatbuffer format.
  197. Args:
  198. path_or_file: Either str, Path or file like object (BytesIO OK).
  199. If it's str or Path, we will read the file referenced by that
  200. path as Bytes.
  201. Returns:
  202. A dict with metadata on what that file contains, currently looks like
  203. this:
  204. {
  205. 'bytecode_version': 4, # int
  206. 'operator_version': 4, # int
  207. 'function_names': {
  208. '__torch__.___torch_mangle_0.Foo.forward'}, # set
  209. 'type_names': set(), # set
  210. 'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
  211. }
  212. """
  213. if isinstance(path_or_file, (str, os.PathLike)):
  214. with open(path_or_file, "rb") as f:
  215. all_bytes = f.read()
  216. else:
  217. all_bytes = path_or_file.read()
  218. return torch._C._get_module_info_from_flatbuffer(all_bytes)