_pytree.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from collections import namedtuple
  2. from typing import Any, Callable, Optional, TypeVar
  3. from typing_extensions import NamedTuple
  4. import torch.return_types
  5. from torch.utils._pytree import PyTree, tree_flatten, TreeSpec
  6. FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
  7. FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
  8. SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {}
  9. SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
  10. _T = TypeVar("_T")
  11. _K = TypeVar("_K")
  12. _V = TypeVar("_V")
  13. def register_pytree_flatten_spec(
  14. cls: type[Any],
  15. flatten_fn_spec: FlattenFuncSpec,
  16. flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
  17. ) -> None:
  18. SUPPORTED_NODES[cls] = flatten_fn_spec
  19. SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
  20. def _deregister_pytree_flatten_spec(
  21. cls: type[Any],
  22. ) -> None:
  23. del SUPPORTED_NODES[cls]
  24. del SUPPORTED_NODES_EXACT_MATCH[cls]
  25. def tree_flatten_spec(
  26. pytree: PyTree,
  27. spec: TreeSpec,
  28. ) -> list[Any]:
  29. if spec.is_leaf():
  30. return [pytree]
  31. # I guess these exist for BC, FC reasons.
  32. # In general, we should be able to directly
  33. # use pytree tree flattener to flatten them,
  34. # as export serializes the pytree separately.
  35. # Will remove it in follow up PR.
  36. if spec.type in SUPPORTED_NODES:
  37. flatten_fn_spec = SUPPORTED_NODES[spec.type]
  38. child_pytrees = flatten_fn_spec(pytree, spec)
  39. result = []
  40. for child, child_spec in zip(child_pytrees, spec.children_specs):
  41. flat = tree_flatten_spec(child, child_spec)
  42. result += flat
  43. return result
  44. flat_result, real_spec = tree_flatten(pytree)
  45. if spec != real_spec:
  46. raise RuntimeError(
  47. f"Real spec {real_spec} of object {pytree} is different from expected spec {spec}. "
  48. f"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml"
  49. )
  50. return flat_result
  51. def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]:
  52. return [d[k] for k in spec.context]
  53. def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]:
  54. return [d[i] for i in range(spec.num_children)]
  55. def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]:
  56. return [d[i] for i in range(spec.num_children)]
  57. def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]:
  58. return [d[i] for i in range(spec.num_children)]
  59. def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool:
  60. return len(d) == spec.num_children
  61. def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool:
  62. return len(d) == spec.num_children
  63. def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool:
  64. return len(d) == spec.num_children
  65. def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
  66. return len(d) == spec.num_children
  67. register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
  68. register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
  69. register_pytree_flatten_spec(
  70. tuple,
  71. _tuple_flatten_spec,
  72. _tuple_flatten_spec_exact_match,
  73. )
  74. for return_type in torch.return_types.all_return_types:
  75. register_pytree_flatten_spec(
  76. return_type,
  77. _tuple_flatten_spec,
  78. _tuple_flatten_spec_exact_match,
  79. )
  80. register_pytree_flatten_spec(
  81. namedtuple, # type: ignore[arg-type]
  82. _namedtuple_flatten_spec,
  83. _namedtuple_flatten_spec_exact_match,
  84. )