_graph_pickler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import dataclasses
  2. import importlib
  3. import io
  4. import pickle
  5. from abc import abstractmethod
  6. from typing import Any, Callable, NewType, Optional, TypeVar, Union
  7. from typing_extensions import override, Self
  8. import torch
  9. import torch.utils._pytree as pytree
  10. from torch._guards import TracingContext
  11. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
  12. from torch._subclasses.meta_utils import (
  13. MetaConverter,
  14. MetaTensorDesc,
  15. MetaTensorDescriber,
  16. )
  17. from torch.fx.experimental.sym_node import SymNode
  18. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  19. from torch.utils._mode_utils import no_dispatch
  20. _SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat)
  21. def _ops_filter_safe(name: str) -> bool:
  22. """
  23. An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in
  24. ones where it will be possible to unpickle on any machine which has PyTorch.
  25. """
  26. # TODO: This list is pretty pessimistic right now. What's the full list?
  27. return name.startswith(
  28. (
  29. "torch.ops.aten",
  30. "torch.ops.fbgemm",
  31. )
  32. )
  33. @dataclasses.dataclass
  34. class Options:
  35. # A filter for which ops will cause the pickler to raise a
  36. # BypassFxGraphCache exception. If None then all ops are allowed.
  37. ops_filter: Optional[Callable[[str], bool]] = _ops_filter_safe
  38. class GraphPickler(pickle.Pickler):
  39. """
  40. GraphPickler is a Pickler which helps pickling fx graph - in particular
  41. GraphModule.
  42. """
  43. def __init__(self, file: io.BytesIO, options: Optional[Options] = None) -> None:
  44. super().__init__(file)
  45. self.options = options or Options()
  46. # This abomination is so we can pass external decoding state to the
  47. # unpickler functions. We serialize _unpickle_state as a persistent
  48. # external item and when we deserialize it we return the common state
  49. # object.
  50. self._unpickle_state = _UnpickleStateToken(object())
  51. # This is used to describe tensors. It needs to be common across the
  52. # pickle so that duplicates and views are properly handled.
  53. self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
  54. @override
  55. def reducer_override(
  56. self, obj: object
  57. ) -> tuple[Callable[..., Any], tuple[Any, ...]]:
  58. # This function is supposed to return either NotImplemented (meaning to
  59. # do the default pickle behavior) or a pair of (unpickle callable, data
  60. # to pass to unpickle).
  61. # We could instead teach individual classes how to pickle themselves but
  62. # that has a few problems:
  63. #
  64. # 1. If we have some special needs (maybe for this use-case we don't
  65. # want to fully serialize every field) then we're adding private
  66. # details to a public interface.
  67. #
  68. # 2. If we need to have some common shared data (such as a
  69. # FakeTensorMode) which is passed to each value it's harder to
  70. # support.
  71. # These are the types that need special handling. See the individual
  72. # *PickleData classes for details on pickling that particular type.
  73. if isinstance(obj, FakeTensor):
  74. return _TensorPickleData.reduce_helper(self, obj)
  75. elif isinstance(obj, torch.fx.GraphModule):
  76. return _GraphModulePickleData.reduce_helper(self, obj)
  77. elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)):
  78. return _OpPickleData.reduce_helper(self, obj)
  79. elif isinstance(obj, ShapeEnv):
  80. return _ShapeEnvPickleData.reduce_helper(self, obj)
  81. elif isinstance(obj, torch.SymInt):
  82. return _SymNodePickleData.reduce_helper(self, obj)
  83. elif isinstance(obj, torch._guards.TracingContext):
  84. return _TracingContextPickleData.reduce_helper(self, obj)
  85. else:
  86. # We should never get a raw Node!
  87. assert not isinstance(obj, torch.fx.Node)
  88. if reduce := _TorchNumpyPickleData.reduce_helper(self, obj):
  89. return reduce
  90. # returning `NotImplemented` causes pickle to revert to the default
  91. # behavior for this object.
  92. return NotImplemented
  93. @override
  94. def persistent_id(self, obj: object) -> Optional[str]:
  95. if obj is self._unpickle_state:
  96. return "unpickle_state"
  97. else:
  98. return None
  99. @classmethod
  100. def dumps(cls, obj: object, options: Optional[Options] = None) -> bytes:
  101. """
  102. Pickle an object.
  103. """
  104. with io.BytesIO() as stream:
  105. pickler = cls(stream, options)
  106. pickler.dump(obj)
  107. return stream.getvalue()
  108. @staticmethod
  109. def loads(data: bytes, fake_mode: FakeTensorMode) -> object:
  110. """
  111. Unpickle an object.
  112. """
  113. state = _UnpickleState(fake_mode)
  114. with io.BytesIO(data) as stream:
  115. unpickler = _GraphUnpickler(stream, state)
  116. return unpickler.load()
  117. class _UnpickleState:
  118. def __init__(self, fake_mode: FakeTensorMode) -> None:
  119. self.fake_mode = fake_mode
  120. self.meta_converter: MetaConverter[FakeTensor] = MetaConverter()
  121. # This token is passed when pickling to indicate that we want to use the
  122. # unpickler's _UnpickleState as a parameter in that position.
  123. _UnpickleStateToken = NewType("_UnpickleStateToken", object)
  124. class _GraphUnpickler(pickle.Unpickler):
  125. def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None:
  126. super().__init__(stream)
  127. self._unpickle_state = unpickle_state
  128. @override
  129. def persistent_load(self, pid: object) -> object:
  130. if pid == "unpickle_state":
  131. return self._unpickle_state
  132. else:
  133. raise pickle.UnpicklingError("Invalid persistent ID")
  134. class _ShapeEnvPickleData:
  135. data: dict[str, object]
  136. @classmethod
  137. def reduce_helper(
  138. cls, pickler: GraphPickler, obj: ShapeEnv
  139. ) -> tuple[
  140. Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken]
  141. ]:
  142. return cls.unpickle, (cls(obj), pickler._unpickle_state)
  143. def __init__(self, env: ShapeEnv) -> None:
  144. # In theory pickle should recognize that a given ShapeEnv was already
  145. # pickled and reuse the resulting _ShapeEnvPickleData (so two objects
  146. # pointing at the same ShapeEnv get the same ShapeEnv out).
  147. assert not env._translation_validation_enabled
  148. self.data = env.__dict__.copy()
  149. del self.data["tracked_fakes"]
  150. del self.data["fake_tensor_cache"]
  151. def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv:
  152. # Fill in the existing ShapeEnv rather than creating a new one
  153. assert unpickle_state.fake_mode
  154. assert unpickle_state.fake_mode.shape_env
  155. for k, v in self.data.items():
  156. setattr(unpickle_state.fake_mode.shape_env, k, v)
  157. return unpickle_state.fake_mode.shape_env
  158. class _SymNodePickleData:
  159. @classmethod
  160. def reduce_helper(
  161. cls,
  162. pickler: GraphPickler,
  163. obj: _SymNodeT,
  164. ) -> tuple[
  165. Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken]
  166. ]:
  167. args = (cls(obj.node), pickler._unpickle_state)
  168. if isinstance(obj, torch.SymInt):
  169. return _SymNodePickleData.unpickle_sym_int, args
  170. else:
  171. raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
  172. def __init__(self, node: SymNode) -> None:
  173. self.expr = node._expr
  174. self.shape_env = node.shape_env
  175. self.pytype = node.pytype
  176. self.hint = node._hint
  177. def _to_sym_node(self) -> SymNode:
  178. assert self.shape_env is not None
  179. return SymNode(self.expr, self.shape_env, self.pytype, self.hint)
  180. def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt:
  181. return torch.SymInt(self._to_sym_node())
  182. class _TensorPickleData:
  183. metadata: MetaTensorDesc[FakeTensor]
  184. @classmethod
  185. def reduce_helper(
  186. cls, pickler: GraphPickler, obj: FakeTensor
  187. ) -> tuple[
  188. Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken]
  189. ]:
  190. return cls.unpickle, (
  191. cls(pickler._meta_tensor_describer, obj),
  192. pickler._unpickle_state,
  193. )
  194. def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None:
  195. # THINGS TO WORRY ABOUT:
  196. # 1. Need to make sure that two tensors with the same id end up with the
  197. # same id on the other side of the wire.
  198. metadata = describer.describe_tensor(t)
  199. # view_func is fine if it's either None or a _FakeTensorViewFunc. A
  200. # custom one (which is basically a lambda) can't be serialized.
  201. assert not metadata.view_func or isinstance(
  202. metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc
  203. )
  204. self.metadata = dataclasses.replace(metadata, fake_mode=None)
  205. # Some debugging/verification
  206. for k in MetaTensorDesc._UNSERIALIZABLE:
  207. if k in ("fake_mode", "view_func"):
  208. continue
  209. assert getattr(self.metadata, k) is None, (
  210. f"not None: {k}: {getattr(self.metadata, k)}"
  211. )
  212. def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
  213. # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?
  214. metadata = dataclasses.replace(
  215. self.metadata,
  216. fake_mode=unpickle_state.fake_mode,
  217. )
  218. def with_fake(
  219. make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str]
  220. ) -> FakeTensor:
  221. with no_dispatch():
  222. return FakeTensor(
  223. unpickle_state.fake_mode,
  224. make_meta_t(),
  225. device,
  226. )
  227. return unpickle_state.meta_converter.meta_tensor(
  228. metadata,
  229. unpickle_state.fake_mode.shape_env,
  230. with_fake,
  231. None,
  232. None,
  233. )
  234. class _TorchNumpyPickleData:
  235. @classmethod
  236. def reduce_helper(
  237. cls, pickler: GraphPickler, obj: object
  238. ) -> Optional[
  239. tuple[
  240. Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken]
  241. ]
  242. ]:
  243. if data := cls.from_object(obj):
  244. return (cls.unpickle, (data, pickler._unpickle_state))
  245. else:
  246. return None
  247. def __init__(self, mod: str, name: str) -> None:
  248. self.mod = mod
  249. self.name = name
  250. def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]:
  251. np = getattr(importlib.import_module(self.mod), self.name)
  252. return torch._dynamo.variables.misc.get_np_to_tnp_map()[np]
  253. @classmethod
  254. def from_object(cls, tnp: object) -> Optional[Self]:
  255. if not callable(tnp):
  256. return None
  257. tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map()
  258. try:
  259. if not (np := tnp_to_np.get(tnp)):
  260. return None
  261. except TypeError:
  262. return None
  263. if not (mod := getattr(np, "__module__", None)):
  264. mod = "numpy"
  265. if not (name := getattr(np, "__name__", None)):
  266. return None
  267. assert np == getattr(importlib.import_module(mod), name)
  268. return cls(mod, name)
  269. class _GraphModulePickleData:
  270. @classmethod
  271. def reduce_helper(
  272. cls, pickler: GraphPickler, obj: torch.fx.GraphModule
  273. ) -> tuple[
  274. Callable[[Self, _UnpickleState], torch.fx.GraphModule],
  275. tuple[Self, _UnpickleStateToken],
  276. ]:
  277. return cls.unpickle, (
  278. cls(obj, pickler.options),
  279. pickler._unpickle_state,
  280. )
  281. def __init__(self, gm: torch.fx.GraphModule, options: Options) -> None:
  282. # Need to do this to ensure the code is created for later pickling.
  283. if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
  284. _python_code = gm._real_recompile()
  285. else:
  286. _python_code = gm.recompile()
  287. self.gm_dict = gm.__dict__.copy()
  288. del self.gm_dict["_graph"]
  289. self.graph = _GraphPickleData(gm._graph, options)
  290. def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule:
  291. gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule)
  292. gm.__dict__ = self.gm_dict
  293. gm._graph = self.graph.unpickle(gm, unpickle_state)
  294. return gm
  295. class _NodePickleData:
  296. def __init__(
  297. self,
  298. node: torch.fx.Node,
  299. mapping: dict[torch.fx.Node, "_NodePickleData"],
  300. options: Options,
  301. ) -> None:
  302. self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
  303. self.kwargs = pytree.tree_map_only(
  304. torch.fx.Node, lambda n: mapping[n], node.kwargs
  305. )
  306. # -- self.graph = node.graph
  307. self.name = node.name
  308. self.op = node.op
  309. self.target = _OpPickleData.pickle(node.target, options)
  310. # self.input_nodes = node._input_nodes
  311. # self.users = node.users
  312. self.type = node.type
  313. # self.sort_key = node._sort_key
  314. # self.repr_fn = node._repr_fn
  315. # self.meta = node.meta
  316. self.meta = node.meta
  317. def unpickle(
  318. self,
  319. graph: torch.fx.Graph,
  320. mapping: dict["_NodePickleData", torch.fx.Node],
  321. unpickle_state: _UnpickleState,
  322. ) -> torch.fx.Node:
  323. args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
  324. kwargs = pytree.tree_map_only(
  325. _NodePickleData, lambda n: mapping[n], self.kwargs
  326. )
  327. target = self.target.unpickle(unpickle_state)
  328. assert callable(target) or isinstance(target, str)
  329. node = graph.create_node(self.op, target, args, kwargs, self.name, self.type)
  330. node.meta = self.meta
  331. return node
  332. class _OpPickleData:
  333. @classmethod
  334. def reduce_helper(
  335. cls, pickler: GraphPickler, op: object
  336. ) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]:
  337. result = cls.pickle(op, pickler.options)
  338. return (result.unpickle, (pickler._unpickle_state,))
  339. @classmethod
  340. def pickle(cls, op: object, options: Options) -> "_OpPickleData":
  341. if isinstance(op, str):
  342. return _OpStrPickleData(op)
  343. name = torch.fx.Node._pretty_print_target(op)
  344. if isinstance(op, torch._ops.OpOverload):
  345. return cls._pickle_op(name, _OpOverloadPickleData, options)
  346. elif isinstance(op, torch._ops.OpOverloadPacket):
  347. return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
  348. elif name.startswith(("builtins.", "math.", "torch.")):
  349. root, detail = name.split(".", 1)
  350. return _OpBuiltinPickleData(root, detail)
  351. elif name.startswith("operator."):
  352. _, detail = name.split(".", 1)
  353. return _OpOperatorPickleData(detail)
  354. else:
  355. # TODO: raise a BypassFxGraphCache so we will just bypass this one...
  356. raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
  357. @staticmethod
  358. def _pickle_op(
  359. name: str,
  360. datacls: Union[
  361. type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"]
  362. ],
  363. options: Options,
  364. ) -> "_OpPickleData":
  365. if (ops_filter := options.ops_filter) and not ops_filter(name):
  366. from torch._inductor.codecache import BypassFxGraphCache
  367. raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}")
  368. return datacls(name)
  369. @abstractmethod
  370. def unpickle(self, unpickle_state: _UnpickleState) -> object:
  371. pass
  372. @classmethod
  373. def _lookup_global_by_name(cls, name: str) -> object:
  374. """
  375. Like `globals()[name]` but supports dotted names.
  376. """
  377. if "." in name:
  378. mod, rest = name.split(".", 1)
  379. root = globals()[mod]
  380. return cls._getattr_by_name(root, rest)
  381. else:
  382. return globals()[name]
  383. @staticmethod
  384. def _getattr_by_name(root: object, name: str) -> object:
  385. """
  386. Like `getattr(root, name)` but supports dotted names.
  387. """
  388. while "." in name:
  389. mod, name = name.split(".", 1)
  390. root = getattr(root, mod)
  391. return getattr(root, name)
  392. class _OpStrPickleData(_OpPickleData):
  393. def __init__(self, name: str) -> None:
  394. self.name = name
  395. def unpickle(self, unpickle_state: _UnpickleState) -> str:
  396. return self.name
  397. class _OpOverloadPickleData(_OpPickleData):
  398. def __init__(self, name: str) -> None:
  399. self.name = name
  400. def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload:
  401. obj = self._lookup_global_by_name(self.name)
  402. assert isinstance(obj, torch._ops.OpOverload)
  403. return obj
  404. class _OpOverloadPacketPickleData(_OpPickleData):
  405. def __init__(self, name: str) -> None:
  406. self.name = name
  407. def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket:
  408. obj = self._lookup_global_by_name(self.name)
  409. assert isinstance(obj, torch._ops.OpOverloadPacket)
  410. return obj
  411. class _OpBuiltinPickleData(_OpPickleData):
  412. def __init__(self, root: str, name: str) -> None:
  413. self.root = root
  414. self.name = name
  415. def unpickle(self, unpickle_state: _UnpickleState) -> object:
  416. if self.root == "builtins":
  417. return __builtins__.get(self.name) # type: ignore[attr-defined]
  418. elif self.root == "math":
  419. import math
  420. return self._getattr_by_name(math, self.name)
  421. elif self.root == "torch":
  422. return self._getattr_by_name(torch, self.name)
  423. else:
  424. raise NotImplementedError
  425. class _OpOperatorPickleData(_OpPickleData):
  426. def __init__(self, name: str) -> None:
  427. self.name = name
  428. def unpickle(self, unpickle_state: _UnpickleState) -> object:
  429. import operator
  430. return self._getattr_by_name(operator, self.name)
  431. class _GraphPickleData:
  432. def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
  433. self.tracer_cls = graph._tracer_cls
  434. self.tracer_extras = graph._tracer_extras
  435. nodes: dict[torch.fx.Node, _NodePickleData] = {}
  436. for node in graph.nodes:
  437. nodes[node] = _NodePickleData(node, nodes, options)
  438. self.nodes = tuple(nodes.values())
  439. # Unpickled variables:
  440. # self._used_names = graph._used_names
  441. # -- self._insert = self._root.prepend
  442. # self._len = graph._len
  443. # self._graph_namespace = graph._graph_namespace
  444. # self._owning_module = graph._owning_module
  445. # self._codegen = graph._codegen
  446. # self._co_fields: Dict[str, Any] = graph._co_fields
  447. # -- self._find_nodes_lookup_table = _FindNodesLookupTable()
  448. def unpickle(
  449. self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState
  450. ) -> torch.fx.Graph:
  451. graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
  452. nodes: dict[_NodePickleData, torch.fx.Node] = {}
  453. for nd in self.nodes:
  454. nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
  455. return graph
  456. class _TracingContextPickleData:
  457. @classmethod
  458. def reduce_helper(
  459. cls, pickler: GraphPickler, obj: torch._guards.TracingContext
  460. ) -> tuple[
  461. Callable[[Self, _UnpickleState], torch._guards.TracingContext],
  462. tuple[Self, _UnpickleStateToken],
  463. ]:
  464. return (
  465. cls.unpickle,
  466. (
  467. cls(obj),
  468. pickler._unpickle_state,
  469. ),
  470. )
  471. def __init__(self, context: TracingContext) -> None:
  472. # TODO: Do we really need all of this?
  473. self.module_context = context.module_context
  474. self.frame_summary_stack = context.frame_summary_stack
  475. self.loc_in_frame = context.loc_in_frame
  476. self.aot_graph_name = context.aot_graph_name
  477. self.params_flat = context.params_flat
  478. self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses
  479. self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index
  480. self.output_strides = context.output_strides
  481. self.force_unspec_int_unbacked_size_like = (
  482. context.force_unspec_int_unbacked_size_like
  483. )
  484. # Not saved (because it's difficult and maybe not needed?):
  485. # self.fw_metadata = context.fw_metadata
  486. # self.guards_context = None
  487. # self.global_context = None
  488. # self.fake_mode = None
  489. # self.fakify_first_call = None
  490. # self.hop_dispatch_set_cache = None
  491. # self.tensor_to_context = context.tensor_to_context
  492. def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext:
  493. context = TracingContext(unpickle_state.fake_mode)
  494. context.module_context = self.module_context
  495. context.frame_summary_stack = self.frame_summary_stack
  496. context.loc_in_frame = self.loc_in_frame
  497. context.aot_graph_name = self.aot_graph_name
  498. context.params_flat = self.params_flat
  499. context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses
  500. context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index
  501. context.output_strides = self.output_strides
  502. context.force_unspec_int_unbacked_size_like = (
  503. self.force_unspec_int_unbacked_size_like
  504. )
  505. return context