_pytree.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068
  1. """
  2. Contains utility functions for working with nested python data structures.
  3. A *pytree* is Python nested data structure. It is a tree in the sense that
  4. nodes are Python collections (e.g., list, tuple, dict) and the leaves are
  5. Python values. Furthermore, a pytree should not contain reference cycles.
  6. pytrees are useful for working with nested collections of Tensors. For example,
  7. one can use `tree_map` to map a function over all Tensors inside some nested
  8. collection of Tensors and `tree_leaves` to get a flat list of all Tensors
  9. inside some nested collection. pytrees are helpful for implementing nested
  10. collection support for PyTorch APIs.
  11. This pytree implementation is not very performant due to Python overhead
  12. To improve the performance we can move parts of the implementation to C++.
  13. """
  14. import dataclasses
  15. import functools
  16. import importlib
  17. import importlib.metadata
  18. import json
  19. import sys
  20. import threading
  21. import types
  22. import warnings
  23. from collections import defaultdict, deque, namedtuple, OrderedDict
  24. from collections.abc import Hashable, Iterable, Mapping, Sequence
  25. from enum import Enum
  26. from typing import (
  27. Any,
  28. Callable,
  29. cast,
  30. ClassVar,
  31. Final,
  32. Generic,
  33. NoReturn,
  34. Optional,
  35. overload,
  36. Protocol,
  37. TypeVar,
  38. Union,
  39. )
  40. from typing_extensions import deprecated, NamedTuple, Self
  41. from torch.torch_version import TorchVersion as _TorchVersion
  42. __all__ = [
  43. "PyTree",
  44. "Context",
  45. "FlattenFunc",
  46. "UnflattenFunc",
  47. "DumpableContext",
  48. "ToDumpableContextFn",
  49. "FromDumpableContextFn",
  50. "TreeSpec",
  51. "LeafSpec",
  52. "keystr",
  53. "key_get",
  54. "register_pytree_node",
  55. "tree_is_leaf",
  56. "tree_flatten",
  57. "tree_flatten_with_path",
  58. "tree_unflatten",
  59. "tree_iter",
  60. "tree_leaves",
  61. "tree_leaves_with_path",
  62. "tree_structure",
  63. "tree_map",
  64. "tree_map_with_path",
  65. "tree_map_",
  66. "tree_map_only",
  67. "tree_map_only_",
  68. "tree_all",
  69. "tree_any",
  70. "tree_all_only",
  71. "tree_any_only",
  72. "treespec_dumps",
  73. "treespec_loads",
  74. "treespec_pprint",
  75. "is_namedtuple",
  76. "is_namedtuple_class",
  77. "is_namedtuple_instance",
  78. "is_structseq",
  79. "is_structseq_class",
  80. "is_structseq_instance",
  81. ]
  82. T = TypeVar("T")
  83. S = TypeVar("S")
  84. U = TypeVar("U")
  85. R = TypeVar("R")
  86. DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
  87. NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
  88. class KeyEntry(Protocol):
  89. def __hash__(self) -> int: ...
  90. def __eq__(self, other: object) -> bool: ...
  91. def __str__(self) -> str: ...
  92. def get(self, parent: Any) -> Any: ...
  93. class EnumEncoder(json.JSONEncoder):
  94. def default(self, obj: object) -> Union[str, dict[str, Any]]:
  95. if isinstance(obj, Enum):
  96. return {
  97. "__enum__": True,
  98. "fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
  99. "name": obj.name,
  100. }
  101. return cast(str, super().default(obj))
  102. Context = Any
  103. PyTree = Any
  104. FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
  105. UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
  106. DumpableContext = Any # Any json dumpable text
  107. ToDumpableContextFn = Callable[[Context], DumpableContext]
  108. FromDumpableContextFn = Callable[[DumpableContext], Context]
  109. ToStrFunc = Callable[["TreeSpec", list[str]], str]
  110. MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]]
  111. KeyPath = tuple[KeyEntry, ...]
  112. FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
  113. # A NodeDef holds two callables:
  114. # - flatten_fn should take the collection and return a flat list of values.
  115. # It can also return some context that is used in reconstructing the
  116. # collection.
  117. # - unflatten_fn should take a flat list of values and some context
  118. # (returned by flatten_fn). It returns the collection by reconstructing
  119. # it from the list and the context.
  120. # - flatten_with_keys_fn, which is a callable that takes a
  121. # pytree and returns a list of (keypath, value) pairs and a context.
  122. class NodeDef(NamedTuple):
  123. type: type[Any]
  124. flatten_fn: FlattenFunc
  125. unflatten_fn: UnflattenFunc
  126. flatten_with_keys_fn: Optional[FlattenWithKeysFunc]
  127. _NODE_REGISTRY_LOCK = threading.RLock()
  128. SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
  129. # _SerializeNodeDef holds the following:
  130. # - typ: the type of the node (e.g., "Dict", "List", etc)
  131. # - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict"
  132. # - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the
  133. # context, and the version number
  134. # - from_dumpable_context takes in a string representation of the context, and the
  135. # version, and returns the deserialized context
  136. class _SerializeNodeDef(NamedTuple):
  137. typ: type[Any]
  138. serialized_type_name: str
  139. to_dumpable_context: Optional[ToDumpableContextFn]
  140. from_dumpable_context: Optional[FromDumpableContextFn]
  141. SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
  142. SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
  143. # NB: we try really hard to not import _cxx_pytree (which depends on optree)
  144. # as much as possible. This is for isolation: a user who is not using C++ pytree
  145. # shouldn't pay for it, and it helps makes things like cpython upgrades easier.
  146. _optree_minimum_version = _TorchVersion("0.13.0")
  147. try:
  148. _optree_version = importlib.metadata.version("optree")
  149. except importlib.metadata.PackageNotFoundError:
  150. # No optree package found
  151. _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
  152. _optree_version = _TorchVersion("0.0.0a0")
  153. else:
  154. _optree_version = _TorchVersion(_optree_version)
  155. if _optree_version < _optree_minimum_version:
  156. # optree package less than our required minimum version.
  157. # Pretend the optree package doesn't exist.
  158. # NB: We will raise ImportError if the user directly tries to
  159. # `import torch.utils._cxx_pytree` (look in that file for the check).
  160. _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
  161. else:
  162. _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True
  163. _cxx_pytree_imported = False
  164. _cxx_pytree_pending_imports: list[Any] = []
  165. def register_pytree_node(
  166. cls: type[Any],
  167. flatten_fn: FlattenFunc,
  168. unflatten_fn: UnflattenFunc,
  169. *,
  170. serialized_type_name: Optional[str] = None,
  171. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  172. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  173. flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
  174. ) -> None:
  175. """Register a container-like type as pytree node.
  176. Note:
  177. :func:`register_dataclass` is a simpler way of registering a container-like
  178. type as a pytree node.
  179. Args:
  180. cls: the type to register
  181. flatten_fn: A callable that takes a pytree and returns a flattened
  182. representation of the pytree and additional context to represent the
  183. flattened pytree.
  184. unflatten_fn: A callable that takes a flattened version of the pytree,
  185. additional context, and returns an unflattened pytree.
  186. serialized_type_name: A keyword argument used to specify the fully qualified
  187. name used when serializing the tree spec.
  188. to_dumpable_context: An optional keyword argument to custom specify how
  189. to convert the context of the pytree to a custom json dumpable
  190. representation. This is used for json serialization, which is being
  191. used in torch.export right now.
  192. from_dumpable_context: An optional keyword argument to custom specify how
  193. to convert the custom json dumpable representation of the context
  194. back to the original context. This is used for json deserialization,
  195. which is being used in torch.export right now.
  196. flatten_with_keys_fn: An optional keyword argument to specify how to
  197. access each pytree leaf's keypath when flattening and tree-mapping.
  198. Like ``flatten_fn``, but in place of a List[leaf], it should return
  199. a List[(keypath, leaf)].
  200. """
  201. with _NODE_REGISTRY_LOCK:
  202. if cls in SUPPORTED_NODES:
  203. raise ValueError(f"{cls} is already registered as pytree node.")
  204. _private_register_pytree_node(
  205. cls,
  206. flatten_fn,
  207. unflatten_fn,
  208. serialized_type_name=serialized_type_name,
  209. to_dumpable_context=to_dumpable_context,
  210. from_dumpable_context=from_dumpable_context,
  211. flatten_with_keys_fn=flatten_with_keys_fn,
  212. )
  213. if not _cxx_pytree_exists:
  214. return
  215. if _cxx_pytree_imported:
  216. from . import _cxx_pytree as cxx
  217. cxx._private_register_pytree_node(
  218. cls,
  219. flatten_fn,
  220. unflatten_fn,
  221. serialized_type_name=serialized_type_name,
  222. to_dumpable_context=to_dumpable_context,
  223. from_dumpable_context=from_dumpable_context,
  224. )
  225. else:
  226. args = (cls, flatten_fn, unflatten_fn)
  227. kwargs = {
  228. "serialized_type_name": serialized_type_name,
  229. "to_dumpable_context": to_dumpable_context,
  230. "from_dumpable_context": from_dumpable_context,
  231. }
  232. _cxx_pytree_pending_imports.append((args, kwargs))
  233. def register_dataclass(
  234. cls: type[Any],
  235. *,
  236. field_names: Optional[list[str]] = None,
  237. drop_field_names: Optional[list[str]] = None,
  238. serialized_type_name: Optional[str] = None,
  239. ) -> None:
  240. """
  241. Registers a type that has the semantics of a ``dataclasses.dataclass`` type
  242. as a pytree node.
  243. This is a simpler API than :func:`register_pytree_node` for registering
  244. a dataclass or a custom class with the semantics of a dataclass.
  245. Args:
  246. cls: The python type to register. The class must have the semantics of a
  247. dataclass; in particular, it must be constructed by passing the fields
  248. in.
  249. field_names (Optional[List[str]]): A list of field names that correspond
  250. to the **non-constant data** in this class. This list must contain
  251. all the fields that are used to initialize the class. This argument
  252. is optional if ``cls`` is a dataclass, in which case the fields will
  253. be taken from ``dataclasses.fields()``.
  254. drop_field_names (Optional[List[str]]): A list of field names that
  255. should not be included in the pytree.
  256. serialized_type_name: A keyword argument used to specify the fully
  257. qualified name used when serializing the tree spec. This is only
  258. needed for serializing the treespec in torch.export.
  259. Example:
  260. >>> from torch import Tensor
  261. >>> from dataclasses import dataclass
  262. >>> import torch.utils._pytree as pytree
  263. >>>
  264. >>> @dataclass
  265. >>> class Point:
  266. >>> x: Tensor
  267. >>> y: Tensor
  268. >>>
  269. >>> pytree.register_dataclass(Point)
  270. >>>
  271. >>> point = Point(torch.tensor(0), torch.tensor(1))
  272. >>> point = pytree.tree_map(lambda x: x + 1, point)
  273. >>> assert torch.allclose(point.x, torch.tensor(1))
  274. >>> assert torch.allclose(point.y, torch.tensor(2))
  275. """
  276. drop_field_names = drop_field_names or []
  277. if not dataclasses.is_dataclass(cls):
  278. if field_names is None:
  279. raise ValueError(
  280. "field_names must be specified with a list of all fields used to "
  281. f"initialize {cls}, as it is not a dataclass."
  282. )
  283. elif field_names is None:
  284. field_names = [f.name for f in dataclasses.fields(cls) if f.init]
  285. else:
  286. dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init}
  287. dataclass_init_fields.difference_update(drop_field_names)
  288. if dataclass_init_fields != set(field_names):
  289. error_msg = "field_names does not include all dataclass fields.\n"
  290. if missing := dataclass_init_fields - set(field_names):
  291. error_msg += (
  292. f"Missing fields in `field_names`: {missing}. If you want "
  293. "to include these fields in the pytree, please add them "
  294. "to `field_names`, otherwise please add them to "
  295. "`drop_field_names`.\n"
  296. )
  297. if unexpected := set(field_names) - dataclass_init_fields:
  298. error_msg += (
  299. f"Unexpected fields in `field_names`: {unexpected}. "
  300. "Please remove these fields, or add them to `drop_field_names`.\n"
  301. )
  302. raise ValueError(error_msg)
  303. def _flatten_fn(obj: Any) -> tuple[list[Any], Context]:
  304. flattened = []
  305. flat_names = []
  306. none_names = []
  307. for name in field_names:
  308. val = getattr(obj, name)
  309. if val is not None:
  310. flattened.append(val)
  311. flat_names.append(name)
  312. else:
  313. none_names.append(name)
  314. return flattened, [flat_names, none_names]
  315. def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
  316. flat_names, none_names = context
  317. return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
  318. def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
  319. flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
  320. return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
  321. _private_register_pytree_node(
  322. cls,
  323. _flatten_fn,
  324. _unflatten_fn,
  325. serialized_type_name=serialized_type_name,
  326. flatten_with_keys_fn=_flatten_fn_with_keys,
  327. )
  328. CONSTANT_NODES: set[type] = set()
  329. def register_constant(cls: type[Any]) -> None:
  330. """Registers a type as a pytree node with no leaves.
  331. In a :func:`torch.compile` region, if instances of these types get passed to
  332. :func:`torch._dynamo.nonstrict_trace`-ed function, they treated as a
  333. constant (sometimes referred to as "static"):
  334. 1. if the instance object existed before the :func:`torch.compile` region,
  335. we _assume_ no mutation will happen to it inside the :func:`torch.compile`
  336. region, require that it has non-default `__eq__` and `__hash__` methods, and
  337. we guard on the instance based on its `__eq__` method, i.e., if a new
  338. instance fails to match any instances from the previous compilations,
  339. :func:`torch.compile` will recompile the function using the new instance.
  340. 2. else if the instance object is created inside the :func:`torch.compile`
  341. region, we currently don't support using it in a
  342. :func:`torch._dynamo.nonstrict_trace`-ed function.
  343. In general, if your class holds Tensors or dynamic int/float/bool (values that
  344. may change from run-to-run of a function being compiled), then you probably
  345. do not want to register it as a constant.
  346. Otherwise if you want to pass instance of a class to a
  347. :func:`torch._dynamo.nonstrict_trace`-ed function, but you either can't use
  348. :func:`register_pytree_node` on the class, or the class is "constant" enough
  349. that you don't want to bother using :func:`register_pytree_node`, you should
  350. consider using this function.
  351. Args:
  352. cls: the type to register as a constant. This type must be hashable.
  353. Example:
  354. >>> from dataclasses import dataclass
  355. >>> import torch.utils._pytree as pytree
  356. >>>
  357. >>> @dataclass(frozen=True)
  358. >>> class Config:
  359. >>> norm: str
  360. >>>
  361. >>> pytree.register_constant(Config)
  362. >>>
  363. >>> config = Config("l2")
  364. >>> values, spec = pytree.tree_flatten(config)
  365. >>> assert len(values) == 0
  366. """
  367. if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
  368. raise TypeError(
  369. "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
  370. )
  371. # Class with a custom `__eq__` without `__hash__` won't inherit the default
  372. # `__hash__` from object; see https://stackoverflow.com/a/1608907.
  373. if cls.__hash__ is None: # type: ignore[comparison-overlap]
  374. raise TypeError(
  375. "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
  376. )
  377. def _flatten(x): # type: ignore[no-untyped-def]
  378. return [], ConstantNode(x)
  379. def _unflatten(_, context): # type: ignore[no-untyped-def]
  380. return context.value
  381. def _flatten_with_keys(x): # type: ignore[no-untyped-def]
  382. return [], ConstantNode(x)
  383. with _NODE_REGISTRY_LOCK:
  384. _private_register_pytree_node(
  385. cls,
  386. _flatten,
  387. _unflatten,
  388. flatten_with_keys_fn=_flatten_with_keys,
  389. )
  390. CONSTANT_NODES.add(cls)
  391. def is_constant_class(cls: type[Any]) -> bool:
  392. return isinstance(cls, type) and cls in CONSTANT_NODES
  393. @dataclasses.dataclass(frozen=True)
  394. class ConstantNode:
  395. value: Any
  396. def _is_constant_holder(spec: "TreeSpec") -> bool:
  397. """Checks if the spec is from a pytree registered with register_constant"""
  398. return isinstance(spec.context, ConstantNode)
  399. def _retrieve_constant(spec: "TreeSpec") -> Any:
  400. """Given a spec from a pytree registered with register_constant, retrieves the constant"""
  401. assert _is_constant_holder(spec)
  402. return tree_unflatten([], spec)
  403. def _register_namedtuple(
  404. cls: type[Any],
  405. *,
  406. serialized_type_name: str,
  407. ) -> None:
  408. """
  409. Registers a namedtuple as a valid pytree node. By default namedtuples are
  410. valid pytree nodes, but they are not serializable. This API provides the
  411. argument `serialized_type_name` which allows these namedtuples to be
  412. serialized.
  413. Args:
  414. cls: the dataclass type to register
  415. serialized_type_name: The serialized name for the dataclass. This is
  416. required if you want to serialize the pytree TreeSpec containing this
  417. namedtuple.
  418. """
  419. _private_register_pytree_node(
  420. cls,
  421. _namedtuple_flatten,
  422. _namedtuple_unflatten,
  423. serialized_type_name=serialized_type_name,
  424. to_dumpable_context=_namedtuple_serialize,
  425. from_dumpable_context=_namedtuple_deserialize,
  426. flatten_with_keys_fn=_namedtuple_flatten_with_keys,
  427. )
  428. @deprecated(
  429. "`torch.utils._pytree._register_pytree_node` is deprecated. "
  430. "Please use `torch.utils._pytree.register_pytree_node` instead.",
  431. category=FutureWarning,
  432. )
  433. def _register_pytree_node(
  434. cls: type[Any],
  435. flatten_fn: FlattenFunc,
  436. unflatten_fn: UnflattenFunc,
  437. to_str_fn: Optional[ToStrFunc] = None, # deprecated
  438. maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
  439. *,
  440. serialized_type_name: Optional[str] = None,
  441. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  442. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  443. flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
  444. ) -> None:
  445. """Register a container-like type as pytree node for the Python pytree only.
  446. Args:
  447. cls: the type to register
  448. flatten_fn: A callable that takes a pytree and returns a flattened
  449. representation of the pytree and additional context to represent the
  450. flattened pytree.
  451. unflatten_fn: A callable that takes a flattened version of the pytree,
  452. additional context, and returns an unflattened pytree.
  453. serialized_type_name: A keyword argument used to specify the fully qualified
  454. name used when serializing the tree spec.
  455. to_dumpable_context: An optional keyword argument to custom specify how
  456. to convert the context of the pytree to a custom json dumpable
  457. representation. This is used for json serialization, which is being
  458. used in torch.export right now.
  459. from_dumpable_context: An optional keyword argument to custom specify how
  460. to convert the custom json dumpable representation of the context
  461. back to the original context. This is used for json deserialization,
  462. which is being used in torch.export right now.
  463. flatten_with_keys_fn: An optional keyword argument to specify how to
  464. access each pytree leaf's keypath when flattening and tree-mapping.
  465. Like ``flatten_fn``, but in place of a List[leaf], it should return
  466. a List[(keypath, leaf)].
  467. """
  468. if to_str_fn is not None or maybe_from_str_fn is not None:
  469. warnings.warn(
  470. "`to_str_fn` and `maybe_from_str_fn` is deprecated. "
  471. "Please use `to_dumpable_context` and `from_dumpable_context` instead.",
  472. FutureWarning,
  473. stacklevel=2,
  474. )
  475. _private_register_pytree_node(
  476. cls,
  477. flatten_fn,
  478. unflatten_fn,
  479. serialized_type_name=serialized_type_name,
  480. to_dumpable_context=to_dumpable_context,
  481. from_dumpable_context=from_dumpable_context,
  482. flatten_with_keys_fn=flatten_with_keys_fn,
  483. )
  484. def _deregister_pytree_node(
  485. cls: type[Any],
  486. ) -> None:
  487. """This is an internal function that is used to deregister a pytree node type
  488. for the Python pytree only. This should be only used inside PyTorch.
  489. """
  490. with _NODE_REGISTRY_LOCK:
  491. del SUPPORTED_NODES[cls]
  492. node_def = SUPPORTED_SERIALIZED_TYPES[cls]
  493. del SERIALIZED_TYPE_TO_PYTHON_TYPE[node_def.serialized_type_name]
  494. del SUPPORTED_SERIALIZED_TYPES[cls]
  495. CONSTANT_NODES.discard(cls)
  496. def _private_register_pytree_node(
  497. cls: type[Any],
  498. flatten_fn: FlattenFunc,
  499. unflatten_fn: UnflattenFunc,
  500. *,
  501. serialized_type_name: Optional[str] = None,
  502. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  503. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  504. flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
  505. ) -> None:
  506. """This is an internal function that is used to register a pytree node type
  507. for the Python pytree only. End-users should use :func:`register_pytree_node`
  508. instead.
  509. """
  510. with _NODE_REGISTRY_LOCK:
  511. if cls in SUPPORTED_NODES:
  512. # TODO: change this warning to an error after OSS/internal stabilize
  513. warnings.warn(
  514. f"{cls} is already registered as pytree node. "
  515. "Overwriting the previous registration.",
  516. )
  517. node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn)
  518. SUPPORTED_NODES[cls] = node_def
  519. if (to_dumpable_context is None) ^ (from_dumpable_context is None):
  520. raise ValueError(
  521. f"Both to_dumpable_context and from_dumpable_context for {cls} must "
  522. "be None or registered."
  523. )
  524. if serialized_type_name is None:
  525. serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
  526. serialize_node_def = _SerializeNodeDef(
  527. cls,
  528. serialized_type_name,
  529. to_dumpable_context,
  530. from_dumpable_context,
  531. )
  532. SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
  533. SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
  534. @dataclasses.dataclass(frozen=True)
  535. class SequenceKey(Generic[T]):
  536. idx: int
  537. def __str__(self) -> str:
  538. return f"[{self.idx!r}]"
  539. def get(self, sequence: Sequence[T]) -> T:
  540. return sequence[self.idx]
  541. K = TypeVar("K", bound=Hashable)
  542. @dataclasses.dataclass(frozen=True)
  543. class MappingKey(Generic[K, T]):
  544. key: K
  545. def __str__(self) -> str:
  546. return f"[{self.key!r}]"
  547. def get(self, mapping: Mapping[K, T]) -> T:
  548. return mapping[self.key]
  549. @dataclasses.dataclass(frozen=True)
  550. class GetAttrKey:
  551. name: str
  552. def __str__(self) -> str:
  553. return f".{self.name}"
  554. def get(self, obj: Any) -> Any:
  555. return getattr(obj, self.name)
  556. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  557. def is_namedtuple(obj: Union[object, type]) -> bool:
  558. """Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
  559. cls = obj if isinstance(obj, type) else type(obj)
  560. return is_namedtuple_class(cls)
  561. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  562. def is_namedtuple_class(cls: type) -> bool:
  563. """Return whether the class is a subclass of namedtuple."""
  564. return (
  565. isinstance(cls, type)
  566. and issubclass(cls, tuple)
  567. and isinstance(getattr(cls, "_fields", None), tuple)
  568. and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
  569. and callable(getattr(cls, "_make", None))
  570. and callable(getattr(cls, "_asdict", None))
  571. )
  572. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  573. def is_namedtuple_instance(obj: object) -> bool:
  574. """Return whether the object is an instance of namedtuple."""
  575. return is_namedtuple_class(type(obj))
  576. _T_co = TypeVar("_T_co", covariant=True)
  577. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  578. class structseq(tuple[_T_co, ...]):
  579. """A generic type stub for CPython's ``PyStructSequence`` type."""
  580. __slots__: ClassVar[tuple[()]] = ()
  581. n_fields: Final[int] # type: ignore[misc]
  582. n_sequence_fields: Final[int] # type: ignore[misc]
  583. n_unnamed_fields: Final[int] # type: ignore[misc]
  584. def __init_subclass__(cls) -> NoReturn:
  585. """Prohibit subclassing."""
  586. raise TypeError("type 'structseq' is not an acceptable base type")
  587. def __new__(
  588. cls: type[Self],
  589. sequence: Iterable[_T_co],
  590. dict: dict[str, Any] = ...,
  591. ) -> Self:
  592. raise NotImplementedError
  593. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  594. def is_structseq(obj: Union[object, type]) -> bool:
  595. """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
  596. cls = obj if isinstance(obj, type) else type(obj)
  597. return is_structseq_class(cls)
  598. # Set if the type allows subclassing (see CPython's Include/object.h)
  599. Py_TPFLAGS_BASETYPE: int = 1 << 10
  600. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  601. def is_structseq_class(cls: type) -> bool:
  602. """Return whether the class is a class of PyStructSequence."""
  603. return (
  604. isinstance(cls, type)
  605. # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
  606. and cls.__bases__ == (tuple,)
  607. # Check PyStructSequence members
  608. and isinstance(getattr(cls, "n_fields", None), int)
  609. and isinstance(getattr(cls, "n_sequence_fields", None), int)
  610. and isinstance(getattr(cls, "n_unnamed_fields", None), int)
  611. # Check the type does not allow subclassing
  612. and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython
  613. )
  614. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  615. def is_structseq_instance(obj: object) -> bool:
  616. """Return whether the object is an instance of PyStructSequence."""
  617. return is_structseq_class(type(obj))
  618. def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
  619. return list(d), None
  620. def _tuple_flatten_with_keys(
  621. d: tuple[T, ...],
  622. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  623. values, context = _tuple_flatten(d)
  624. return [(SequenceKey(i), v) for i, v in enumerate(values)], context
  625. def _tuple_unflatten(values: Iterable[T], context: Context) -> tuple[T, ...]:
  626. return tuple(values)
  627. def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
  628. return d, None
  629. def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
  630. values, context = _list_flatten(d)
  631. return [(SequenceKey(i), v) for i, v in enumerate(values)], context
  632. def _list_unflatten(values: Iterable[T], context: Context) -> list[T]:
  633. return list(values)
  634. def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
  635. return list(d.values()), list(d.keys())
  636. def _dict_flatten_with_keys(
  637. d: dict[Any, T],
  638. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  639. values, context = _dict_flatten(d)
  640. return [(MappingKey(k), v) for k, v in zip(context, values)], context
  641. def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]:
  642. return dict(zip(context, values))
  643. def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]:
  644. return list(d), type(d)
  645. def _namedtuple_flatten_with_keys(
  646. d: NamedTuple,
  647. ) -> tuple[list[tuple[KeyEntry, Any]], Context]:
  648. values, context = _namedtuple_flatten(d)
  649. return (
  650. [(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
  651. context,
  652. )
  653. def _namedtuple_unflatten(values: Iterable[T], context: Context) -> NamedTuple:
  654. return cast(NamedTuple, context(*values))
  655. def _namedtuple_serialize(context: Context) -> DumpableContext:
  656. if context not in SUPPORTED_SERIALIZED_TYPES:
  657. raise NotImplementedError(
  658. f"Can't serialize TreeSpec of namedtuple class {context} because we "
  659. "didn't register a serializated_type_name. Please register using "
  660. "`_register_namedtuple`."
  661. )
  662. serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
  663. serialized_type_name = serialize_node_def.serialized_type_name
  664. if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
  665. raise NotImplementedError(
  666. f"Can't serialize TreeSpec of namedtuple class {context} because we "
  667. "couldn't find a serializated_type_name. Please register using "
  668. "`_register_namedtuple`."
  669. )
  670. return serialized_type_name
  671. def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
  672. if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
  673. raise NotImplementedError(
  674. f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
  675. "because we couldn't find a serializated name."
  676. )
  677. typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
  678. return typ
  679. def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
  680. return list(d.values()), list(d.keys())
  681. def _ordereddict_flatten_with_keys(
  682. d: OrderedDict[Any, T],
  683. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  684. values, context = _ordereddict_flatten(d)
  685. return [(MappingKey(k), v) for k, v in zip(context, values)], context
  686. def _ordereddict_unflatten(
  687. values: Iterable[T],
  688. context: Context,
  689. ) -> OrderedDict[Any, T]:
  690. return OrderedDict((key, value) for key, value in zip(context, values))
  691. _odict_flatten = _ordereddict_flatten
  692. _odict_unflatten = _ordereddict_unflatten
  693. def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
  694. values, dict_context = _dict_flatten(d)
  695. return values, [d.default_factory, dict_context]
  696. def _defaultdict_flatten_with_keys(
  697. d: defaultdict[Any, T],
  698. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  699. values, context = _defaultdict_flatten(d)
  700. _, dict_context = context
  701. return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
  702. def _defaultdict_unflatten(
  703. values: Iterable[T],
  704. context: Context,
  705. ) -> defaultdict[Any, T]:
  706. default_factory, dict_context = context
  707. return defaultdict(default_factory, _dict_unflatten(values, dict_context))
  708. def _defaultdict_serialize(context: Context) -> DumpableContext:
  709. default_factory, dict_context = context
  710. json_defaultdict = {
  711. "default_factory_module": default_factory.__module__,
  712. "default_factory_name": default_factory.__qualname__,
  713. "dict_context": dict_context,
  714. }
  715. return json_defaultdict
  716. def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
  717. assert isinstance(dumpable_context, dict)
  718. assert set(dumpable_context) == {
  719. "default_factory_module",
  720. "default_factory_name",
  721. "dict_context",
  722. }
  723. default_factory_module = dumpable_context["default_factory_module"]
  724. default_factory_name = dumpable_context["default_factory_name"]
  725. assert isinstance(default_factory_module, str)
  726. assert isinstance(default_factory_name, str)
  727. module = importlib.import_module(default_factory_module)
  728. default_factory = getattr(module, default_factory_name)
  729. dict_context = dumpable_context["dict_context"]
  730. return [default_factory, dict_context]
  731. def _deque_flatten(d: deque[T]) -> tuple[list[T], Context]:
  732. return list(d), d.maxlen
  733. def _deque_flatten_with_keys(
  734. d: deque[T],
  735. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  736. values, context = _deque_flatten(d)
  737. return [(SequenceKey(i), v) for i, v in enumerate(values)], context
  738. def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]:
  739. return deque(values, maxlen=context)
  740. _private_register_pytree_node(
  741. tuple,
  742. _tuple_flatten,
  743. _tuple_unflatten,
  744. serialized_type_name="builtins.tuple",
  745. flatten_with_keys_fn=_tuple_flatten_with_keys,
  746. )
  747. _private_register_pytree_node(
  748. list,
  749. _list_flatten,
  750. _list_unflatten,
  751. serialized_type_name="builtins.list",
  752. flatten_with_keys_fn=_list_flatten_with_keys,
  753. )
  754. _private_register_pytree_node(
  755. dict,
  756. _dict_flatten,
  757. _dict_unflatten,
  758. serialized_type_name="builtins.dict",
  759. flatten_with_keys_fn=_dict_flatten_with_keys,
  760. )
  761. _private_register_pytree_node(
  762. namedtuple, # type: ignore[arg-type]
  763. _namedtuple_flatten,
  764. _namedtuple_unflatten,
  765. serialized_type_name="collections.namedtuple",
  766. to_dumpable_context=_namedtuple_serialize,
  767. from_dumpable_context=_namedtuple_deserialize,
  768. flatten_with_keys_fn=_namedtuple_flatten_with_keys,
  769. )
  770. _private_register_pytree_node(
  771. OrderedDict,
  772. _ordereddict_flatten,
  773. _ordereddict_unflatten,
  774. serialized_type_name="collections.OrderedDict",
  775. flatten_with_keys_fn=_ordereddict_flatten_with_keys,
  776. )
  777. _private_register_pytree_node(
  778. defaultdict,
  779. _defaultdict_flatten,
  780. _defaultdict_unflatten,
  781. serialized_type_name="collections.defaultdict",
  782. to_dumpable_context=_defaultdict_serialize,
  783. from_dumpable_context=_defaultdict_deserialize,
  784. flatten_with_keys_fn=_defaultdict_flatten_with_keys,
  785. )
  786. _private_register_pytree_node(
  787. deque,
  788. _deque_flatten,
  789. _deque_unflatten,
  790. serialized_type_name="collections.deque",
  791. flatten_with_keys_fn=_deque_flatten_with_keys,
  792. )
  793. STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
  794. BUILTIN_TYPES: frozenset[type] = frozenset(
  795. {
  796. tuple,
  797. list,
  798. dict,
  799. namedtuple, # type: ignore[arg-type]
  800. OrderedDict,
  801. defaultdict,
  802. deque,
  803. },
  804. )
  805. @deprecated(
  806. "torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. "
  807. "Please use torch.utils._pytree.is_namedtuple_instance instead.",
  808. category=FutureWarning,
  809. )
  810. def _is_namedtuple_instance(tree: Any) -> bool:
  811. return is_namedtuple_instance(tree)
  812. def _get_node_type(tree: Any) -> Any:
  813. node_type = type(tree)
  814. # All namedtuple types are implicitly registered as pytree nodes.
  815. # XXX: Other parts of the codebase expect namedtuple types always return
  816. # `namedtuple` instead of the actual namedtuple type. Even if the type
  817. # is explicitly registered.
  818. if is_namedtuple_class(node_type):
  819. return namedtuple
  820. return node_type
  821. # A leaf is defined as anything that is not a Node.
  822. def tree_is_leaf(
  823. tree: PyTree,
  824. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  825. ) -> bool:
  826. """Check if a pytree is a leaf.
  827. >>> tree_is_leaf(1)
  828. True
  829. >>> tree_is_leaf(None)
  830. True
  831. >>> tree_is_leaf([1, 2, 3])
  832. False
  833. >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
  834. True
  835. >>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
  836. False
  837. >>> tree_is_leaf({"a": 1, "b": 2, "c": None})
  838. False
  839. """
  840. if is_leaf is not None and is_leaf(tree):
  841. return True
  842. return _get_node_type(tree) not in SUPPORTED_NODES
  843. @deprecated(
  844. "torch.utils._pytree._is_leaf is private and will be removed in a future release. "
  845. "Please use torch.utils._pytree.tree_is_leaf instead.",
  846. category=FutureWarning,
  847. )
  848. def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
  849. return tree_is_leaf(tree, is_leaf=is_leaf)
  850. # A TreeSpec represents the structure of a pytree. It holds:
  851. # "type": the type of root Node of the pytree
  852. # context: some context that is useful in unflattening the pytree
  853. # children_specs: specs for each child of the root Node
  854. # num_leaves: the number of leaves
  855. @dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
  856. class TreeSpec:
  857. type: Any
  858. context: Context
  859. children_specs: list["TreeSpec"]
  860. num_nodes: int = dataclasses.field(init=False)
  861. num_leaves: int = dataclasses.field(init=False)
  862. num_children: int = dataclasses.field(init=False)
  863. def __post_init__(self) -> None:
  864. num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
  865. num_leaves = sum(spec.num_leaves for spec in self.children_specs)
  866. num_children = len(self.children_specs)
  867. object.__setattr__(self, "num_nodes", num_nodes)
  868. object.__setattr__(self, "num_leaves", num_leaves)
  869. object.__setattr__(self, "num_children", num_children)
  870. def __repr__(self, indent: int = 0) -> str:
  871. repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
  872. children_specs_str: str = ""
  873. if self.num_children > 0:
  874. indent += 2
  875. children_specs_str += self.children_specs[0].__repr__(indent)
  876. children_specs_str += "," if self.num_children > 1 else ""
  877. children_specs_str += ",".join(
  878. [
  879. "\n" + " " * indent + child.__repr__(indent)
  880. for child in self.children_specs[1:]
  881. ]
  882. )
  883. repr_suffix: str = f"{children_specs_str}])"
  884. return repr_prefix + repr_suffix
  885. def __eq__(self, other: PyTree) -> bool:
  886. if self is other:
  887. return True
  888. elif other.__class__ is self.__class__:
  889. if str(self.type) != str(other.type):
  890. return False
  891. if self.context != other.context:
  892. return False
  893. elif self.children_specs != other.children_specs:
  894. return False
  895. return True
  896. return NotImplemented
  897. def is_leaf(self) -> bool:
  898. return self.num_nodes == 1 and self.num_leaves == 1
  899. def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
  900. def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
  901. if treespec.is_leaf():
  902. subtrees.append(tree)
  903. return
  904. node_type = _get_node_type(tree)
  905. if treespec.type not in BUILTIN_TYPES:
  906. # Always require custom node types to match exactly
  907. if node_type != treespec.type:
  908. raise ValueError(
  909. f"Type mismatch; "
  910. f"expected {treespec.type!r}, but got {node_type!r}.",
  911. )
  912. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  913. children, context = flatten_fn(tree)
  914. if len(children) != treespec.num_children:
  915. raise ValueError(
  916. f"Node arity mismatch; "
  917. f"expected {treespec.num_children}, but got {len(children)}.",
  918. )
  919. if context != treespec.context:
  920. raise ValueError(
  921. f"Node context mismatch for custom node type {treespec.type!r}.",
  922. )
  923. else:
  924. # For builtin dictionary types, we allow some flexibility
  925. # Otherwise, we require exact matches
  926. both_standard_dict = (
  927. treespec.type in STANDARD_DICT_TYPES
  928. and node_type in STANDARD_DICT_TYPES
  929. )
  930. if not both_standard_dict and node_type != treespec.type:
  931. raise ValueError(
  932. f"Node type mismatch; "
  933. f"expected {treespec.type!r}, but got {node_type!r}.",
  934. )
  935. if len(tree) != treespec.num_children:
  936. raise ValueError(
  937. f"Node arity mismatch; "
  938. f"expected {treespec.num_children}, but got {len(tree)}.",
  939. )
  940. if both_standard_dict:
  941. # dictionary types are compatible with each other
  942. dict_context = (
  943. treespec.context
  944. if treespec.type is not defaultdict
  945. # ignore mismatch of `default_factory` for defaultdict
  946. else treespec.context[1]
  947. )
  948. expected_keys = dict_context
  949. got_key_set = set(tree)
  950. expected_key_set = set(expected_keys)
  951. if got_key_set != expected_key_set:
  952. missing_keys = expected_key_set.difference(got_key_set)
  953. extra_keys = got_key_set.difference(expected_key_set)
  954. message = ""
  955. if missing_keys:
  956. message += f"; missing key(s): {missing_keys}"
  957. if extra_keys:
  958. message += f"; extra key(s): {extra_keys}"
  959. raise ValueError(f"Node keys mismatch{message}.")
  960. children = [tree[key] for key in expected_keys]
  961. else:
  962. # node_type is treespec.type
  963. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  964. children, context = flatten_fn(tree)
  965. if (
  966. node_type is not deque # ignore mismatch of `maxlen` for deque
  967. ) and context != treespec.context:
  968. raise ValueError(
  969. f"Node context mismatch for node type {treespec.type!r}; "
  970. f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
  971. )
  972. for subtree, subspec in zip(children, treespec.children_specs):
  973. helper(subspec, subtree, subtrees)
  974. subtrees: list[PyTree] = []
  975. helper(self, tree, subtrees)
  976. return subtrees
  977. def unflatten(self, leaves: Iterable[Any]) -> PyTree:
  978. if not isinstance(leaves, (list, tuple)):
  979. leaves = list(leaves)
  980. if len(leaves) != self.num_leaves:
  981. raise ValueError(
  982. f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
  983. f"but the spec refers to a pytree that holds {self.num_leaves} "
  984. f"items ({self}).",
  985. )
  986. if self.is_leaf():
  987. return leaves[0]
  988. unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
  989. # Recursively unflatten the children
  990. start = 0
  991. end = 0
  992. child_pytrees = []
  993. for child_spec in self.children_specs:
  994. end += child_spec.num_leaves
  995. child_pytrees.append(child_spec.unflatten(leaves[start:end]))
  996. start = end
  997. return unflatten_fn(child_pytrees, self.context)
  998. def __hash__(self) -> int:
  999. node_type = self.type
  1000. if node_type is defaultdict:
  1001. default_factory, dict_context = self.context
  1002. hashable_context = (default_factory, tuple(dict_context))
  1003. elif node_type in (dict, OrderedDict):
  1004. hashable_context = tuple(self.context)
  1005. elif node_type is None or node_type in BUILTIN_TYPES:
  1006. hashable_context = self.context
  1007. elif isinstance(self.context, ConstantNode):
  1008. hashable_context = self.context.value
  1009. else:
  1010. # The context for user-defined node types might not be hashable.
  1011. # Ignore it for hashing.
  1012. # This does not break the correctness that equal objects imply the
  1013. # same hash. This might increase the hash collision rate, but we
  1014. # don't care about that.
  1015. hashable_context = None
  1016. return hash((node_type, hashable_context, tuple(self.children_specs)))
  1017. # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
  1018. # this class with `dataclasses.fields`, etc., while having a simplified
  1019. # constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
  1020. # again, with fields that have `init=False`.
  1021. @dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
  1022. class LeafSpec(TreeSpec):
  1023. type: Any = dataclasses.field(default=None, init=False)
  1024. context: Context = dataclasses.field(default=None, init=False)
  1025. children_specs: list["TreeSpec"] = dataclasses.field(
  1026. default_factory=list, init=False
  1027. )
  1028. def __post_init__(self) -> None:
  1029. # Override `__post_init__` for `num_leaves` derivation.
  1030. object.__setattr__(self, "num_nodes", 1)
  1031. object.__setattr__(self, "num_leaves", 1)
  1032. object.__setattr__(self, "num_children", 0)
  1033. def __repr__(self, indent: int = 0) -> str:
  1034. return "*"
  1035. # All leaves are equivalent, so represent with a single object to save on
  1036. # object construction time
  1037. _LEAF_SPEC = LeafSpec()
  1038. def tree_flatten(
  1039. tree: PyTree,
  1040. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1041. ) -> tuple[list[Any], TreeSpec]:
  1042. """Flattens a pytree into a list of values and a TreeSpec that can be used
  1043. to reconstruct the pytree.
  1044. """
  1045. def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
  1046. if tree_is_leaf(node, is_leaf=is_leaf):
  1047. leaves.append(node)
  1048. return _LEAF_SPEC
  1049. node_type = _get_node_type(node)
  1050. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  1051. children, context = flatten_fn(node)
  1052. # Recursively flatten the children
  1053. subspecs = [helper(child, leaves) for child in children]
  1054. return TreeSpec(node_type, context, subspecs)
  1055. leaves: list[Any] = []
  1056. treespec = helper(tree, leaves)
  1057. return leaves, treespec
  1058. def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
  1059. """Given a list of values and a TreeSpec, builds a pytree.
  1060. This is the inverse operation of `tree_flatten`.
  1061. """
  1062. if not isinstance(treespec, TreeSpec):
  1063. raise TypeError(
  1064. f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
  1065. f"instance of TreeSpec but got item of type {type(treespec)}.",
  1066. )
  1067. return treespec.unflatten(leaves)
  1068. def tree_iter(
  1069. tree: PyTree,
  1070. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1071. ) -> Iterable[Any]:
  1072. """Get an iterator over the leaves of a pytree."""
  1073. if tree_is_leaf(tree, is_leaf=is_leaf):
  1074. yield tree
  1075. else:
  1076. node_type = _get_node_type(tree)
  1077. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  1078. child_pytrees, _ = flatten_fn(tree)
  1079. # Recursively flatten the children
  1080. for child in child_pytrees:
  1081. yield from tree_iter(child, is_leaf=is_leaf)
  1082. def tree_leaves(
  1083. tree: PyTree,
  1084. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1085. ) -> list[Any]:
  1086. """Get a list of leaves of a pytree."""
  1087. return list(tree_iter(tree, is_leaf=is_leaf))
  1088. def tree_structure(
  1089. tree: PyTree,
  1090. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1091. ) -> TreeSpec:
  1092. """Get the TreeSpec for a pytree."""
  1093. return tree_flatten(tree, is_leaf=is_leaf)[1]
  1094. def tree_map(
  1095. func: Callable[..., Any],
  1096. tree: PyTree,
  1097. *rests: PyTree,
  1098. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1099. ) -> PyTree:
  1100. """Map a multi-input function over pytree args to produce a new pytree.
  1101. See also :func:`tree_map_`.
  1102. >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
  1103. {'x': 8, 'y': (43, 65)}
  1104. >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
  1105. {'x': False, 'y': (False, False), 'z': True}
  1106. If multiple inputs are given, the structure of the tree is taken from the first input;
  1107. subsequent inputs need only have ``tree`` as a prefix:
  1108. >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
  1109. [[5, 7, 9], [6, 1, 2]]
  1110. Args:
  1111. func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
  1112. corresponding leaves of the pytrees.
  1113. tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
  1114. argument to function ``func``.
  1115. rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
  1116. ``tree`` or has ``tree`` as a prefix.
  1117. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  1118. flattening step. The function should have a single argument with signature
  1119. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1120. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1121. leaf or not. If the function is not specified, the default pytree registry will be used.
  1122. Returns:
  1123. A new pytree with the same structure as ``tree`` but with the value at each leaf given by
  1124. ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
  1125. is the tuple of values at corresponding nodes in ``rests``.
  1126. """
  1127. leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
  1128. flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
  1129. return treespec.unflatten(map(func, *flat_args))
  1130. def tree_map_(
  1131. func: Callable[..., Any],
  1132. tree: PyTree,
  1133. *rests: PyTree,
  1134. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1135. ) -> PyTree:
  1136. """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
  1137. See also :func:`tree_map`.
  1138. Args:
  1139. func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
  1140. corresponding leaves of the pytrees.
  1141. tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
  1142. argument to function ``func``.
  1143. rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
  1144. ``tree`` or has ``tree`` as a prefix.
  1145. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  1146. flattening step. The function should have a single argument with signature
  1147. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1148. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1149. leaf or not. If the function is not specified, the default pytree registry will be used.
  1150. Returns:
  1151. The original ``tree`` with the value at each leaf is given by the side-effect of function
  1152. ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
  1153. in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
  1154. """
  1155. leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
  1156. flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
  1157. deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
  1158. return tree
  1159. Type2 = tuple[type[T], type[S]]
  1160. Type3 = tuple[type[T], type[S], type[U]]
  1161. if sys.version_info >= (3, 10):
  1162. TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
  1163. else:
  1164. TypeAny = Union[type[Any], tuple[type[Any], ...]]
  1165. Fn2 = Callable[[Union[T, S]], R]
  1166. Fn3 = Callable[[Union[T, S, U]], R]
  1167. Fn = Callable[[T], R]
  1168. FnAny = Callable[[Any], R]
  1169. MapOnlyFn = Callable[[T], Callable[[Any], Any]]
  1170. # These specializations help with type inference on the lambda passed to this
  1171. # function
  1172. @overload
  1173. def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
  1174. @overload
  1175. def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
  1176. @overload
  1177. def map_only(
  1178. type_or_types_or_pred: Type3[T, S, U], /
  1179. ) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
  1180. # This specialization is needed for the implementations below that call
  1181. @overload
  1182. def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
  1183. @overload
  1184. def map_only(
  1185. type_or_types_or_pred: Callable[[Any], bool], /
  1186. ) -> MapOnlyFn[FnAny[Any]]: ...
  1187. def map_only(
  1188. type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
  1189. ) -> MapOnlyFn[FnAny[Any]]:
  1190. """
  1191. Suppose you are writing a tree_map over tensors, leaving everything
  1192. else unchanged. Ordinarily you would have to write:
  1193. def go(t):
  1194. if isinstance(t, Tensor):
  1195. return ...
  1196. else:
  1197. return t
  1198. With this function, you only need to write:
  1199. @map_only(Tensor)
  1200. def go(t):
  1201. return ...
  1202. You can also directly use 'tree_map_only'
  1203. """
  1204. if isinstance(type_or_types_or_pred, (type, tuple)) or (
  1205. sys.version_info >= (3, 10)
  1206. and isinstance(type_or_types_or_pred, types.UnionType)
  1207. ):
  1208. def pred(x: Any) -> bool:
  1209. return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type]
  1210. elif callable(type_or_types_or_pred):
  1211. pred = type_or_types_or_pred # type: ignore[assignment]
  1212. else:
  1213. raise TypeError("Argument must be a type, a tuple of types, or a callable.")
  1214. def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
  1215. @functools.wraps(func)
  1216. def wrapped(x: T) -> Any:
  1217. if pred(x):
  1218. return func(x)
  1219. return x
  1220. return wrapped
  1221. return wrapper
  1222. @overload
  1223. def tree_map_only(
  1224. type_or_types_or_pred: type[T],
  1225. /,
  1226. func: Fn[T, Any],
  1227. tree: PyTree,
  1228. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1229. ) -> PyTree: ...
  1230. @overload
  1231. def tree_map_only(
  1232. type_or_types_or_pred: Type2[T, S],
  1233. /,
  1234. func: Fn2[T, S, Any],
  1235. tree: PyTree,
  1236. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1237. ) -> PyTree: ...
  1238. @overload
  1239. def tree_map_only(
  1240. type_or_types_or_pred: Type3[T, S, U],
  1241. /,
  1242. func: Fn3[T, S, U, Any],
  1243. tree: PyTree,
  1244. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1245. ) -> PyTree: ...
  1246. @overload
  1247. def tree_map_only(
  1248. type_or_types_or_pred: TypeAny,
  1249. /,
  1250. func: FnAny[Any],
  1251. tree: PyTree,
  1252. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1253. ) -> PyTree: ...
  1254. @overload
  1255. def tree_map_only(
  1256. type_or_types_or_pred: Callable[[Any], bool],
  1257. /,
  1258. func: FnAny[Any],
  1259. tree: PyTree,
  1260. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1261. ) -> PyTree: ...
  1262. def tree_map_only(
  1263. type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
  1264. /,
  1265. func: FnAny[Any],
  1266. tree: PyTree,
  1267. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1268. ) -> PyTree:
  1269. return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  1270. @overload
  1271. def tree_map_only_(
  1272. type_or_types_or_pred: type[T],
  1273. /,
  1274. func: Fn[T, Any],
  1275. tree: PyTree,
  1276. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1277. ) -> PyTree: ...
  1278. @overload
  1279. def tree_map_only_(
  1280. type_or_types_or_pred: Type2[T, S],
  1281. /,
  1282. func: Fn2[T, S, Any],
  1283. tree: PyTree,
  1284. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1285. ) -> PyTree: ...
  1286. @overload
  1287. def tree_map_only_(
  1288. type_or_types_or_pred: Type3[T, S, U],
  1289. /,
  1290. func: Fn3[T, S, U, Any],
  1291. tree: PyTree,
  1292. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1293. ) -> PyTree: ...
  1294. @overload
  1295. def tree_map_only_(
  1296. type_or_types_or_pred: TypeAny,
  1297. /,
  1298. func: FnAny[Any],
  1299. tree: PyTree,
  1300. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1301. ) -> PyTree: ...
  1302. @overload
  1303. def tree_map_only_(
  1304. type_or_types_or_pred: Callable[[Any], bool],
  1305. /,
  1306. func: FnAny[Any],
  1307. tree: PyTree,
  1308. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1309. ) -> PyTree: ...
  1310. def tree_map_only_(
  1311. type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
  1312. /,
  1313. func: FnAny[Any],
  1314. tree: PyTree,
  1315. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1316. ) -> PyTree:
  1317. return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  1318. def tree_all(
  1319. pred: Callable[[Any], bool],
  1320. tree: PyTree,
  1321. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1322. ) -> bool:
  1323. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1324. return all(map(pred, flat_args))
  1325. def tree_any(
  1326. pred: Callable[[Any], bool],
  1327. tree: PyTree,
  1328. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1329. ) -> bool:
  1330. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1331. return any(map(pred, flat_args))
  1332. @overload
  1333. def tree_all_only(
  1334. type_or_types: type[T],
  1335. /,
  1336. pred: Fn[T, bool],
  1337. tree: PyTree,
  1338. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1339. ) -> bool: ...
  1340. @overload
  1341. def tree_all_only(
  1342. type_or_types: Type2[T, S],
  1343. /,
  1344. pred: Fn2[T, S, bool],
  1345. tree: PyTree,
  1346. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1347. ) -> bool: ...
  1348. @overload
  1349. def tree_all_only(
  1350. type_or_types: Type3[T, S, U],
  1351. /,
  1352. pred: Fn3[T, S, U, bool],
  1353. tree: PyTree,
  1354. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1355. ) -> bool: ...
  1356. def tree_all_only(
  1357. type_or_types: TypeAny,
  1358. /,
  1359. pred: FnAny[bool],
  1360. tree: PyTree,
  1361. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1362. ) -> bool:
  1363. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1364. return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
  1365. @overload
  1366. def tree_any_only(
  1367. type_or_types: type[T],
  1368. /,
  1369. pred: Fn[T, bool],
  1370. tree: PyTree,
  1371. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1372. ) -> bool: ...
  1373. @overload
  1374. def tree_any_only(
  1375. type_or_types: Type2[T, S],
  1376. /,
  1377. pred: Fn2[T, S, bool],
  1378. tree: PyTree,
  1379. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1380. ) -> bool: ...
  1381. @overload
  1382. def tree_any_only(
  1383. type_or_types: Type3[T, S, U],
  1384. /,
  1385. pred: Fn3[T, S, U, bool],
  1386. tree: PyTree,
  1387. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1388. ) -> bool: ...
  1389. def tree_any_only(
  1390. type_or_types: TypeAny,
  1391. /,
  1392. pred: FnAny[bool],
  1393. tree: PyTree,
  1394. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1395. ) -> bool:
  1396. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1397. return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
  1398. # Broadcasts a pytree to the provided TreeSpec and returns the flattened
  1399. # values. If this is not possible, then this function returns None.
  1400. #
  1401. # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
  1402. # would return [0, 0]. This is useful for part of the vmap implementation:
  1403. # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
  1404. # broadcastable to the tree structure of `inputs` and we use
  1405. # _broadcast_to_and_flatten to check this.
  1406. def _broadcast_to_and_flatten(
  1407. tree: PyTree,
  1408. treespec: TreeSpec,
  1409. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1410. ) -> Optional[list[Any]]:
  1411. assert isinstance(treespec, TreeSpec)
  1412. if tree_is_leaf(tree, is_leaf=is_leaf):
  1413. return [tree] * treespec.num_leaves
  1414. if treespec.is_leaf():
  1415. return None
  1416. node_type = _get_node_type(tree)
  1417. if node_type != treespec.type:
  1418. return None
  1419. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  1420. child_pytrees, ctx = flatten_fn(tree)
  1421. # Check if the Node is different from the spec
  1422. if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
  1423. return None
  1424. # Recursively flatten the children
  1425. result: list[Any] = []
  1426. for child, child_spec in zip(child_pytrees, treespec.children_specs):
  1427. flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
  1428. if flat is not None:
  1429. result += flat
  1430. else:
  1431. return None
  1432. return result
  1433. @dataclasses.dataclass
  1434. class _TreeSpecSchema:
  1435. """
  1436. _TreeSpecSchema is the schema used to serialize the TreeSpec
  1437. It contains the following fields:
  1438. - type: A string name of the type. null for the case of a LeafSpec.
  1439. - context: Any format which is json dumpable
  1440. - children_spec: A list of children serialized specs.
  1441. """
  1442. type: Optional[str]
  1443. context: DumpableContext
  1444. children_spec: list["_TreeSpecSchema"]
  1445. class _ProtocolFn(NamedTuple):
  1446. treespec_to_json: Callable[[TreeSpec], DumpableContext]
  1447. json_to_treespec: Callable[[DumpableContext], TreeSpec]
  1448. _SUPPORTED_PROTOCOLS: dict[int, _ProtocolFn] = {}
  1449. def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
  1450. if treespec.is_leaf():
  1451. return _TreeSpecSchema(None, None, [])
  1452. if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
  1453. raise NotImplementedError(
  1454. f"Serializing {treespec.type} in pytree is not registered.",
  1455. )
  1456. serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
  1457. serialized_type_name = serialize_node_def.serialized_type_name
  1458. if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
  1459. raise NotImplementedError(
  1460. f"No registered serialization name for {treespec.type} found. "
  1461. "Please update your _register_pytree_node call with a `serialized_type_name` kwarg."
  1462. )
  1463. if serialize_node_def.to_dumpable_context is None:
  1464. try:
  1465. serialized_context = json.dumps(treespec.context, cls=EnumEncoder)
  1466. except TypeError as e:
  1467. raise TypeError(
  1468. "Unable to serialize context. "
  1469. "Please make the context json dump-able, or register a "
  1470. "custom serializer using _register_pytree_node."
  1471. ) from e
  1472. else:
  1473. serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
  1474. child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
  1475. return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
  1476. def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
  1477. if "__enum__" in obj:
  1478. modname, _, classname = obj["fqn"].partition(":")
  1479. mod = importlib.import_module(modname)
  1480. enum_cls = mod
  1481. for attr in classname.split("."):
  1482. enum_cls = getattr(enum_cls, attr)
  1483. enum_cls = cast(type[Enum], enum_cls)
  1484. return enum_cls[obj["name"]]
  1485. return obj
  1486. def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
  1487. if (
  1488. json_schema["type"] is None
  1489. and json_schema["context"] is None
  1490. and len(json_schema["children_spec"]) == 0
  1491. ):
  1492. return _LEAF_SPEC
  1493. if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
  1494. raise NotImplementedError(
  1495. f"Deserializing {json_schema['type']} in pytree is not registered.",
  1496. )
  1497. typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
  1498. serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ]
  1499. if serialize_node_def.from_dumpable_context is None:
  1500. try:
  1501. context = json.loads(json_schema["context"], object_hook=enum_object_hook)
  1502. except TypeError as ex:
  1503. raise TypeError(
  1504. "Unable to deserialize context. "
  1505. "Please make the context json load-able, or register a "
  1506. "custom serializer using _register_pytree_node.",
  1507. ) from ex
  1508. else:
  1509. context = serialize_node_def.from_dumpable_context(json_schema["context"])
  1510. children_specs = [
  1511. _json_to_treespec(child_string) for child_string in json_schema["children_spec"]
  1512. ]
  1513. return TreeSpec(typ, context, children_specs)
  1514. _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
  1515. def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
  1516. if not isinstance(treespec, TreeSpec):
  1517. raise TypeError(
  1518. f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
  1519. f"TreeSpec but got item of type {type(treespec)}.",
  1520. )
  1521. if protocol is None:
  1522. protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
  1523. if protocol in _SUPPORTED_PROTOCOLS:
  1524. json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec)
  1525. else:
  1526. raise ValueError(
  1527. f"Unknown protocol {protocol}. "
  1528. f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
  1529. )
  1530. str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)), cls=EnumEncoder)
  1531. return str_spec
  1532. @functools.lru_cache
  1533. def treespec_loads(serialized: str) -> TreeSpec:
  1534. protocol, json_schema = json.loads(serialized)
  1535. if protocol in _SUPPORTED_PROTOCOLS:
  1536. return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema)
  1537. raise ValueError(
  1538. f"Unknown protocol {protocol}. "
  1539. f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
  1540. )
  1541. class _DummyLeaf:
  1542. def __repr__(self) -> str:
  1543. return "*"
  1544. def treespec_pprint(treespec: TreeSpec) -> str:
  1545. dummy_tree = tree_unflatten(
  1546. [_DummyLeaf() for _ in range(treespec.num_leaves)],
  1547. treespec,
  1548. )
  1549. return repr(dummy_tree)
  1550. # TODO(angelayi): remove this function after OSS/internal stabilize
  1551. @deprecated(
  1552. "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.",
  1553. category=FutureWarning,
  1554. )
  1555. def pytree_to_str(treespec: TreeSpec) -> str:
  1556. return treespec_dumps(treespec)
  1557. # TODO(angelayi): remove this function after OSS/internal stabilize
  1558. @deprecated(
  1559. "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.",
  1560. category=FutureWarning,
  1561. )
  1562. def str_to_pytree(json: str) -> TreeSpec:
  1563. return treespec_loads(json)
  1564. def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
  1565. """Get a flat list of arguments to this function
  1566. A slightly faster version of tree_leaves((args, kwargs))
  1567. """
  1568. leaves: list[Any] = []
  1569. for a in args:
  1570. leaves.extend(tree_iter(a))
  1571. for a in kwargs.values():
  1572. leaves.extend(tree_iter(a))
  1573. return leaves
  1574. def tree_flatten_with_path(
  1575. tree: PyTree,
  1576. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1577. ) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
  1578. """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
  1579. Args:
  1580. tree: a pytree to flatten. If it contains a custom type, that type must be
  1581. registered with an appropriate `tree_flatten_with_path_fn` when registered
  1582. with :func:`register_pytree_node`.
  1583. is_leaf: An extra leaf predicate function that will be called at each
  1584. flattening step. The function should have a single argument with signature
  1585. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1586. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1587. leaf or not. If the function is not specified, the default pytree registry will be used.
  1588. Returns:
  1589. A tuple where the first element is a list of (key path, leaf) pairs, and the
  1590. second element is a :class:`TreeSpec` representing the structure of the flattened
  1591. tree.
  1592. """
  1593. _, treespec = tree_flatten(tree, is_leaf)
  1594. return list(_generate_key_paths((), tree, is_leaf)), treespec
  1595. def tree_leaves_with_path(
  1596. tree: PyTree,
  1597. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1598. ) -> list[tuple[KeyPath, Any]]:
  1599. """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
  1600. Args:
  1601. tree: a pytree. If it contains a custom type, that type must be
  1602. registered with an appropriate `tree_flatten_with_path_fn` when registered
  1603. with :func:`register_pytree_node`.
  1604. is_leaf: An extra leaf predicate function that will be called at each
  1605. flattening step. The function should have a single argument with signature
  1606. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1607. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1608. leaf or not. If the function is not specified, the default pytree registry will be used.
  1609. Returns:
  1610. A list of (key path, leaf) pairs.
  1611. """
  1612. return list(_generate_key_paths((), tree, is_leaf))
  1613. def _generate_key_paths(
  1614. key_path: KeyPath,
  1615. tree: PyTree,
  1616. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1617. ) -> Iterable[tuple[KeyPath, Any]]:
  1618. if is_leaf and is_leaf(tree):
  1619. yield key_path, tree
  1620. return
  1621. node_type = _get_node_type(tree)
  1622. handler = SUPPORTED_NODES.get(node_type)
  1623. if not handler:
  1624. # This is a leaf
  1625. yield key_path, tree
  1626. return
  1627. flatten_with_keys = handler.flatten_with_keys_fn
  1628. if flatten_with_keys:
  1629. key_children, _ = flatten_with_keys(tree)
  1630. for k, c in key_children:
  1631. yield from _generate_key_paths((*key_path, k), c, is_leaf)
  1632. else:
  1633. # We registered this pytree but didn't add a flatten_with_keys_fn, complain.
  1634. raise ValueError(
  1635. f"Did not find a flatten_with_keys_fn for type: {node_type}. "
  1636. "Please pass a flatten_with_keys_fn argument to register_pytree_node."
  1637. )
  1638. def tree_map_with_path(
  1639. func: Callable[..., Any],
  1640. tree: PyTree,
  1641. *rests: PyTree,
  1642. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  1643. ) -> PyTree:
  1644. """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
  1645. Args:
  1646. func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
  1647. corresponding leaves of the pytrees. The first positional argument
  1648. to ``func`` is the key path of the leaf in question. The second
  1649. positional argument is the value of the leaf.
  1650. tree: A pytree to be mapped over, with each leaf providing the first positional
  1651. argument to function ``func``.
  1652. rests: A tuple of pytrees, each of which has the same structure as
  1653. ``tree`` or has ``tree`` as a prefix.
  1654. is_leaf: An extra leaf predicate function that will be called at each
  1655. flattening step. The function should have a single argument with signature
  1656. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1657. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1658. leaf or not. If the function is not specified, the default pytree registry will be used.
  1659. Returns
  1660. A new pytree with the same structure as ``tree`` but with the value at each leaf given by
  1661. ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
  1662. corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
  1663. ``xs`` is the tuple of values at corresponding nodes in ``rests``.
  1664. """
  1665. keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
  1666. keypath_leaves = list(zip(*keypath_leaves))
  1667. all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  1668. return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
  1669. def keystr(kp: KeyPath) -> str:
  1670. """Given a key path, return a pretty-printed representation."""
  1671. return "".join([str(k) for k in kp])
  1672. def key_get(obj: Any, kp: KeyPath) -> Any:
  1673. """Given an object and a key path, return the value at the key path."""
  1674. for k in kp:
  1675. obj = k.get(obj)
  1676. return obj