pytree.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. """
  2. Python polyfills for torch.utils.pytree
  3. """
  4. from __future__ import annotations
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import Any, Callable, Literal, TYPE_CHECKING
  8. from typing_extensions import TypeIs
  9. import torch.utils._pytree as python_pytree
  10. from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
  11. from ..decorators import substitute_in_graph
  12. if TYPE_CHECKING:
  13. import builtins
  14. from collections.abc import Iterable
  15. from typing_extensions import Self
  16. __all__: list[str] = []
  17. if python_pytree._cxx_pytree_dynamo_traceable:
  18. import optree
  19. import optree._C
  20. import torch.utils._cxx_pytree as cxx_pytree
  21. if TYPE_CHECKING:
  22. from torch.utils._cxx_pytree import PyTree
  23. @substitute_in_graph(
  24. optree._C.is_dict_insertion_ordered,
  25. can_constant_fold_through=True,
  26. )
  27. def _(*args: Any, **kwargs: Any) -> bool:
  28. # In namespace 'torch', the dictionary is always traversed in insertion order.
  29. # This function returns True.
  30. raise ValueError(
  31. "Should not be called directly "
  32. "because the original function will be called in the constant fold path."
  33. )
  34. __name = ""
  35. for __name in (
  36. "is_namedtuple",
  37. "is_namedtuple_class",
  38. "is_namedtuple_instance",
  39. "is_structseq",
  40. "is_structseq_class",
  41. "is_structseq_instance",
  42. "namedtuple_fields",
  43. "structseq_fields",
  44. ):
  45. __func = getattr(optree, __name)
  46. globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
  47. __func.__python_implementation__
  48. )
  49. __all__ += [__name] # noqa: PLE0604
  50. del __func
  51. del __name
  52. @substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
  53. def tree_is_leaf(
  54. tree: PyTree,
  55. is_leaf: Callable[[PyTree], bool] | None = None,
  56. ) -> bool:
  57. if tree is None or (is_leaf is not None and is_leaf(tree)):
  58. return True
  59. if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
  60. return True
  61. return False
  62. @substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
  63. def tree_iter(
  64. tree: PyTree,
  65. is_leaf: Callable[[PyTree], bool] | None = None,
  66. ) -> Iterable[Any]:
  67. stack = [tree]
  68. while stack:
  69. node = stack.pop()
  70. if tree_is_leaf(node, is_leaf=is_leaf):
  71. yield node
  72. continue
  73. children, *_ = optree.tree_flatten_one_level(
  74. node,
  75. is_leaf=is_leaf,
  76. none_is_leaf=True,
  77. namespace="torch",
  78. )
  79. stack.extend(reversed(children))
  80. __all__ += ["tree_iter"]
  81. @substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
  82. def tree_leaves(
  83. tree: PyTree,
  84. is_leaf: Callable[[PyTree], bool] | None = None,
  85. ) -> list[Any]:
  86. return list(tree_iter(tree, is_leaf=is_leaf))
  87. __all__ += ["tree_leaves"]
  88. class _Asterisk(str):
  89. __slots__ = ()
  90. def __new__(cls) -> Self:
  91. return super().__new__(cls, "*")
  92. def __repr__(self) -> str:
  93. return "*" # no quotes
  94. _asterisk = _Asterisk()
  95. del _Asterisk
  96. @dataclass(frozen=True)
  97. class PyTreeSpec:
  98. """Analog for :class:`optree.PyTreeSpec` in Python."""
  99. _children: tuple[PyTreeSpec, ...]
  100. _type: builtins.type | None
  101. _metadata: Any
  102. _entries: tuple[Any, ...]
  103. _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
  104. num_nodes: int = field(init=False)
  105. num_leaves: int = field(init=False)
  106. num_children: int = field(init=False)
  107. none_is_leaf: Literal[True] = field(init=False)
  108. namespace: Literal["torch"] = field(init=False)
  109. def __post_init__(self) -> None:
  110. if self._type is None:
  111. assert len(self._children) == 0
  112. assert self._metadata is None
  113. assert self._entries == ()
  114. assert self._unflatten_func is None
  115. num_nodes = 1
  116. num_leaves = 1
  117. num_children = 0
  118. else:
  119. assert callable(self._unflatten_func)
  120. num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
  121. num_leaves = sum(spec.num_leaves for spec in self._children)
  122. num_children = len(self._children)
  123. object.__setattr__(self, "num_nodes", num_nodes)
  124. object.__setattr__(self, "num_leaves", num_leaves)
  125. object.__setattr__(self, "num_children", num_children)
  126. object.__setattr__(self, "none_is_leaf", True)
  127. object.__setattr__(self, "namespace", "torch")
  128. def __repr__(self) -> str:
  129. def helper(treespec: PyTreeSpec) -> str:
  130. if treespec.is_leaf():
  131. assert treespec.type is None
  132. return _asterisk
  133. assert treespec.type is not None
  134. assert callable(treespec._unflatten_func)
  135. children_representations = [
  136. helper(subspec) for subspec in treespec._children
  137. ]
  138. if (
  139. treespec.type in BUILTIN_TYPES
  140. or optree.is_namedtuple_class(treespec.type)
  141. or optree.is_structseq_class(treespec.type)
  142. ):
  143. return treespec._unflatten_func(
  144. treespec._metadata,
  145. children_representations,
  146. )
  147. return (
  148. f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
  149. f"[{', '.join(children_representations)}])"
  150. )
  151. return (
  152. f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
  153. )
  154. def __len__(self) -> int:
  155. return self.num_leaves
  156. @property
  157. def type(self) -> builtins.type | None:
  158. return self._type
  159. def is_leaf(self) -> bool:
  160. return self.num_nodes == 1 and self.num_leaves == 1
  161. def children(self) -> list[PyTreeSpec]:
  162. return list(self._children)
  163. def child(self, index: int) -> PyTreeSpec:
  164. return self._children[index]
  165. def entries(self) -> list[Any]:
  166. return list(self._entries)
  167. def entry(self, index: int) -> Any:
  168. return self._entries[index]
  169. def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
  170. def helper(
  171. treespec: PyTreeSpec,
  172. node: PyTree,
  173. subtrees: list[PyTree],
  174. ) -> None:
  175. if treespec.is_leaf():
  176. subtrees.append(node)
  177. return
  178. node_type = type(node)
  179. if treespec.type not in BUILTIN_TYPES:
  180. # Always require custom node types to match exactly
  181. if node_type != treespec.type:
  182. raise ValueError(
  183. f"Type mismatch; "
  184. f"expected {treespec.type!r}, but got {node_type!r}.",
  185. )
  186. children, metadata, *_ = optree.tree_flatten_one_level(
  187. node,
  188. none_is_leaf=True,
  189. namespace="torch",
  190. )
  191. if len(children) != treespec.num_children:
  192. raise ValueError(
  193. f"Node arity mismatch; "
  194. f"expected {treespec.num_children}, but got {len(children)}.",
  195. )
  196. if metadata != treespec._metadata:
  197. raise ValueError(
  198. f"Node context mismatch for custom node type {treespec.type!r}.",
  199. )
  200. else:
  201. # For builtin dictionary types, we allow some flexibility
  202. # Otherwise, we require exact matches
  203. both_standard_dict = (
  204. treespec.type in STANDARD_DICT_TYPES
  205. and node_type in STANDARD_DICT_TYPES
  206. )
  207. if not both_standard_dict and node_type != treespec.type:
  208. raise ValueError(
  209. f"Node type mismatch; "
  210. f"expected {treespec.type!r}, but got {node_type!r}.",
  211. )
  212. if len(node) != treespec.num_children:
  213. raise ValueError(
  214. f"Node arity mismatch; "
  215. f"expected {treespec.num_children}, but got {len(node)}.",
  216. )
  217. if both_standard_dict:
  218. # dictionary types are compatible with each other
  219. expected_keys = treespec.entries()
  220. got_key_set = set(node)
  221. expected_key_set = set(expected_keys)
  222. if got_key_set != expected_key_set:
  223. missing_keys = expected_key_set.difference(got_key_set)
  224. extra_keys = got_key_set.difference(expected_key_set)
  225. message = ""
  226. if missing_keys:
  227. message += f"; missing key(s): {missing_keys}"
  228. if extra_keys:
  229. message += f"; extra key(s): {extra_keys}"
  230. raise ValueError(f"Node keys mismatch{message}.")
  231. children = [node[key] for key in expected_keys]
  232. else:
  233. # node_type is treespec.type
  234. children, metadata, *_ = optree.tree_flatten_one_level(
  235. node,
  236. none_is_leaf=True,
  237. namespace="torch",
  238. )
  239. if (
  240. node_type
  241. is not deque # ignore mismatch of `maxlen` for deque
  242. ) and metadata != treespec._metadata:
  243. raise ValueError(
  244. f"Node metadata mismatch for node type {treespec.type!r}; "
  245. f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
  246. )
  247. for subtree, subspec in zip(children, treespec._children):
  248. helper(subspec, subtree, subtrees)
  249. subtrees: list[PyTree] = []
  250. helper(self, tree, subtrees)
  251. return subtrees
  252. def unflatten(self, leaves: Iterable[Any]) -> PyTree:
  253. if not isinstance(leaves, (list, tuple)):
  254. leaves = list(leaves)
  255. if len(leaves) != self.num_leaves:
  256. raise ValueError(
  257. f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
  258. f"but the spec refers to a pytree that holds {self.num_leaves} "
  259. f"items ({self}).",
  260. )
  261. if self.is_leaf():
  262. return leaves[0]
  263. # Recursively unflatten the children
  264. start = 0
  265. end = 0
  266. subtrees = []
  267. for subspec in self._children:
  268. end += subspec.num_leaves
  269. subtrees.append(subspec.unflatten(leaves[start:end]))
  270. start = end
  271. assert callable(self._unflatten_func)
  272. return self._unflatten_func(self._metadata, subtrees)
  273. _LEAF_SPEC = PyTreeSpec((), None, None, (), None)
  274. def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
  275. return isinstance(obj, PyTreeSpec)
  276. @substitute_in_graph( # type: ignore[arg-type]
  277. cxx_pytree.tree_flatten,
  278. # We need to disable constant folding here because we want the function to reference the
  279. # PyTreeSpec class defined above, not the one in the C++ module.
  280. can_constant_fold_through=False,
  281. )
  282. def tree_flatten(
  283. tree: PyTree,
  284. is_leaf: Callable[[PyTree], bool] | None = None,
  285. ) -> tuple[list[Any], PyTreeSpec]:
  286. def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
  287. if tree_is_leaf(node, is_leaf=is_leaf):
  288. leaves.append(node)
  289. return _LEAF_SPEC
  290. (
  291. children,
  292. metadata,
  293. entries,
  294. unflatten_func,
  295. ) = optree.tree_flatten_one_level(
  296. node,
  297. is_leaf=is_leaf,
  298. none_is_leaf=True,
  299. namespace="torch",
  300. )
  301. # Recursively flatten the children
  302. subspecs = tuple(helper(child, leaves) for child in children)
  303. return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
  304. leaves: list[Any] = []
  305. treespec = helper(tree, leaves)
  306. return leaves, treespec
  307. __all__ += ["tree_flatten"]
  308. @substitute_in_graph( # type: ignore[arg-type]
  309. cxx_pytree.tree_structure,
  310. # We need to disable constant folding here because we want the function to reference the
  311. # PyTreeSpec class defined above, not the one in the C++ module.
  312. can_constant_fold_through=False,
  313. )
  314. def tree_structure(
  315. tree: PyTree,
  316. is_leaf: Callable[[PyTree], bool] | None = None,
  317. ) -> PyTreeSpec:
  318. return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
  319. __all__ += ["tree_structure"]
  320. @substitute_in_graph( # type: ignore[arg-type]
  321. cxx_pytree.tree_unflatten,
  322. # We need to disable constant folding here because we want the function to reference the
  323. # PyTreeSpec class defined above, not the one in the C++ module.
  324. can_constant_fold_through=False,
  325. )
  326. def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
  327. if not _is_pytreespec_instance(treespec):
  328. raise TypeError(
  329. f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
  330. f"PyTreeSpec but got item of type {type(treespec)}."
  331. )
  332. return treespec.unflatten(leaves)
  333. __all__ += ["tree_unflatten"]
  334. @substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
  335. def tree_map(
  336. func: Callable[..., Any],
  337. tree: PyTree,
  338. *rests: PyTree,
  339. is_leaf: Callable[[PyTree], bool] | None = None,
  340. ) -> PyTree:
  341. leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
  342. flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
  343. return treespec.unflatten(map(func, *flat_args))
  344. __all__ += ["tree_map"]
  345. @substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
  346. def tree_map_(
  347. func: Callable[..., Any],
  348. tree: PyTree,
  349. *rests: PyTree,
  350. is_leaf: Callable[[PyTree], bool] | None = None,
  351. ) -> PyTree:
  352. leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
  353. flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
  354. deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
  355. return tree
  356. __all__ += ["tree_map_"]