_wrap.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. """
  2. Python implementation of function wrapping functionality for functorch.dim.
  3. """
  4. from __future__ import annotations
  5. import functools
  6. from typing import Any, Optional, TYPE_CHECKING
  7. import torch
  8. from torch.utils._pytree import tree_map
  9. from ._dim_entry import DimEntry
  10. from ._enable_all_layers import EnableAllLayers
  11. from ._tensor_info import TensorInfo
  12. if TYPE_CHECKING:
  13. from collections.abc import Callable
  14. def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
  15. """Handle tensor conversion for torch function integration."""
  16. return tensor
  17. class WrappedOperator:
  18. """
  19. This class wraps PyTorch operations to support first-class dimensions.
  20. """
  21. def __init__(
  22. self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
  23. ):
  24. self.orig = orig
  25. self.wrapper_implementation = wrapper_implementation
  26. self.name = getattr(orig, "__name__", "")
  27. self.doc = getattr(orig, "__doc__", None)
  28. self.dim_name = dim_name
  29. self.is_pointwise = False
  30. self.dim_offset = 0
  31. self.keepdim_offset = 1
  32. self.single_dim = False
  33. self.reduce = True
  34. # Update docstring if we have a dim_name
  35. if self.doc and self.dim_name:
  36. self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
  37. def function(self) -> Callable:
  38. """Create a wrapped function that calls our wrapper implementation."""
  39. def wrapped_func(*args: Any, **kwargs: Any) -> Any:
  40. return self.wrapper_implementation(self, *args, **kwargs)
  41. # Copy metadata using functools.update_wrapper for just __name__ and __doc__
  42. functools.update_wrapper(
  43. wrapped_func, self.orig, assigned=("__name__",), updated=()
  44. )
  45. wrapped_func.__doc__ = self.doc
  46. return wrapped_func
  47. def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
  48. """Convert single dimension specification to DimEntry object."""
  49. from . import Dim
  50. if isinstance(dim, Dim):
  51. if keepdim:
  52. raise ValueError("cannot preserve first-class dimensions with keepdim=True")
  53. return DimEntry(dim)
  54. elif isinstance(dim, int):
  55. i = dim
  56. while i >= 0:
  57. i -= ndim
  58. return DimEntry(i)
  59. else:
  60. return DimEntry()
  61. def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
  62. """Convert dimension specification to list of DimEntry objects."""
  63. de = _wrap_dim(dim, ndim, keepdim)
  64. result = []
  65. if not de.is_none():
  66. result.append(de)
  67. else:
  68. for d in dim:
  69. result.append(_wrap_dim(d, ndim, keepdim))
  70. return result
  71. def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
  72. """
  73. This is the core method that handles dimension-aware operations.
  74. """
  75. if not args:
  76. raise ValueError("Expected at least one argument (self)")
  77. # Get dimension argument
  78. dim_arg = kwargs.get(wrapper.dim_name)
  79. if dim_arg is None and wrapper.dim_offset < len(args):
  80. # Try to get dim from positional args (accounting for self at index 0)
  81. dim_idx = wrapper.dim_offset + 1
  82. if dim_idx < len(args):
  83. dim_arg = args[dim_idx]
  84. # If no dimension argument provided, fall back to standard functorch handling
  85. if dim_arg is None:
  86. info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
  87. if not info:
  88. return wrapper.orig(*args, **kwargs)
  89. with EnableAllLayers(info.levels) as guard:
  90. assert info.batchedtensor is not None
  91. guard.inplace_update_layers(info.batchedtensor, info.levels)
  92. new_args = list(args)
  93. new_args[0] = handle_from_tensor(info.batchedtensor)
  94. result = wrapper.orig(*new_args, **kwargs)
  95. return guard.from_batched(result, info.has_device)
  96. # Handle dimension-aware operation
  97. info = TensorInfo.create(args[0])
  98. if not info:
  99. return wrapper.orig(*args, **kwargs)
  100. # Check for keepdim parameter
  101. keepdim = False
  102. if wrapper.reduce:
  103. keepdim_arg = kwargs.get("keepdim")
  104. if keepdim_arg is None and wrapper.keepdim_offset < len(args):
  105. keepdim_idx = wrapper.keepdim_offset + 1
  106. if keepdim_idx < len(args):
  107. keepdim_arg = args[keepdim_idx]
  108. if keepdim_arg is not None:
  109. keepdim = bool(keepdim_arg)
  110. # Wrap dimensions
  111. ndim = info.ndim()
  112. dims = _wrap_dims(dim_arg, ndim, keepdim)
  113. # Convert dimensions to indices and validate
  114. dim_indices: list[int] = []
  115. seen = [False] * len(info.levels)
  116. for d in dims:
  117. midx = None
  118. for i, level in enumerate(info.levels):
  119. if level == d:
  120. midx = i
  121. break
  122. if midx is None:
  123. # Try to match by position/name more flexibly
  124. for i, level in enumerate(info.levels):
  125. if hasattr(level, "matches") and level.matches(d):
  126. midx = i
  127. break
  128. if midx is None:
  129. level_strs = [str(level) for level in info.levels]
  130. raise ValueError(
  131. f"Tensor with dimensions {level_strs} does not contain {d}"
  132. )
  133. seen[midx] = True
  134. dim_indices.append(midx)
  135. # Determine new levels after reduction
  136. new_levels = []
  137. if wrapper.reduce and not keepdim:
  138. for i, level in enumerate(info.levels):
  139. if not seen[i]:
  140. new_levels.append(level)
  141. else:
  142. new_levels = info.levels[:]
  143. # Create dimension indices for the original function
  144. if len(dim_indices) == 1:
  145. py_indices: Any = dim_indices[0]
  146. else:
  147. py_indices = tuple(dim_indices)
  148. # Update arguments
  149. new_args = list(args)
  150. new_kwargs = kwargs.copy()
  151. assert info.tensor is not None
  152. new_args[0] = handle_from_tensor(info.tensor)
  153. # Update dimension argument
  154. if wrapper.dim_name in new_kwargs:
  155. new_kwargs[wrapper.dim_name] = py_indices
  156. else:
  157. dim_idx = wrapper.dim_offset + 1
  158. if dim_idx < len(new_args):
  159. new_args = list(new_args)
  160. new_args[dim_idx] = py_indices
  161. # Call original function
  162. result = wrapper.orig(*new_args, **new_kwargs)
  163. # Wrap results
  164. def wrap_result(obj: Any) -> Any:
  165. if isinstance(obj, torch.Tensor):
  166. from . import Tensor
  167. return Tensor.from_positional(obj, new_levels, info.has_device)
  168. return obj
  169. return tree_map(wrap_result, result)
  170. def _wrap(
  171. orig: Callable,
  172. dim_offset: Optional[int] = None,
  173. keepdim_offset: Optional[int] = None,
  174. dim_name: Optional[str] = None,
  175. single_dim: Optional[bool] = None,
  176. reduce: Optional[bool] = None,
  177. ) -> Callable:
  178. """
  179. Wrap a PyTorch function to support first-class dimensions.
  180. Args:
  181. orig: Original function to wrap
  182. dim_offset: Offset for dimension argument (default: 0)
  183. keepdim_offset: Offset for keepdim argument (default: 1)
  184. dim_name: Name of dimension parameter (default: "dim")
  185. single_dim: Whether function takes single dimension (default: False)
  186. reduce: Whether function reduces dimensions (default: True)
  187. """
  188. dim_name = dim_name or "dim"
  189. wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
  190. if dim_offset is not None:
  191. wrapper.dim_offset = dim_offset
  192. if keepdim_offset is not None:
  193. wrapper.keepdim_offset = keepdim_offset
  194. if single_dim is not None:
  195. wrapper.single_dim = single_dim
  196. if reduce is not None:
  197. wrapper.reduce = reduce
  198. return wrapper.function()
  199. def call_torch_function(
  200. wrapper: WrappedOperator,
  201. func: Callable,
  202. types: tuple,
  203. args: tuple = (),
  204. kwargs: Optional[dict] = None,
  205. ) -> Any:
  206. """
  207. Handle __torch_function__ calls for wrapped operators.
  208. """
  209. if kwargs is None:
  210. kwargs = {}
  211. # Import here to avoid circular imports
  212. from . import _Tensor
  213. # Use the torch function mechanism from _Tensor
  214. return _Tensor.__torch_function__(func, types, args, kwargs)