| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068 |
- """
- Contains utility functions for working with nested python data structures.
- A *pytree* is Python nested data structure. It is a tree in the sense that
- nodes are Python collections (e.g., list, tuple, dict) and the leaves are
- Python values. Furthermore, a pytree should not contain reference cycles.
- pytrees are useful for working with nested collections of Tensors. For example,
- one can use `tree_map` to map a function over all Tensors inside some nested
- collection of Tensors and `tree_leaves` to get a flat list of all Tensors
- inside some nested collection. pytrees are helpful for implementing nested
- collection support for PyTorch APIs.
- This pytree implementation is not very performant due to Python overhead
- To improve the performance we can move parts of the implementation to C++.
- """
- import dataclasses
- import functools
- import importlib
- import importlib.metadata
- import json
- import sys
- import threading
- import types
- import warnings
- from collections import defaultdict, deque, namedtuple, OrderedDict
- from collections.abc import Hashable, Iterable, Mapping, Sequence
- from enum import Enum
- from typing import (
- Any,
- Callable,
- cast,
- ClassVar,
- Final,
- Generic,
- NoReturn,
- Optional,
- overload,
- Protocol,
- TypeVar,
- Union,
- )
- from typing_extensions import deprecated, NamedTuple, Self
- from torch.torch_version import TorchVersion as _TorchVersion
- __all__ = [
- "PyTree",
- "Context",
- "FlattenFunc",
- "UnflattenFunc",
- "DumpableContext",
- "ToDumpableContextFn",
- "FromDumpableContextFn",
- "TreeSpec",
- "LeafSpec",
- "keystr",
- "key_get",
- "register_pytree_node",
- "tree_is_leaf",
- "tree_flatten",
- "tree_flatten_with_path",
- "tree_unflatten",
- "tree_iter",
- "tree_leaves",
- "tree_leaves_with_path",
- "tree_structure",
- "tree_map",
- "tree_map_with_path",
- "tree_map_",
- "tree_map_only",
- "tree_map_only_",
- "tree_all",
- "tree_any",
- "tree_all_only",
- "tree_any_only",
- "treespec_dumps",
- "treespec_loads",
- "treespec_pprint",
- "is_namedtuple",
- "is_namedtuple_class",
- "is_namedtuple_instance",
- "is_structseq",
- "is_structseq_class",
- "is_structseq_instance",
- ]
- T = TypeVar("T")
- S = TypeVar("S")
- U = TypeVar("U")
- R = TypeVar("R")
- DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
- NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
- class KeyEntry(Protocol):
- def __hash__(self) -> int: ...
- def __eq__(self, other: object) -> bool: ...
- def __str__(self) -> str: ...
- def get(self, parent: Any) -> Any: ...
- class EnumEncoder(json.JSONEncoder):
- def default(self, obj: object) -> Union[str, dict[str, Any]]:
- if isinstance(obj, Enum):
- return {
- "__enum__": True,
- "fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
- "name": obj.name,
- }
- return cast(str, super().default(obj))
- Context = Any
- PyTree = Any
- FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
- UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
- DumpableContext = Any # Any json dumpable text
- ToDumpableContextFn = Callable[[Context], DumpableContext]
- FromDumpableContextFn = Callable[[DumpableContext], Context]
- ToStrFunc = Callable[["TreeSpec", list[str]], str]
- MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]]
- KeyPath = tuple[KeyEntry, ...]
- FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
- # A NodeDef holds two callables:
- # - flatten_fn should take the collection and return a flat list of values.
- # It can also return some context that is used in reconstructing the
- # collection.
- # - unflatten_fn should take a flat list of values and some context
- # (returned by flatten_fn). It returns the collection by reconstructing
- # it from the list and the context.
- # - flatten_with_keys_fn, which is a callable that takes a
- # pytree and returns a list of (keypath, value) pairs and a context.
- class NodeDef(NamedTuple):
- type: type[Any]
- flatten_fn: FlattenFunc
- unflatten_fn: UnflattenFunc
- flatten_with_keys_fn: Optional[FlattenWithKeysFunc]
- _NODE_REGISTRY_LOCK = threading.RLock()
- SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
- # _SerializeNodeDef holds the following:
- # - typ: the type of the node (e.g., "Dict", "List", etc)
- # - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict"
- # - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the
- # context, and the version number
- # - from_dumpable_context takes in a string representation of the context, and the
- # version, and returns the deserialized context
- class _SerializeNodeDef(NamedTuple):
- typ: type[Any]
- serialized_type_name: str
- to_dumpable_context: Optional[ToDumpableContextFn]
- from_dumpable_context: Optional[FromDumpableContextFn]
- SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
- SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
- # NB: we try really hard to not import _cxx_pytree (which depends on optree)
- # as much as possible. This is for isolation: a user who is not using C++ pytree
- # shouldn't pay for it, and it helps makes things like cpython upgrades easier.
- _optree_minimum_version = _TorchVersion("0.13.0")
- try:
- _optree_version = importlib.metadata.version("optree")
- except importlib.metadata.PackageNotFoundError:
- # No optree package found
- _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
- _optree_version = _TorchVersion("0.0.0a0")
- else:
- _optree_version = _TorchVersion(_optree_version)
- if _optree_version < _optree_minimum_version:
- # optree package less than our required minimum version.
- # Pretend the optree package doesn't exist.
- # NB: We will raise ImportError if the user directly tries to
- # `import torch.utils._cxx_pytree` (look in that file for the check).
- _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
- else:
- _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True
- _cxx_pytree_imported = False
- _cxx_pytree_pending_imports: list[Any] = []
- def register_pytree_node(
- cls: type[Any],
- flatten_fn: FlattenFunc,
- unflatten_fn: UnflattenFunc,
- *,
- serialized_type_name: Optional[str] = None,
- to_dumpable_context: Optional[ToDumpableContextFn] = None,
- from_dumpable_context: Optional[FromDumpableContextFn] = None,
- flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
- ) -> None:
- """Register a container-like type as pytree node.
- Note:
- :func:`register_dataclass` is a simpler way of registering a container-like
- type as a pytree node.
- Args:
- cls: the type to register
- flatten_fn: A callable that takes a pytree and returns a flattened
- representation of the pytree and additional context to represent the
- flattened pytree.
- unflatten_fn: A callable that takes a flattened version of the pytree,
- additional context, and returns an unflattened pytree.
- serialized_type_name: A keyword argument used to specify the fully qualified
- name used when serializing the tree spec.
- to_dumpable_context: An optional keyword argument to custom specify how
- to convert the context of the pytree to a custom json dumpable
- representation. This is used for json serialization, which is being
- used in torch.export right now.
- from_dumpable_context: An optional keyword argument to custom specify how
- to convert the custom json dumpable representation of the context
- back to the original context. This is used for json deserialization,
- which is being used in torch.export right now.
- flatten_with_keys_fn: An optional keyword argument to specify how to
- access each pytree leaf's keypath when flattening and tree-mapping.
- Like ``flatten_fn``, but in place of a List[leaf], it should return
- a List[(keypath, leaf)].
- """
- with _NODE_REGISTRY_LOCK:
- if cls in SUPPORTED_NODES:
- raise ValueError(f"{cls} is already registered as pytree node.")
- _private_register_pytree_node(
- cls,
- flatten_fn,
- unflatten_fn,
- serialized_type_name=serialized_type_name,
- to_dumpable_context=to_dumpable_context,
- from_dumpable_context=from_dumpable_context,
- flatten_with_keys_fn=flatten_with_keys_fn,
- )
- if not _cxx_pytree_exists:
- return
- if _cxx_pytree_imported:
- from . import _cxx_pytree as cxx
- cxx._private_register_pytree_node(
- cls,
- flatten_fn,
- unflatten_fn,
- serialized_type_name=serialized_type_name,
- to_dumpable_context=to_dumpable_context,
- from_dumpable_context=from_dumpable_context,
- )
- else:
- args = (cls, flatten_fn, unflatten_fn)
- kwargs = {
- "serialized_type_name": serialized_type_name,
- "to_dumpable_context": to_dumpable_context,
- "from_dumpable_context": from_dumpable_context,
- }
- _cxx_pytree_pending_imports.append((args, kwargs))
- def register_dataclass(
- cls: type[Any],
- *,
- field_names: Optional[list[str]] = None,
- drop_field_names: Optional[list[str]] = None,
- serialized_type_name: Optional[str] = None,
- ) -> None:
- """
- Registers a type that has the semantics of a ``dataclasses.dataclass`` type
- as a pytree node.
- This is a simpler API than :func:`register_pytree_node` for registering
- a dataclass or a custom class with the semantics of a dataclass.
- Args:
- cls: The python type to register. The class must have the semantics of a
- dataclass; in particular, it must be constructed by passing the fields
- in.
- field_names (Optional[List[str]]): A list of field names that correspond
- to the **non-constant data** in this class. This list must contain
- all the fields that are used to initialize the class. This argument
- is optional if ``cls`` is a dataclass, in which case the fields will
- be taken from ``dataclasses.fields()``.
- drop_field_names (Optional[List[str]]): A list of field names that
- should not be included in the pytree.
- serialized_type_name: A keyword argument used to specify the fully
- qualified name used when serializing the tree spec. This is only
- needed for serializing the treespec in torch.export.
- Example:
- >>> from torch import Tensor
- >>> from dataclasses import dataclass
- >>> import torch.utils._pytree as pytree
- >>>
- >>> @dataclass
- >>> class Point:
- >>> x: Tensor
- >>> y: Tensor
- >>>
- >>> pytree.register_dataclass(Point)
- >>>
- >>> point = Point(torch.tensor(0), torch.tensor(1))
- >>> point = pytree.tree_map(lambda x: x + 1, point)
- >>> assert torch.allclose(point.x, torch.tensor(1))
- >>> assert torch.allclose(point.y, torch.tensor(2))
- """
- drop_field_names = drop_field_names or []
- if not dataclasses.is_dataclass(cls):
- if field_names is None:
- raise ValueError(
- "field_names must be specified with a list of all fields used to "
- f"initialize {cls}, as it is not a dataclass."
- )
- elif field_names is None:
- field_names = [f.name for f in dataclasses.fields(cls) if f.init]
- else:
- dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init}
- dataclass_init_fields.difference_update(drop_field_names)
- if dataclass_init_fields != set(field_names):
- error_msg = "field_names does not include all dataclass fields.\n"
- if missing := dataclass_init_fields - set(field_names):
- error_msg += (
- f"Missing fields in `field_names`: {missing}. If you want "
- "to include these fields in the pytree, please add them "
- "to `field_names`, otherwise please add them to "
- "`drop_field_names`.\n"
- )
- if unexpected := set(field_names) - dataclass_init_fields:
- error_msg += (
- f"Unexpected fields in `field_names`: {unexpected}. "
- "Please remove these fields, or add them to `drop_field_names`.\n"
- )
- raise ValueError(error_msg)
- def _flatten_fn(obj: Any) -> tuple[list[Any], Context]:
- flattened = []
- flat_names = []
- none_names = []
- for name in field_names:
- val = getattr(obj, name)
- if val is not None:
- flattened.append(val)
- flat_names.append(name)
- else:
- none_names.append(name)
- return flattened, [flat_names, none_names]
- def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
- flat_names, none_names = context
- return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
- def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
- flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
- return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
- _private_register_pytree_node(
- cls,
- _flatten_fn,
- _unflatten_fn,
- serialized_type_name=serialized_type_name,
- flatten_with_keys_fn=_flatten_fn_with_keys,
- )
- CONSTANT_NODES: set[type] = set()
- def register_constant(cls: type[Any]) -> None:
- """Registers a type as a pytree node with no leaves.
- In a :func:`torch.compile` region, if instances of these types get passed to
- :func:`torch._dynamo.nonstrict_trace`-ed function, they treated as a
- constant (sometimes referred to as "static"):
- 1. if the instance object existed before the :func:`torch.compile` region,
- we _assume_ no mutation will happen to it inside the :func:`torch.compile`
- region, require that it has non-default `__eq__` and `__hash__` methods, and
- we guard on the instance based on its `__eq__` method, i.e., if a new
- instance fails to match any instances from the previous compilations,
- :func:`torch.compile` will recompile the function using the new instance.
- 2. else if the instance object is created inside the :func:`torch.compile`
- region, we currently don't support using it in a
- :func:`torch._dynamo.nonstrict_trace`-ed function.
- In general, if your class holds Tensors or dynamic int/float/bool (values that
- may change from run-to-run of a function being compiled), then you probably
- do not want to register it as a constant.
- Otherwise if you want to pass instance of a class to a
- :func:`torch._dynamo.nonstrict_trace`-ed function, but you either can't use
- :func:`register_pytree_node` on the class, or the class is "constant" enough
- that you don't want to bother using :func:`register_pytree_node`, you should
- consider using this function.
- Args:
- cls: the type to register as a constant. This type must be hashable.
- Example:
- >>> from dataclasses import dataclass
- >>> import torch.utils._pytree as pytree
- >>>
- >>> @dataclass(frozen=True)
- >>> class Config:
- >>> norm: str
- >>>
- >>> pytree.register_constant(Config)
- >>>
- >>> config = Config("l2")
- >>> values, spec = pytree.tree_flatten(config)
- >>> assert len(values) == 0
- """
- if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
- raise TypeError(
- "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
- )
- # Class with a custom `__eq__` without `__hash__` won't inherit the default
- # `__hash__` from object; see https://stackoverflow.com/a/1608907.
- if cls.__hash__ is None: # type: ignore[comparison-overlap]
- raise TypeError(
- "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
- )
- def _flatten(x): # type: ignore[no-untyped-def]
- return [], ConstantNode(x)
- def _unflatten(_, context): # type: ignore[no-untyped-def]
- return context.value
- def _flatten_with_keys(x): # type: ignore[no-untyped-def]
- return [], ConstantNode(x)
- with _NODE_REGISTRY_LOCK:
- _private_register_pytree_node(
- cls,
- _flatten,
- _unflatten,
- flatten_with_keys_fn=_flatten_with_keys,
- )
- CONSTANT_NODES.add(cls)
- def is_constant_class(cls: type[Any]) -> bool:
- return isinstance(cls, type) and cls in CONSTANT_NODES
- @dataclasses.dataclass(frozen=True)
- class ConstantNode:
- value: Any
- def _is_constant_holder(spec: "TreeSpec") -> bool:
- """Checks if the spec is from a pytree registered with register_constant"""
- return isinstance(spec.context, ConstantNode)
- def _retrieve_constant(spec: "TreeSpec") -> Any:
- """Given a spec from a pytree registered with register_constant, retrieves the constant"""
- assert _is_constant_holder(spec)
- return tree_unflatten([], spec)
- def _register_namedtuple(
- cls: type[Any],
- *,
- serialized_type_name: str,
- ) -> None:
- """
- Registers a namedtuple as a valid pytree node. By default namedtuples are
- valid pytree nodes, but they are not serializable. This API provides the
- argument `serialized_type_name` which allows these namedtuples to be
- serialized.
- Args:
- cls: the dataclass type to register
- serialized_type_name: The serialized name for the dataclass. This is
- required if you want to serialize the pytree TreeSpec containing this
- namedtuple.
- """
- _private_register_pytree_node(
- cls,
- _namedtuple_flatten,
- _namedtuple_unflatten,
- serialized_type_name=serialized_type_name,
- to_dumpable_context=_namedtuple_serialize,
- from_dumpable_context=_namedtuple_deserialize,
- flatten_with_keys_fn=_namedtuple_flatten_with_keys,
- )
- @deprecated(
- "`torch.utils._pytree._register_pytree_node` is deprecated. "
- "Please use `torch.utils._pytree.register_pytree_node` instead.",
- category=FutureWarning,
- )
- def _register_pytree_node(
- cls: type[Any],
- flatten_fn: FlattenFunc,
- unflatten_fn: UnflattenFunc,
- to_str_fn: Optional[ToStrFunc] = None, # deprecated
- maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
- *,
- serialized_type_name: Optional[str] = None,
- to_dumpable_context: Optional[ToDumpableContextFn] = None,
- from_dumpable_context: Optional[FromDumpableContextFn] = None,
- flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
- ) -> None:
- """Register a container-like type as pytree node for the Python pytree only.
- Args:
- cls: the type to register
- flatten_fn: A callable that takes a pytree and returns a flattened
- representation of the pytree and additional context to represent the
- flattened pytree.
- unflatten_fn: A callable that takes a flattened version of the pytree,
- additional context, and returns an unflattened pytree.
- serialized_type_name: A keyword argument used to specify the fully qualified
- name used when serializing the tree spec.
- to_dumpable_context: An optional keyword argument to custom specify how
- to convert the context of the pytree to a custom json dumpable
- representation. This is used for json serialization, which is being
- used in torch.export right now.
- from_dumpable_context: An optional keyword argument to custom specify how
- to convert the custom json dumpable representation of the context
- back to the original context. This is used for json deserialization,
- which is being used in torch.export right now.
- flatten_with_keys_fn: An optional keyword argument to specify how to
- access each pytree leaf's keypath when flattening and tree-mapping.
- Like ``flatten_fn``, but in place of a List[leaf], it should return
- a List[(keypath, leaf)].
- """
- if to_str_fn is not None or maybe_from_str_fn is not None:
- warnings.warn(
- "`to_str_fn` and `maybe_from_str_fn` is deprecated. "
- "Please use `to_dumpable_context` and `from_dumpable_context` instead.",
- FutureWarning,
- stacklevel=2,
- )
- _private_register_pytree_node(
- cls,
- flatten_fn,
- unflatten_fn,
- serialized_type_name=serialized_type_name,
- to_dumpable_context=to_dumpable_context,
- from_dumpable_context=from_dumpable_context,
- flatten_with_keys_fn=flatten_with_keys_fn,
- )
- def _deregister_pytree_node(
- cls: type[Any],
- ) -> None:
- """This is an internal function that is used to deregister a pytree node type
- for the Python pytree only. This should be only used inside PyTorch.
- """
- with _NODE_REGISTRY_LOCK:
- del SUPPORTED_NODES[cls]
- node_def = SUPPORTED_SERIALIZED_TYPES[cls]
- del SERIALIZED_TYPE_TO_PYTHON_TYPE[node_def.serialized_type_name]
- del SUPPORTED_SERIALIZED_TYPES[cls]
- CONSTANT_NODES.discard(cls)
- def _private_register_pytree_node(
- cls: type[Any],
- flatten_fn: FlattenFunc,
- unflatten_fn: UnflattenFunc,
- *,
- serialized_type_name: Optional[str] = None,
- to_dumpable_context: Optional[ToDumpableContextFn] = None,
- from_dumpable_context: Optional[FromDumpableContextFn] = None,
- flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
- ) -> None:
- """This is an internal function that is used to register a pytree node type
- for the Python pytree only. End-users should use :func:`register_pytree_node`
- instead.
- """
- with _NODE_REGISTRY_LOCK:
- if cls in SUPPORTED_NODES:
- # TODO: change this warning to an error after OSS/internal stabilize
- warnings.warn(
- f"{cls} is already registered as pytree node. "
- "Overwriting the previous registration.",
- )
- node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn)
- SUPPORTED_NODES[cls] = node_def
- if (to_dumpable_context is None) ^ (from_dumpable_context is None):
- raise ValueError(
- f"Both to_dumpable_context and from_dumpable_context for {cls} must "
- "be None or registered."
- )
- if serialized_type_name is None:
- serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
- serialize_node_def = _SerializeNodeDef(
- cls,
- serialized_type_name,
- to_dumpable_context,
- from_dumpable_context,
- )
- SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
- SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
- @dataclasses.dataclass(frozen=True)
- class SequenceKey(Generic[T]):
- idx: int
- def __str__(self) -> str:
- return f"[{self.idx!r}]"
- def get(self, sequence: Sequence[T]) -> T:
- return sequence[self.idx]
- K = TypeVar("K", bound=Hashable)
- @dataclasses.dataclass(frozen=True)
- class MappingKey(Generic[K, T]):
- key: K
- def __str__(self) -> str:
- return f"[{self.key!r}]"
- def get(self, mapping: Mapping[K, T]) -> T:
- return mapping[self.key]
- @dataclasses.dataclass(frozen=True)
- class GetAttrKey:
- name: str
- def __str__(self) -> str:
- return f".{self.name}"
- def get(self, obj: Any) -> Any:
- return getattr(obj, self.name)
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- def is_namedtuple(obj: Union[object, type]) -> bool:
- """Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
- cls = obj if isinstance(obj, type) else type(obj)
- return is_namedtuple_class(cls)
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- def is_namedtuple_class(cls: type) -> bool:
- """Return whether the class is a subclass of namedtuple."""
- return (
- isinstance(cls, type)
- and issubclass(cls, tuple)
- and isinstance(getattr(cls, "_fields", None), tuple)
- and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
- and callable(getattr(cls, "_make", None))
- and callable(getattr(cls, "_asdict", None))
- )
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- def is_namedtuple_instance(obj: object) -> bool:
- """Return whether the object is an instance of namedtuple."""
- return is_namedtuple_class(type(obj))
- _T_co = TypeVar("_T_co", covariant=True)
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- class structseq(tuple[_T_co, ...]):
- """A generic type stub for CPython's ``PyStructSequence`` type."""
- __slots__: ClassVar[tuple[()]] = ()
- n_fields: Final[int] # type: ignore[misc]
- n_sequence_fields: Final[int] # type: ignore[misc]
- n_unnamed_fields: Final[int] # type: ignore[misc]
- def __init_subclass__(cls) -> NoReturn:
- """Prohibit subclassing."""
- raise TypeError("type 'structseq' is not an acceptable base type")
- def __new__(
- cls: type[Self],
- sequence: Iterable[_T_co],
- dict: dict[str, Any] = ...,
- ) -> Self:
- raise NotImplementedError
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- def is_structseq(obj: Union[object, type]) -> bool:
- """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
- cls = obj if isinstance(obj, type) else type(obj)
- return is_structseq_class(cls)
- # Set if the type allows subclassing (see CPython's Include/object.h)
- Py_TPFLAGS_BASETYPE: int = 1 << 10
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- def is_structseq_class(cls: type) -> bool:
- """Return whether the class is a class of PyStructSequence."""
- return (
- isinstance(cls, type)
- # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
- and cls.__bases__ == (tuple,)
- # Check PyStructSequence members
- and isinstance(getattr(cls, "n_fields", None), int)
- and isinstance(getattr(cls, "n_sequence_fields", None), int)
- and isinstance(getattr(cls, "n_unnamed_fields", None), int)
- # Check the type does not allow subclassing
- and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython
- )
- # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
- def is_structseq_instance(obj: object) -> bool:
- """Return whether the object is an instance of PyStructSequence."""
- return is_structseq_class(type(obj))
- def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
- return list(d), None
- def _tuple_flatten_with_keys(
- d: tuple[T, ...],
- ) -> tuple[list[tuple[KeyEntry, T]], Context]:
- values, context = _tuple_flatten(d)
- return [(SequenceKey(i), v) for i, v in enumerate(values)], context
- def _tuple_unflatten(values: Iterable[T], context: Context) -> tuple[T, ...]:
- return tuple(values)
- def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
- return d, None
- def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
- values, context = _list_flatten(d)
- return [(SequenceKey(i), v) for i, v in enumerate(values)], context
- def _list_unflatten(values: Iterable[T], context: Context) -> list[T]:
- return list(values)
- def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
- return list(d.values()), list(d.keys())
- def _dict_flatten_with_keys(
- d: dict[Any, T],
- ) -> tuple[list[tuple[KeyEntry, T]], Context]:
- values, context = _dict_flatten(d)
- return [(MappingKey(k), v) for k, v in zip(context, values)], context
- def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]:
- return dict(zip(context, values))
- def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]:
- return list(d), type(d)
- def _namedtuple_flatten_with_keys(
- d: NamedTuple,
- ) -> tuple[list[tuple[KeyEntry, Any]], Context]:
- values, context = _namedtuple_flatten(d)
- return (
- [(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
- context,
- )
- def _namedtuple_unflatten(values: Iterable[T], context: Context) -> NamedTuple:
- return cast(NamedTuple, context(*values))
- def _namedtuple_serialize(context: Context) -> DumpableContext:
- if context not in SUPPORTED_SERIALIZED_TYPES:
- raise NotImplementedError(
- f"Can't serialize TreeSpec of namedtuple class {context} because we "
- "didn't register a serializated_type_name. Please register using "
- "`_register_namedtuple`."
- )
- serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
- serialized_type_name = serialize_node_def.serialized_type_name
- if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
- raise NotImplementedError(
- f"Can't serialize TreeSpec of namedtuple class {context} because we "
- "couldn't find a serializated_type_name. Please register using "
- "`_register_namedtuple`."
- )
- return serialized_type_name
- def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
- if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
- raise NotImplementedError(
- f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
- "because we couldn't find a serializated name."
- )
- typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
- return typ
- def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
- return list(d.values()), list(d.keys())
- def _ordereddict_flatten_with_keys(
- d: OrderedDict[Any, T],
- ) -> tuple[list[tuple[KeyEntry, T]], Context]:
- values, context = _ordereddict_flatten(d)
- return [(MappingKey(k), v) for k, v in zip(context, values)], context
- def _ordereddict_unflatten(
- values: Iterable[T],
- context: Context,
- ) -> OrderedDict[Any, T]:
- return OrderedDict((key, value) for key, value in zip(context, values))
- _odict_flatten = _ordereddict_flatten
- _odict_unflatten = _ordereddict_unflatten
- def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
- values, dict_context = _dict_flatten(d)
- return values, [d.default_factory, dict_context]
- def _defaultdict_flatten_with_keys(
- d: defaultdict[Any, T],
- ) -> tuple[list[tuple[KeyEntry, T]], Context]:
- values, context = _defaultdict_flatten(d)
- _, dict_context = context
- return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
- def _defaultdict_unflatten(
- values: Iterable[T],
- context: Context,
- ) -> defaultdict[Any, T]:
- default_factory, dict_context = context
- return defaultdict(default_factory, _dict_unflatten(values, dict_context))
- def _defaultdict_serialize(context: Context) -> DumpableContext:
- default_factory, dict_context = context
- json_defaultdict = {
- "default_factory_module": default_factory.__module__,
- "default_factory_name": default_factory.__qualname__,
- "dict_context": dict_context,
- }
- return json_defaultdict
- def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
- assert isinstance(dumpable_context, dict)
- assert set(dumpable_context) == {
- "default_factory_module",
- "default_factory_name",
- "dict_context",
- }
- default_factory_module = dumpable_context["default_factory_module"]
- default_factory_name = dumpable_context["default_factory_name"]
- assert isinstance(default_factory_module, str)
- assert isinstance(default_factory_name, str)
- module = importlib.import_module(default_factory_module)
- default_factory = getattr(module, default_factory_name)
- dict_context = dumpable_context["dict_context"]
- return [default_factory, dict_context]
- def _deque_flatten(d: deque[T]) -> tuple[list[T], Context]:
- return list(d), d.maxlen
- def _deque_flatten_with_keys(
- d: deque[T],
- ) -> tuple[list[tuple[KeyEntry, T]], Context]:
- values, context = _deque_flatten(d)
- return [(SequenceKey(i), v) for i, v in enumerate(values)], context
- def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]:
- return deque(values, maxlen=context)
- _private_register_pytree_node(
- tuple,
- _tuple_flatten,
- _tuple_unflatten,
- serialized_type_name="builtins.tuple",
- flatten_with_keys_fn=_tuple_flatten_with_keys,
- )
- _private_register_pytree_node(
- list,
- _list_flatten,
- _list_unflatten,
- serialized_type_name="builtins.list",
- flatten_with_keys_fn=_list_flatten_with_keys,
- )
- _private_register_pytree_node(
- dict,
- _dict_flatten,
- _dict_unflatten,
- serialized_type_name="builtins.dict",
- flatten_with_keys_fn=_dict_flatten_with_keys,
- )
- _private_register_pytree_node(
- namedtuple, # type: ignore[arg-type]
- _namedtuple_flatten,
- _namedtuple_unflatten,
- serialized_type_name="collections.namedtuple",
- to_dumpable_context=_namedtuple_serialize,
- from_dumpable_context=_namedtuple_deserialize,
- flatten_with_keys_fn=_namedtuple_flatten_with_keys,
- )
- _private_register_pytree_node(
- OrderedDict,
- _ordereddict_flatten,
- _ordereddict_unflatten,
- serialized_type_name="collections.OrderedDict",
- flatten_with_keys_fn=_ordereddict_flatten_with_keys,
- )
- _private_register_pytree_node(
- defaultdict,
- _defaultdict_flatten,
- _defaultdict_unflatten,
- serialized_type_name="collections.defaultdict",
- to_dumpable_context=_defaultdict_serialize,
- from_dumpable_context=_defaultdict_deserialize,
- flatten_with_keys_fn=_defaultdict_flatten_with_keys,
- )
- _private_register_pytree_node(
- deque,
- _deque_flatten,
- _deque_unflatten,
- serialized_type_name="collections.deque",
- flatten_with_keys_fn=_deque_flatten_with_keys,
- )
- STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
- BUILTIN_TYPES: frozenset[type] = frozenset(
- {
- tuple,
- list,
- dict,
- namedtuple, # type: ignore[arg-type]
- OrderedDict,
- defaultdict,
- deque,
- },
- )
- @deprecated(
- "torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. "
- "Please use torch.utils._pytree.is_namedtuple_instance instead.",
- category=FutureWarning,
- )
- def _is_namedtuple_instance(tree: Any) -> bool:
- return is_namedtuple_instance(tree)
- def _get_node_type(tree: Any) -> Any:
- node_type = type(tree)
- # All namedtuple types are implicitly registered as pytree nodes.
- # XXX: Other parts of the codebase expect namedtuple types always return
- # `namedtuple` instead of the actual namedtuple type. Even if the type
- # is explicitly registered.
- if is_namedtuple_class(node_type):
- return namedtuple
- return node_type
- # A leaf is defined as anything that is not a Node.
- def tree_is_leaf(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool:
- """Check if a pytree is a leaf.
- >>> tree_is_leaf(1)
- True
- >>> tree_is_leaf(None)
- True
- >>> tree_is_leaf([1, 2, 3])
- False
- >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
- True
- >>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
- False
- >>> tree_is_leaf({"a": 1, "b": 2, "c": None})
- False
- """
- if is_leaf is not None and is_leaf(tree):
- return True
- return _get_node_type(tree) not in SUPPORTED_NODES
- @deprecated(
- "torch.utils._pytree._is_leaf is private and will be removed in a future release. "
- "Please use torch.utils._pytree.tree_is_leaf instead.",
- category=FutureWarning,
- )
- def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
- return tree_is_leaf(tree, is_leaf=is_leaf)
- # A TreeSpec represents the structure of a pytree. It holds:
- # "type": the type of root Node of the pytree
- # context: some context that is useful in unflattening the pytree
- # children_specs: specs for each child of the root Node
- # num_leaves: the number of leaves
- @dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
- class TreeSpec:
- type: Any
- context: Context
- children_specs: list["TreeSpec"]
- num_nodes: int = dataclasses.field(init=False)
- num_leaves: int = dataclasses.field(init=False)
- num_children: int = dataclasses.field(init=False)
- def __post_init__(self) -> None:
- num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
- num_leaves = sum(spec.num_leaves for spec in self.children_specs)
- num_children = len(self.children_specs)
- object.__setattr__(self, "num_nodes", num_nodes)
- object.__setattr__(self, "num_leaves", num_leaves)
- object.__setattr__(self, "num_children", num_children)
- def __repr__(self, indent: int = 0) -> str:
- repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
- children_specs_str: str = ""
- if self.num_children > 0:
- indent += 2
- children_specs_str += self.children_specs[0].__repr__(indent)
- children_specs_str += "," if self.num_children > 1 else ""
- children_specs_str += ",".join(
- [
- "\n" + " " * indent + child.__repr__(indent)
- for child in self.children_specs[1:]
- ]
- )
- repr_suffix: str = f"{children_specs_str}])"
- return repr_prefix + repr_suffix
- def __eq__(self, other: PyTree) -> bool:
- if self is other:
- return True
- elif other.__class__ is self.__class__:
- if str(self.type) != str(other.type):
- return False
- if self.context != other.context:
- return False
- elif self.children_specs != other.children_specs:
- return False
- return True
- return NotImplemented
- def is_leaf(self) -> bool:
- return self.num_nodes == 1 and self.num_leaves == 1
- def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
- def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
- if treespec.is_leaf():
- subtrees.append(tree)
- return
- node_type = _get_node_type(tree)
- if treespec.type not in BUILTIN_TYPES:
- # Always require custom node types to match exactly
- if node_type != treespec.type:
- raise ValueError(
- f"Type mismatch; "
- f"expected {treespec.type!r}, but got {node_type!r}.",
- )
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- children, context = flatten_fn(tree)
- if len(children) != treespec.num_children:
- raise ValueError(
- f"Node arity mismatch; "
- f"expected {treespec.num_children}, but got {len(children)}.",
- )
- if context != treespec.context:
- raise ValueError(
- f"Node context mismatch for custom node type {treespec.type!r}.",
- )
- else:
- # For builtin dictionary types, we allow some flexibility
- # Otherwise, we require exact matches
- both_standard_dict = (
- treespec.type in STANDARD_DICT_TYPES
- and node_type in STANDARD_DICT_TYPES
- )
- if not both_standard_dict and node_type != treespec.type:
- raise ValueError(
- f"Node type mismatch; "
- f"expected {treespec.type!r}, but got {node_type!r}.",
- )
- if len(tree) != treespec.num_children:
- raise ValueError(
- f"Node arity mismatch; "
- f"expected {treespec.num_children}, but got {len(tree)}.",
- )
- if both_standard_dict:
- # dictionary types are compatible with each other
- dict_context = (
- treespec.context
- if treespec.type is not defaultdict
- # ignore mismatch of `default_factory` for defaultdict
- else treespec.context[1]
- )
- expected_keys = dict_context
- got_key_set = set(tree)
- expected_key_set = set(expected_keys)
- if got_key_set != expected_key_set:
- missing_keys = expected_key_set.difference(got_key_set)
- extra_keys = got_key_set.difference(expected_key_set)
- message = ""
- if missing_keys:
- message += f"; missing key(s): {missing_keys}"
- if extra_keys:
- message += f"; extra key(s): {extra_keys}"
- raise ValueError(f"Node keys mismatch{message}.")
- children = [tree[key] for key in expected_keys]
- else:
- # node_type is treespec.type
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- children, context = flatten_fn(tree)
- if (
- node_type is not deque # ignore mismatch of `maxlen` for deque
- ) and context != treespec.context:
- raise ValueError(
- f"Node context mismatch for node type {treespec.type!r}; "
- f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
- )
- for subtree, subspec in zip(children, treespec.children_specs):
- helper(subspec, subtree, subtrees)
- subtrees: list[PyTree] = []
- helper(self, tree, subtrees)
- return subtrees
- def unflatten(self, leaves: Iterable[Any]) -> PyTree:
- if not isinstance(leaves, (list, tuple)):
- leaves = list(leaves)
- if len(leaves) != self.num_leaves:
- raise ValueError(
- f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
- f"but the spec refers to a pytree that holds {self.num_leaves} "
- f"items ({self}).",
- )
- if self.is_leaf():
- return leaves[0]
- unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
- # Recursively unflatten the children
- start = 0
- end = 0
- child_pytrees = []
- for child_spec in self.children_specs:
- end += child_spec.num_leaves
- child_pytrees.append(child_spec.unflatten(leaves[start:end]))
- start = end
- return unflatten_fn(child_pytrees, self.context)
- def __hash__(self) -> int:
- node_type = self.type
- if node_type is defaultdict:
- default_factory, dict_context = self.context
- hashable_context = (default_factory, tuple(dict_context))
- elif node_type in (dict, OrderedDict):
- hashable_context = tuple(self.context)
- elif node_type is None or node_type in BUILTIN_TYPES:
- hashable_context = self.context
- elif isinstance(self.context, ConstantNode):
- hashable_context = self.context.value
- else:
- # The context for user-defined node types might not be hashable.
- # Ignore it for hashing.
- # This does not break the correctness that equal objects imply the
- # same hash. This might increase the hash collision rate, but we
- # don't care about that.
- hashable_context = None
- return hash((node_type, hashable_context, tuple(self.children_specs)))
- # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
- # this class with `dataclasses.fields`, etc., while having a simplified
- # constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
- # again, with fields that have `init=False`.
- @dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
- class LeafSpec(TreeSpec):
- type: Any = dataclasses.field(default=None, init=False)
- context: Context = dataclasses.field(default=None, init=False)
- children_specs: list["TreeSpec"] = dataclasses.field(
- default_factory=list, init=False
- )
- def __post_init__(self) -> None:
- # Override `__post_init__` for `num_leaves` derivation.
- object.__setattr__(self, "num_nodes", 1)
- object.__setattr__(self, "num_leaves", 1)
- object.__setattr__(self, "num_children", 0)
- def __repr__(self, indent: int = 0) -> str:
- return "*"
- # All leaves are equivalent, so represent with a single object to save on
- # object construction time
- _LEAF_SPEC = LeafSpec()
- def tree_flatten(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> tuple[list[Any], TreeSpec]:
- """Flattens a pytree into a list of values and a TreeSpec that can be used
- to reconstruct the pytree.
- """
- def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
- if tree_is_leaf(node, is_leaf=is_leaf):
- leaves.append(node)
- return _LEAF_SPEC
- node_type = _get_node_type(node)
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- children, context = flatten_fn(node)
- # Recursively flatten the children
- subspecs = [helper(child, leaves) for child in children]
- return TreeSpec(node_type, context, subspecs)
- leaves: list[Any] = []
- treespec = helper(tree, leaves)
- return leaves, treespec
- def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
- """Given a list of values and a TreeSpec, builds a pytree.
- This is the inverse operation of `tree_flatten`.
- """
- if not isinstance(treespec, TreeSpec):
- raise TypeError(
- f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
- f"instance of TreeSpec but got item of type {type(treespec)}.",
- )
- return treespec.unflatten(leaves)
- def tree_iter(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> Iterable[Any]:
- """Get an iterator over the leaves of a pytree."""
- if tree_is_leaf(tree, is_leaf=is_leaf):
- yield tree
- else:
- node_type = _get_node_type(tree)
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- child_pytrees, _ = flatten_fn(tree)
- # Recursively flatten the children
- for child in child_pytrees:
- yield from tree_iter(child, is_leaf=is_leaf)
- def tree_leaves(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> list[Any]:
- """Get a list of leaves of a pytree."""
- return list(tree_iter(tree, is_leaf=is_leaf))
- def tree_structure(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> TreeSpec:
- """Get the TreeSpec for a pytree."""
- return tree_flatten(tree, is_leaf=is_leaf)[1]
- def tree_map(
- func: Callable[..., Any],
- tree: PyTree,
- *rests: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree:
- """Map a multi-input function over pytree args to produce a new pytree.
- See also :func:`tree_map_`.
- >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
- {'x': 8, 'y': (43, 65)}
- >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
- {'x': False, 'y': (False, False), 'z': True}
- If multiple inputs are given, the structure of the tree is taken from the first input;
- subsequent inputs need only have ``tree`` as a prefix:
- >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
- [[5, 7, 9], [6, 1, 2]]
- Args:
- func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
- corresponding leaves of the pytrees.
- tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
- argument to function ``func``.
- rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
- ``tree`` or has ``tree`` as a prefix.
- is_leaf (callable, optional): An extra leaf predicate function that will be called at each
- flattening step. The function should have a single argument with signature
- ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
- as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
- leaf or not. If the function is not specified, the default pytree registry will be used.
- Returns:
- A new pytree with the same structure as ``tree`` but with the value at each leaf given by
- ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
- is the tuple of values at corresponding nodes in ``rests``.
- """
- leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
- flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
- return treespec.unflatten(map(func, *flat_args))
- def tree_map_(
- func: Callable[..., Any],
- tree: PyTree,
- *rests: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree:
- """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
- See also :func:`tree_map`.
- Args:
- func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
- corresponding leaves of the pytrees.
- tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
- argument to function ``func``.
- rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
- ``tree`` or has ``tree`` as a prefix.
- is_leaf (callable, optional): An extra leaf predicate function that will be called at each
- flattening step. The function should have a single argument with signature
- ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
- as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
- leaf or not. If the function is not specified, the default pytree registry will be used.
- Returns:
- The original ``tree`` with the value at each leaf is given by the side-effect of function
- ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
- in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
- """
- leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
- flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
- deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
- return tree
- Type2 = tuple[type[T], type[S]]
- Type3 = tuple[type[T], type[S], type[U]]
- if sys.version_info >= (3, 10):
- TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
- else:
- TypeAny = Union[type[Any], tuple[type[Any], ...]]
- Fn2 = Callable[[Union[T, S]], R]
- Fn3 = Callable[[Union[T, S, U]], R]
- Fn = Callable[[T], R]
- FnAny = Callable[[Any], R]
- MapOnlyFn = Callable[[T], Callable[[Any], Any]]
- # These specializations help with type inference on the lambda passed to this
- # function
- @overload
- def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
- @overload
- def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
- @overload
- def map_only(
- type_or_types_or_pred: Type3[T, S, U], /
- ) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
- # This specialization is needed for the implementations below that call
- @overload
- def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
- @overload
- def map_only(
- type_or_types_or_pred: Callable[[Any], bool], /
- ) -> MapOnlyFn[FnAny[Any]]: ...
- def map_only(
- type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
- ) -> MapOnlyFn[FnAny[Any]]:
- """
- Suppose you are writing a tree_map over tensors, leaving everything
- else unchanged. Ordinarily you would have to write:
- def go(t):
- if isinstance(t, Tensor):
- return ...
- else:
- return t
- With this function, you only need to write:
- @map_only(Tensor)
- def go(t):
- return ...
- You can also directly use 'tree_map_only'
- """
- if isinstance(type_or_types_or_pred, (type, tuple)) or (
- sys.version_info >= (3, 10)
- and isinstance(type_or_types_or_pred, types.UnionType)
- ):
- def pred(x: Any) -> bool:
- return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type]
- elif callable(type_or_types_or_pred):
- pred = type_or_types_or_pred # type: ignore[assignment]
- else:
- raise TypeError("Argument must be a type, a tuple of types, or a callable.")
- def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
- @functools.wraps(func)
- def wrapped(x: T) -> Any:
- if pred(x):
- return func(x)
- return x
- return wrapped
- return wrapper
- @overload
- def tree_map_only(
- type_or_types_or_pred: type[T],
- /,
- func: Fn[T, Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only(
- type_or_types_or_pred: Type2[T, S],
- /,
- func: Fn2[T, S, Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only(
- type_or_types_or_pred: Type3[T, S, U],
- /,
- func: Fn3[T, S, U, Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only(
- type_or_types_or_pred: TypeAny,
- /,
- func: FnAny[Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only(
- type_or_types_or_pred: Callable[[Any], bool],
- /,
- func: FnAny[Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- def tree_map_only(
- type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
- /,
- func: FnAny[Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree:
- return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
- @overload
- def tree_map_only_(
- type_or_types_or_pred: type[T],
- /,
- func: Fn[T, Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only_(
- type_or_types_or_pred: Type2[T, S],
- /,
- func: Fn2[T, S, Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only_(
- type_or_types_or_pred: Type3[T, S, U],
- /,
- func: Fn3[T, S, U, Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only_(
- type_or_types_or_pred: TypeAny,
- /,
- func: FnAny[Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- @overload
- def tree_map_only_(
- type_or_types_or_pred: Callable[[Any], bool],
- /,
- func: FnAny[Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree: ...
- def tree_map_only_(
- type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
- /,
- func: FnAny[Any],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree:
- return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
- def tree_all(
- pred: Callable[[Any], bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool:
- flat_args = tree_iter(tree, is_leaf=is_leaf)
- return all(map(pred, flat_args))
- def tree_any(
- pred: Callable[[Any], bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool:
- flat_args = tree_iter(tree, is_leaf=is_leaf)
- return any(map(pred, flat_args))
- @overload
- def tree_all_only(
- type_or_types: type[T],
- /,
- pred: Fn[T, bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool: ...
- @overload
- def tree_all_only(
- type_or_types: Type2[T, S],
- /,
- pred: Fn2[T, S, bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool: ...
- @overload
- def tree_all_only(
- type_or_types: Type3[T, S, U],
- /,
- pred: Fn3[T, S, U, bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool: ...
- def tree_all_only(
- type_or_types: TypeAny,
- /,
- pred: FnAny[bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool:
- flat_args = tree_iter(tree, is_leaf=is_leaf)
- return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
- @overload
- def tree_any_only(
- type_or_types: type[T],
- /,
- pred: Fn[T, bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool: ...
- @overload
- def tree_any_only(
- type_or_types: Type2[T, S],
- /,
- pred: Fn2[T, S, bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool: ...
- @overload
- def tree_any_only(
- type_or_types: Type3[T, S, U],
- /,
- pred: Fn3[T, S, U, bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool: ...
- def tree_any_only(
- type_or_types: TypeAny,
- /,
- pred: FnAny[bool],
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> bool:
- flat_args = tree_iter(tree, is_leaf=is_leaf)
- return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
- # Broadcasts a pytree to the provided TreeSpec and returns the flattened
- # values. If this is not possible, then this function returns None.
- #
- # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
- # would return [0, 0]. This is useful for part of the vmap implementation:
- # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
- # broadcastable to the tree structure of `inputs` and we use
- # _broadcast_to_and_flatten to check this.
- def _broadcast_to_and_flatten(
- tree: PyTree,
- treespec: TreeSpec,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> Optional[list[Any]]:
- assert isinstance(treespec, TreeSpec)
- if tree_is_leaf(tree, is_leaf=is_leaf):
- return [tree] * treespec.num_leaves
- if treespec.is_leaf():
- return None
- node_type = _get_node_type(tree)
- if node_type != treespec.type:
- return None
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- child_pytrees, ctx = flatten_fn(tree)
- # Check if the Node is different from the spec
- if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
- return None
- # Recursively flatten the children
- result: list[Any] = []
- for child, child_spec in zip(child_pytrees, treespec.children_specs):
- flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
- if flat is not None:
- result += flat
- else:
- return None
- return result
- @dataclasses.dataclass
- class _TreeSpecSchema:
- """
- _TreeSpecSchema is the schema used to serialize the TreeSpec
- It contains the following fields:
- - type: A string name of the type. null for the case of a LeafSpec.
- - context: Any format which is json dumpable
- - children_spec: A list of children serialized specs.
- """
- type: Optional[str]
- context: DumpableContext
- children_spec: list["_TreeSpecSchema"]
- class _ProtocolFn(NamedTuple):
- treespec_to_json: Callable[[TreeSpec], DumpableContext]
- json_to_treespec: Callable[[DumpableContext], TreeSpec]
- _SUPPORTED_PROTOCOLS: dict[int, _ProtocolFn] = {}
- def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
- if treespec.is_leaf():
- return _TreeSpecSchema(None, None, [])
- if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
- raise NotImplementedError(
- f"Serializing {treespec.type} in pytree is not registered.",
- )
- serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
- serialized_type_name = serialize_node_def.serialized_type_name
- if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
- raise NotImplementedError(
- f"No registered serialization name for {treespec.type} found. "
- "Please update your _register_pytree_node call with a `serialized_type_name` kwarg."
- )
- if serialize_node_def.to_dumpable_context is None:
- try:
- serialized_context = json.dumps(treespec.context, cls=EnumEncoder)
- except TypeError as e:
- raise TypeError(
- "Unable to serialize context. "
- "Please make the context json dump-able, or register a "
- "custom serializer using _register_pytree_node."
- ) from e
- else:
- serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
- child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
- return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
- def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
- if "__enum__" in obj:
- modname, _, classname = obj["fqn"].partition(":")
- mod = importlib.import_module(modname)
- enum_cls = mod
- for attr in classname.split("."):
- enum_cls = getattr(enum_cls, attr)
- enum_cls = cast(type[Enum], enum_cls)
- return enum_cls[obj["name"]]
- return obj
- def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
- if (
- json_schema["type"] is None
- and json_schema["context"] is None
- and len(json_schema["children_spec"]) == 0
- ):
- return _LEAF_SPEC
- if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
- raise NotImplementedError(
- f"Deserializing {json_schema['type']} in pytree is not registered.",
- )
- typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
- serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ]
- if serialize_node_def.from_dumpable_context is None:
- try:
- context = json.loads(json_schema["context"], object_hook=enum_object_hook)
- except TypeError as ex:
- raise TypeError(
- "Unable to deserialize context. "
- "Please make the context json load-able, or register a "
- "custom serializer using _register_pytree_node.",
- ) from ex
- else:
- context = serialize_node_def.from_dumpable_context(json_schema["context"])
- children_specs = [
- _json_to_treespec(child_string) for child_string in json_schema["children_spec"]
- ]
- return TreeSpec(typ, context, children_specs)
- _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
- def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
- if not isinstance(treespec, TreeSpec):
- raise TypeError(
- f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
- f"TreeSpec but got item of type {type(treespec)}.",
- )
- if protocol is None:
- protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
- if protocol in _SUPPORTED_PROTOCOLS:
- json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec)
- else:
- raise ValueError(
- f"Unknown protocol {protocol}. "
- f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
- )
- str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)), cls=EnumEncoder)
- return str_spec
- @functools.lru_cache
- def treespec_loads(serialized: str) -> TreeSpec:
- protocol, json_schema = json.loads(serialized)
- if protocol in _SUPPORTED_PROTOCOLS:
- return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema)
- raise ValueError(
- f"Unknown protocol {protocol}. "
- f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
- )
- class _DummyLeaf:
- def __repr__(self) -> str:
- return "*"
- def treespec_pprint(treespec: TreeSpec) -> str:
- dummy_tree = tree_unflatten(
- [_DummyLeaf() for _ in range(treespec.num_leaves)],
- treespec,
- )
- return repr(dummy_tree)
- # TODO(angelayi): remove this function after OSS/internal stabilize
- @deprecated(
- "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.",
- category=FutureWarning,
- )
- def pytree_to_str(treespec: TreeSpec) -> str:
- return treespec_dumps(treespec)
- # TODO(angelayi): remove this function after OSS/internal stabilize
- @deprecated(
- "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.",
- category=FutureWarning,
- )
- def str_to_pytree(json: str) -> TreeSpec:
- return treespec_loads(json)
- def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
- """Get a flat list of arguments to this function
- A slightly faster version of tree_leaves((args, kwargs))
- """
- leaves: list[Any] = []
- for a in args:
- leaves.extend(tree_iter(a))
- for a in kwargs.values():
- leaves.extend(tree_iter(a))
- return leaves
- def tree_flatten_with_path(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
- """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
- Args:
- tree: a pytree to flatten. If it contains a custom type, that type must be
- registered with an appropriate `tree_flatten_with_path_fn` when registered
- with :func:`register_pytree_node`.
- is_leaf: An extra leaf predicate function that will be called at each
- flattening step. The function should have a single argument with signature
- ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
- as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
- leaf or not. If the function is not specified, the default pytree registry will be used.
- Returns:
- A tuple where the first element is a list of (key path, leaf) pairs, and the
- second element is a :class:`TreeSpec` representing the structure of the flattened
- tree.
- """
- _, treespec = tree_flatten(tree, is_leaf)
- return list(_generate_key_paths((), tree, is_leaf)), treespec
- def tree_leaves_with_path(
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> list[tuple[KeyPath, Any]]:
- """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
- Args:
- tree: a pytree. If it contains a custom type, that type must be
- registered with an appropriate `tree_flatten_with_path_fn` when registered
- with :func:`register_pytree_node`.
- is_leaf: An extra leaf predicate function that will be called at each
- flattening step. The function should have a single argument with signature
- ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
- as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
- leaf or not. If the function is not specified, the default pytree registry will be used.
- Returns:
- A list of (key path, leaf) pairs.
- """
- return list(_generate_key_paths((), tree, is_leaf))
- def _generate_key_paths(
- key_path: KeyPath,
- tree: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> Iterable[tuple[KeyPath, Any]]:
- if is_leaf and is_leaf(tree):
- yield key_path, tree
- return
- node_type = _get_node_type(tree)
- handler = SUPPORTED_NODES.get(node_type)
- if not handler:
- # This is a leaf
- yield key_path, tree
- return
- flatten_with_keys = handler.flatten_with_keys_fn
- if flatten_with_keys:
- key_children, _ = flatten_with_keys(tree)
- for k, c in key_children:
- yield from _generate_key_paths((*key_path, k), c, is_leaf)
- else:
- # We registered this pytree but didn't add a flatten_with_keys_fn, complain.
- raise ValueError(
- f"Did not find a flatten_with_keys_fn for type: {node_type}. "
- "Please pass a flatten_with_keys_fn argument to register_pytree_node."
- )
- def tree_map_with_path(
- func: Callable[..., Any],
- tree: PyTree,
- *rests: PyTree,
- is_leaf: Optional[Callable[[PyTree], bool]] = None,
- ) -> PyTree:
- """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
- Args:
- func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
- corresponding leaves of the pytrees. The first positional argument
- to ``func`` is the key path of the leaf in question. The second
- positional argument is the value of the leaf.
- tree: A pytree to be mapped over, with each leaf providing the first positional
- argument to function ``func``.
- rests: A tuple of pytrees, each of which has the same structure as
- ``tree`` or has ``tree`` as a prefix.
- is_leaf: An extra leaf predicate function that will be called at each
- flattening step. The function should have a single argument with signature
- ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
- as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
- leaf or not. If the function is not specified, the default pytree registry will be used.
- Returns
- A new pytree with the same structure as ``tree`` but with the value at each leaf given by
- ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
- corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
- ``xs`` is the tuple of values at corresponding nodes in ``rests``.
- """
- keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
- keypath_leaves = list(zip(*keypath_leaves))
- all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
- return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
- def keystr(kp: KeyPath) -> str:
- """Given a key path, return a pretty-printed representation."""
- return "".join([str(k) for k in kp])
- def key_get(obj: Any, kp: KeyPath) -> Any:
- """Given an object and a key path, return the value at the key path."""
- for k in kp:
- obj = k.get(obj)
- return obj
|