_getsetitem.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from typing import Any, Optional, TYPE_CHECKING, Union
  4. import torch
  5. from ._dim_entry import _match_levels, DimEntry
  6. from ._tensor_info import TensorInfo
  7. if TYPE_CHECKING:
  8. from . import Dim
  9. def _safe_index(lst: list, item: Any) -> Optional[int]:
  10. """
  11. Helper function to find index of item in list.
  12. For DimEntry objects, uses __eq__ comparison which properly handles
  13. both positional and Dim entries.
  14. Returns the index if found, None if not found.
  15. """
  16. for i, list_item in enumerate(lst):
  17. # Use == for DimEntry objects as they have proper __eq__ implementation
  18. if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
  19. if list_item == item:
  20. return i
  21. elif list_item is item:
  22. return i
  23. return None
  24. @dataclass
  25. class IndexingInfo:
  26. can_call_original: bool = False
  27. advanced_indexing: bool = False
  28. self_tensor: Optional[torch.Tensor] = None
  29. flat_inputs: list[Any] = field(default_factory=list)
  30. result_levels: list[DimEntry] = field(default_factory=list)
  31. has_device: bool = False
  32. def has_dims(obj: Any) -> bool:
  33. """
  34. Check if an object has first-class dimensions.
  35. This function checks if the object is either a Dim or a functorch Tensor
  36. that has first-class dimensions, using the proper check_exact methods.
  37. """
  38. from . import Dim, Tensor
  39. return Dim.check_exact(obj) or Tensor.check_exact(obj)
  40. def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
  41. """
  42. Bind dimensions to size and calculate proper strides for dim packs.
  43. """
  44. from . import DimensionBindError
  45. rhs_prod = 1
  46. for i, dim in enumerate(dims):
  47. if not dim.is_bound:
  48. # Check for multiple unbound dimensions
  49. for j in range(i + 1, len(dims)):
  50. if not dims[j].is_bound:
  51. raise DimensionBindError(
  52. f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
  53. )
  54. rhs_prod *= dims[j].size
  55. # Calculate the size for this unbound dimension
  56. if sz % rhs_prod != 0:
  57. tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
  58. raise DimensionBindError(
  59. f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
  60. )
  61. inferred_size = sz // rhs_prod
  62. dim.size = inferred_size
  63. rhs_prod = sz
  64. break
  65. else:
  66. rhs_prod *= dim.size
  67. # Final validation that dimensions match
  68. if rhs_prod != sz:
  69. tup = tuple(dims)
  70. raise DimensionBindError(
  71. f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
  72. )
  73. # Calculate new sizes and strides for each dimension in the pack
  74. # First calculate all strides by iterating in reverse
  75. new_strides = [0] * len(dims)
  76. current_stride = sd
  77. for i in reversed(range(len(dims))):
  78. new_strides[i] = current_stride
  79. current_stride *= dims[i].size
  80. # Then append sizes and strides in forward order
  81. for i in range(len(dims)):
  82. nsz.append(dims[i].size)
  83. nsd.append(new_strides[i])
  84. def slice_to_tuple(flat_inputs: list) -> tuple:
  85. return tuple(flat_inputs)
  86. def extractIndices(index: Any, indices: list) -> bool:
  87. if isinstance(index, tuple): # mpy::tuple_view::check
  88. indices.extend(index)
  89. return True
  90. elif isinstance(index, torch.Tensor): # THPVariable_Check
  91. indices.append(index)
  92. return False
  93. elif not hasattr(index, "__iter__") or isinstance(
  94. index, (str, bytes)
  95. ): # !mpy::is_sequence
  96. indices.append(index)
  97. return False
  98. # Handle sequence case (list)
  99. if isinstance(index, list):
  100. if len(index) >= 32:
  101. indices.extend(index)
  102. return True
  103. # Check each item in the sequence
  104. for item in index:
  105. if (
  106. isinstance(item, (torch.Tensor, slice))
  107. or hasattr(item, "__iter__")
  108. or item is ...
  109. or item is None
  110. or has_dims(item)
  111. ):
  112. indices.extend(index)
  113. return True
  114. # If we got here, treat as single index
  115. indices.append(index)
  116. return False
  117. # Default case
  118. indices.append(index)
  119. return False
  120. def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
  121. self = args[0]
  122. index = args[1]
  123. iinfo = getsetitem(self, index, has_dims(self))
  124. if iinfo.can_call_original:
  125. # Call original tensor __getitem__ directly, bypassing __torch_function__
  126. return torch.Tensor.__getitem__(self, index)
  127. return invoke_getitem(iinfo)
  128. def setitem(self: Any, index: Any, rhs: Any) -> None:
  129. """Set values in tensor using first-class dimensions."""
  130. from . import DimensionBindError, TensorInfo
  131. iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
  132. if iinfo.can_call_original:
  133. # Call original tensor __setitem__ directly, bypassing __torch_function__
  134. torch._C.TensorBase.__setitem__(self, index, rhs)
  135. return
  136. # Handle RHS tensor with dimensions
  137. rhs_info = TensorInfo.create(rhs, False, False)
  138. if rhs_info:
  139. # Check that rhs dimensions are compatible with result dimensions
  140. for l in rhs_info.levels:
  141. if not l.is_positional():
  142. # Find this dimension in result levels
  143. found = False
  144. for result_level in iinfo.result_levels:
  145. if (
  146. not result_level.is_positional()
  147. and result_level.dim() is l.dim()
  148. ):
  149. found = True
  150. break
  151. if not found:
  152. # Create tuple representation of result levels for error message
  153. result_dims: list[Union[int, Dim]] = []
  154. for rl in iinfo.result_levels:
  155. if rl.is_positional():
  156. result_dims.append(rl.position())
  157. else:
  158. result_dims.append(rl.dim())
  159. raise DimensionBindError(
  160. f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
  161. f"({tuple(result_dims)!r})"
  162. )
  163. # Match RHS tensor to result levels
  164. assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
  165. matched_rhs = _match_levels(
  166. rhs_info.tensor, rhs_info.levels, iinfo.result_levels
  167. )
  168. else:
  169. matched_rhs = rhs
  170. # For advanced indexing with dimensions, we need special handling
  171. if iinfo.advanced_indexing:
  172. # Use advanced indexing - the flat_inputs already contain matched tensors
  173. tup = slice_to_tuple(iinfo.flat_inputs)
  174. if iinfo.self_tensor is None:
  175. raise RuntimeError("Cannot setitem on None tensor")
  176. torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
  177. else:
  178. # Simple copy operation
  179. if iinfo.self_tensor is None:
  180. raise RuntimeError("Cannot copy to None tensor")
  181. iinfo.self_tensor.copy_(matched_rhs)
  182. def invoke_getitem(iinfo: IndexingInfo) -> Any:
  183. if iinfo.advanced_indexing:
  184. self_tensor = iinfo.self_tensor
  185. tup = slice_to_tuple(iinfo.flat_inputs)
  186. if self_tensor is None:
  187. raise RuntimeError("Cannot getitem on None tensor")
  188. rtensor = self_tensor[tup]
  189. else:
  190. rtensor = iinfo.self_tensor # type: ignore[assignment]
  191. if rtensor is None:
  192. raise RuntimeError("Cannot getitem on None tensor")
  193. # rtensor is now guaranteed to be not None
  194. # Create a Tensor with the proper dimensions using the class method
  195. from . import Tensor
  196. return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
  197. def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
  198. from . import DimList # Import DimList for type checking
  199. can_call_original_getitem = not tensors_have_dims
  200. input_list = []
  201. if has_dims(index):
  202. input_list.append(index)
  203. else:
  204. is_sequence = extractIndices(index, input_list)
  205. # nothing about first class dims here, fallback to getitem
  206. if can_call_original_getitem and not is_sequence:
  207. return IndexingInfo(can_call_original=True)
  208. # Calculate how many dimensions have been indexed in order to compute the
  209. # size of ... or expand a potentially unbound dimension list.
  210. dims_indexed = 0
  211. expanding_object = -1
  212. unbound_dim_list = None
  213. dimlists = [] # Track DimList positions for later processing
  214. def check_expanding(i: int) -> None:
  215. nonlocal expanding_object
  216. if expanding_object != -1:
  217. from . import DimensionBindError
  218. raise DimensionBindError(
  219. f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
  220. f"{expanding_object} and {i}"
  221. )
  222. expanding_object = i
  223. def is_dimpack(s: Any) -> bool:
  224. from . import Dim
  225. return (
  226. isinstance(s, (tuple, list))
  227. and len(s) > 0
  228. and all(Dim.check_exact(item) for item in s)
  229. )
  230. has_dimpacks_or_none = False
  231. for i, s in enumerate(input_list):
  232. if has_dims(s):
  233. can_call_original_getitem = False
  234. dims_indexed += 1
  235. elif s is ...:
  236. check_expanding(i)
  237. elif isinstance(s, DimList):
  238. can_call_original_getitem = False
  239. if not s.is_bound:
  240. check_expanding(i)
  241. unbound_dim_list = s
  242. else:
  243. dims_indexed += len(s._dims)
  244. dimlists.append(i)
  245. elif s is None:
  246. has_dimpacks_or_none = True
  247. elif is_dimpack(s):
  248. can_call_original_getitem = False
  249. has_dimpacks_or_none = True
  250. dims_indexed += 1
  251. else:
  252. dims_indexed += 1
  253. # Early return if we can use original getitem
  254. if can_call_original_getitem:
  255. return IndexingInfo(can_call_original=True)
  256. self_info = TensorInfo.create(self, False, True)
  257. total_dims = len(self_info.levels) # Total dimensions (positional + named)
  258. if dims_indexed > total_dims:
  259. raise ValueError(
  260. f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
  261. )
  262. # Expand any unbound dimension list, or expand ... into individual : slices.
  263. expanding_dims = total_dims - dims_indexed
  264. if expanding_object != -1:
  265. if unbound_dim_list is not None:
  266. # Bind unbound dimension list to the expanding dimensions
  267. unbound_dim_list.bind_len(expanding_dims)
  268. else:
  269. # Expand ... into slice(None) objects
  270. no_slices = [slice(None)] * expanding_dims
  271. input_list = (
  272. input_list[:expanding_object]
  273. + no_slices
  274. + input_list[expanding_object + 1 :]
  275. )
  276. # Flatten out any dimensions stored in dimlist elements directly into the inputs
  277. # Process in reverse order to maintain indices
  278. for i in range(len(dimlists) - 1, -1, -1):
  279. idx = dimlists[i]
  280. # We added more elements to input because of ...
  281. # so we need to also adjust the index to get back to where the
  282. # dimlist existed
  283. if (
  284. unbound_dim_list is None
  285. and expanding_object != -1
  286. and idx > expanding_object
  287. ):
  288. idx += expanding_dims
  289. dl = input_list[idx]
  290. # PRIVATE here naughty
  291. input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
  292. return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
  293. def getsetitem_flat(
  294. self_info: TensorInfo,
  295. input_list: list,
  296. keys: list[DimEntry],
  297. values: list,
  298. has_dimpacks_or_none: bool,
  299. ) -> IndexingInfo:
  300. from . import Dim
  301. # Track dimension usage
  302. seen_dims: list[Any] = []
  303. seen_dims_nuses: list[int] = []
  304. def add_dim(dim: Any) -> None:
  305. # Use safe indexing to avoid triggering __torch_function__ on Dim objects
  306. idx = _safe_index(seen_dims, dim)
  307. if idx is not None:
  308. seen_dims_nuses[idx] += 1
  309. else:
  310. seen_dims.append(dim)
  311. seen_dims_nuses.append(1)
  312. flat_inputs = []
  313. tensor_inputs: list[Any] = []
  314. device_holding_tensor = None
  315. def append_flat_handle(handle: Any) -> None:
  316. flat_inputs.append(handle)
  317. tensor_inputs.append(None)
  318. def append_tensor_input(ti: TensorInfo) -> None:
  319. flat_inputs.append(None)
  320. tensor_inputs.append(ti)
  321. nonlocal device_holding_tensor
  322. if ti.has_device and device_holding_tensor is None:
  323. device_holding_tensor = ti.tensor
  324. nsz = []
  325. nsd = []
  326. if self_info.tensor is None:
  327. raise RuntimeError("Cannot get size/stride on None tensor")
  328. sz = self_info.tensor.size()
  329. sd = self_info.tensor.stride()
  330. def append_size(i: int) -> None:
  331. if has_dimpacks_or_none:
  332. nsz.append(sz[i])
  333. nsd.append(sd[i])
  334. input_it = input_list[:]
  335. def parse_nones() -> None:
  336. nonlocal input_it
  337. while input_it and input_it[0] is None:
  338. append_flat_handle(slice(None))
  339. nsz.append(1)
  340. nsd.append(0)
  341. input_it = input_it[1:]
  342. def append_item(i: int, arg: Any) -> None:
  343. if Dim.check_exact(arg):
  344. d = arg
  345. if d._size == -1:
  346. d.size = sz[i]
  347. add_dim(d)
  348. append_size(i)
  349. append_flat_handle(arg)
  350. return
  351. info = TensorInfo.create(arg, False, False)
  352. if info:
  353. append_size(i)
  354. append_tensor_input(info)
  355. for level in info.levels:
  356. if not level.is_positional():
  357. add_dim(level.dim())
  358. return
  359. if has_dimpacks_or_none:
  360. if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
  361. # dim pack
  362. dim_pack = list(arg)
  363. for d in dim_pack:
  364. add_dim(d)
  365. append_flat_handle(d)
  366. _bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
  367. return
  368. append_size(i)
  369. append_flat_handle(arg)
  370. # Match indexing expressions with tensor dimensions
  371. for i, level in enumerate(self_info.levels):
  372. # Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
  373. idx = _safe_index(keys, level)
  374. if idx is not None:
  375. append_item(i, values[idx])
  376. else:
  377. if level.is_positional():
  378. parse_nones()
  379. if not input_it:
  380. append_flat_handle(slice(None))
  381. append_size(i)
  382. else:
  383. arg = input_it[0]
  384. input_it = input_it[1:]
  385. append_item(i, arg)
  386. else:
  387. add_dim(level.dim())
  388. append_flat_handle(level.dim())
  389. append_size(i)
  390. parse_nones()
  391. # Restride tensor if needed
  392. if has_dimpacks_or_none and nsz:
  393. if self_info.tensor is None:
  394. raise RuntimeError("Cannot restride None tensor")
  395. self_tensor = self_info.tensor.as_strided(
  396. nsz, nsd, self_info.tensor.storage_offset()
  397. )
  398. else:
  399. self_tensor = self_info.tensor
  400. # Determine result shape and indexing requirements
  401. result_levels: list[Any] = []
  402. index_levels = []
  403. tensor_insert_point = -1
  404. requires_getindex = False
  405. def mark_tensor_index() -> None:
  406. nonlocal tensor_insert_point
  407. if tensor_insert_point == -1:
  408. tensor_insert_point = len(result_levels)
  409. elif tensor_insert_point != len(result_levels):
  410. tensor_insert_point = 0
  411. for i, inp in enumerate(flat_inputs):
  412. if tensor_inputs[i] is not None:
  413. requires_getindex = True
  414. mark_tensor_index()
  415. for level in tensor_inputs[i].levels:
  416. if level not in index_levels:
  417. index_levels.append(level)
  418. elif Dim.check_exact(inp):
  419. d = inp
  420. # Use safe indexing to avoid triggering __torch_function__
  421. dim_idx = _safe_index(seen_dims, d)
  422. assert dim_idx is not None, f"Dim {d} not found in seen_dims"
  423. if seen_dims_nuses[dim_idx] == 1:
  424. flat_inputs[i] = slice(None)
  425. result_levels.append(DimEntry(d))
  426. else:
  427. requires_getindex = True
  428. flat_inputs[i] = None
  429. tensor_inputs[i] = TensorInfo(
  430. d._get_range(), [DimEntry(d)], False, None
  431. )
  432. if DimEntry(d) not in index_levels:
  433. index_levels.append(DimEntry(d))
  434. mark_tensor_index()
  435. else:
  436. if inp != slice(None):
  437. requires_getindex = True
  438. if not isinstance(inp, int):
  439. result_levels.append(DimEntry(-1))
  440. # Insert indexing dimensions at first tensor use point
  441. if tensor_insert_point != -1:
  442. for level in reversed(index_levels):
  443. result_levels.insert(tensor_insert_point, level)
  444. # Match tensors to indexing shape
  445. if requires_getindex:
  446. for i in range(len(flat_inputs)):
  447. if tensor_inputs[i] is not None:
  448. t = tensor_inputs[i].tensor
  449. assert t is not None, "TensorInfo should have valid tensor data"
  450. if (
  451. not tensor_inputs[i].has_device
  452. and device_holding_tensor is not None
  453. ):
  454. t = t.to(device_holding_tensor.device)
  455. flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
  456. # Number positional dimensions correctly
  457. seen_positionals = 0
  458. for i in reversed(range(len(result_levels))):
  459. if result_levels[i].is_positional():
  460. seen_positionals += 1
  461. result_levels[i] = DimEntry(-seen_positionals)
  462. return IndexingInfo(
  463. can_call_original=False,
  464. advanced_indexing=requires_getindex,
  465. self_tensor=self_tensor,
  466. flat_inputs=flat_inputs,
  467. result_levels=result_levels,
  468. has_device=self_info.has_device,
  469. )