decomposition.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174
  1. # mypy: allow-untyped-decorators
  2. import functools
  3. import logging
  4. import math
  5. import operator
  6. import sys
  7. import typing
  8. from typing import Any, Callable, Optional, TypeVar, Union
  9. from typing_extensions import ParamSpec, TypeAlias
  10. import torch
  11. import torch._decomp as decomp
  12. import torch._prims_common as utils
  13. import torch.ao.quantization.fx._decomposed
  14. from torch._decomp import (
  15. core_aten_decompositions,
  16. get_decompositions,
  17. remove_decompositions,
  18. )
  19. from torch._decomp.decompositions import (
  20. _grid_sampler_2d as decomp_grid_sampler_2d,
  21. _index_add,
  22. embedding_dense_backward as decomp_embedding_dense_backward,
  23. pw_cast_for_opmath,
  24. pw_cast_for_opmath_non_tensor_args,
  25. )
  26. from torch._decomp.decompositions_for_rng import extra_random_decomps
  27. from torch._dynamo.utils import counters
  28. from torch._environment import is_fbcode
  29. from torch._higher_order_ops.out_dtype import out_dtype
  30. from torch._inductor.utils import pad_listlike
  31. from torch._prims_common import (
  32. elementwise_dtypes,
  33. ELEMENTWISE_TYPE_PROMOTION_KIND,
  34. type_to_dtype,
  35. )
  36. from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true
  37. from . import config, inductor_prims
  38. from .utils import (
  39. is_gpu,
  40. needs_fallback_due_to_atomic_add_limitations,
  41. use_scatter_fallback,
  42. )
  43. _T = TypeVar("_T")
  44. _P = ParamSpec("_P")
  45. _GenericOperator: TypeAlias = Union[
  46. torch._ops.OperatorBase, torch._ops.OpOverloadPacket
  47. ]
  48. log = logging.getLogger(__name__)
  49. aten = torch.ops.aten
  50. prims = torch.ops.prims
  51. quantized = torch.ops.quantized
  52. _quantized = torch.ops._quantized
  53. quantized_decomposed = torch.ops.quantized_decomposed
  54. inductor_decompositions = get_decompositions(
  55. [
  56. aten._adaptive_avg_pool2d_backward,
  57. aten.index_select,
  58. aten.addmv,
  59. aten.arange,
  60. aten.bitwise_and_,
  61. aten.bitwise_or_,
  62. aten.clamp_min_,
  63. aten.dist,
  64. aten.elu,
  65. aten.empty_like,
  66. aten.flip,
  67. aten.gelu,
  68. aten.hardtanh,
  69. aten.lcm,
  70. aten.leaky_relu,
  71. aten.linalg_vector_norm,
  72. aten._log_softmax,
  73. aten.max_pool2d_with_indices_backward,
  74. aten._native_batch_norm_legit,
  75. aten._native_batch_norm_legit_functional,
  76. aten._native_batch_norm_legit_no_training,
  77. aten._batch_norm_with_update,
  78. aten._batch_norm_with_update_functional,
  79. aten._batch_norm_no_update,
  80. aten.batch_norm_backward,
  81. aten.native_batch_norm,
  82. aten.native_group_norm,
  83. aten.native_layer_norm,
  84. aten.nll_loss2d_backward,
  85. aten.permute_copy,
  86. aten.rrelu_with_noise_backward,
  87. aten._softmax,
  88. aten.sin_,
  89. aten.sqrt_,
  90. out_dtype,
  91. aten._to_copy,
  92. aten.tril_indices,
  93. aten.triu_indices,
  94. aten.unbind_copy.int,
  95. aten.upsample_bilinear2d.vec,
  96. quantized.linear_dynamic_fp16_unpacked_weight,
  97. _quantized.wrapped_quantized_linear,
  98. ]
  99. )
  100. decompositions = {**core_aten_decompositions(), **inductor_decompositions}
  101. # Remove unwanted decompositions included via the core ATen decompositions from
  102. # the Inductor decomp table.
  103. decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [
  104. aten._unsafe_index,
  105. aten._unsafe_masked_index,
  106. aten._unsafe_masked_index_put_accumulate,
  107. aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
  108. aten._softmax_backward_data,
  109. aten.clamp_max,
  110. aten.clamp_min,
  111. aten.embedding_dense_backward, # we fall back on xpu
  112. aten.index_add, # we conditionally call this decomp
  113. aten.glu, # inductor lowers this directly
  114. aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
  115. aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
  116. aten.split.Tensor, # inductor lowers this directly
  117. aten.squeeze, # inductor lowers this directly
  118. aten.sum, # inductor lowers this directly
  119. aten.unbind, # inductor lowers this directly
  120. aten.baddbmm, # upcasts to fp32, perf issue
  121. ]
  122. remove_decompositions(decompositions, decomps_to_exclude)
  123. def register_decomposition(
  124. ops: Union[_GenericOperator, list[_GenericOperator]],
  125. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  126. for op in ops if isinstance(ops, list) else [ops]:
  127. if op in decompositions:
  128. log.warning("duplicate decomp: %s", ops)
  129. return decomp.register_decomposition(ops, decompositions)
  130. @register_decomposition([aten.embedding_dense_backward])
  131. def _embedding_dense_backward(
  132. grad_output: torch.Tensor,
  133. indices: torch.Tensor,
  134. num_weights: int,
  135. padding_idx: int,
  136. scale_grad_by_freq: bool,
  137. ) -> torch.Tensor:
  138. # TODO: check if XE4 still need this fallback
  139. # check torch.xpu.get_device_properties(grad_output.device).architecture
  140. if grad_output.is_xpu:
  141. return NotImplemented
  142. # We can write a util function to update decomp table if we have more ops to fallback.
  143. return decomp_embedding_dense_backward(
  144. grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
  145. )
  146. @register_decomposition([aten.sym_constrain_range_for_size.default])
  147. def sym_constrain_range_for_size(
  148. symbol: torch.SymInt,
  149. *,
  150. min: Optional[torch.types.Number] = None,
  151. max: Optional[torch.types.Number] = None,
  152. ) -> None:
  153. return
  154. @register_decomposition([aten.clamp])
  155. @pw_cast_for_opmath_non_tensor_args
  156. def clamp(
  157. x: torch.Tensor,
  158. min: Optional[torch.types.Number] = None,
  159. max: Optional[torch.types.Number] = None,
  160. ) -> torch.Tensor:
  161. if min is not None:
  162. x = x.clamp_min(min)
  163. if max is not None:
  164. x = x.clamp_max(max)
  165. return x
  166. @register_decomposition([aten.full])
  167. def full(
  168. size: list[Union[int, torch.SymInt]],
  169. fill_value: torch.types.Number,
  170. **kwargs: Any,
  171. ) -> torch.Tensor:
  172. dtype = kwargs.get("dtype")
  173. if dtype is None:
  174. kwargs["dtype"] = type_to_dtype(type(fill_value))
  175. return torch.full(size, fill_value, **kwargs)
  176. return NotImplemented
  177. @register_decomposition([aten.index_add])
  178. def index_add(
  179. x: torch.Tensor,
  180. dim: int,
  181. index: torch.Tensor,
  182. tensor: torch.Tensor,
  183. *,
  184. alpha: torch.types.Number = 1,
  185. ) -> torch.Tensor:
  186. # If we are not in fbcode and dtype is bfloat16
  187. # fallback to index_add kernel
  188. # see https://github.com/pytorch/pytorch/issues/137425 for details
  189. if not is_fbcode() and x.dtype == torch.bfloat16:
  190. return NotImplemented
  191. else:
  192. return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
  193. # Not really sure how to put this into the main library. PrimTorch wants
  194. # empty_permuted to go to the prim, and typically users don't really want
  195. # to decompose to empty_strided (but inductor is OK with it, because we are
  196. # cool with strides and everything goes to empty_strided)
  197. @register_decomposition([aten.empty_permuted.default])
  198. def empty_permuted(
  199. size: list[Union[int, torch.SymInt]],
  200. physical_layout: list[int],
  201. **kwargs: Any,
  202. ) -> torch.Tensor:
  203. perm = [0] * len(size)
  204. for p, l in enumerate(physical_layout):
  205. perm[l] = p
  206. return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
  207. @register_decomposition([aten.convolution_backward])
  208. def convolution_backward(
  209. grad_output: torch.Tensor,
  210. input: torch.Tensor,
  211. weight: torch.Tensor,
  212. bias_sizes: list[int],
  213. stride: Union[int, list[int]],
  214. padding: Union[int, list[int]],
  215. dilation: Union[int, list[int]],
  216. transposed: bool,
  217. output_padding: list[int],
  218. groups: int,
  219. output_mask: list[bool],
  220. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  221. if not output_mask[2] or not is_gpu(grad_output.device.type):
  222. return NotImplemented
  223. grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
  224. grad_inp, grad_weight, _ = aten.convolution_backward(
  225. grad_output,
  226. input,
  227. weight,
  228. bias_sizes,
  229. stride,
  230. padding,
  231. dilation,
  232. transposed,
  233. output_padding,
  234. groups,
  235. [output_mask[0], output_mask[1], False],
  236. )
  237. return (grad_inp, grad_weight, grad_bias)
  238. @register_decomposition([aten.round.decimals])
  239. def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
  240. ten_pow_decimals = 10.0**decimals
  241. return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
  242. @register_decomposition([aten.bmm])
  243. @pw_cast_for_opmath
  244. def bmm(
  245. self: torch.Tensor,
  246. batch2: torch.Tensor,
  247. out_dtype: Optional[torch.dtype] = None,
  248. ) -> torch.Tensor:
  249. # TODO: Re-enable for mps once our reductions are performant enough
  250. # (https://github.com/pytorch/pytorch/issues/150121)
  251. if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]:
  252. if statically_known_true(self.shape[1] == 1) or statically_known_true(
  253. batch2.shape[2] == 1
  254. ):
  255. out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
  256. return out
  257. if self.device.type == "cpu":
  258. if statically_known_true(self.size(1) == 1) and statically_known_true(
  259. batch2.size(-1) == 1
  260. ):
  261. counters["inductor"]["decompose_bmm"] += 1
  262. return torch.sum(
  263. self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
  264. ).unsqueeze(1)
  265. return NotImplemented
  266. @register_decomposition([aten.addmm])
  267. @pw_cast_for_opmath
  268. def addmm(
  269. self: torch.Tensor,
  270. mat1: torch.Tensor,
  271. mat2: torch.Tensor,
  272. out_dtype: Optional[torch.dtype] = None,
  273. beta: torch.types.Number = 1,
  274. alpha: torch.types.Number = 1,
  275. ) -> torch.Tensor:
  276. if self.device.type == "cpu":
  277. if statically_known_true(mat1.size(0) == 1) and statically_known_true(
  278. mat2.size(-1) == 1
  279. ):
  280. counters["inductor"]["decompose_addmm"] += 1
  281. out = torch.sum(
  282. mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
  283. ).unsqueeze(0)
  284. return alpha * out + beta * self
  285. if (
  286. statically_known_true(mat1.size(0) == 1)
  287. and guard_or_false(mat2.size(0) <= 16)
  288. and guard_or_false(mat2.size(1) <= 16)
  289. ):
  290. counters["inductor"]["decompose_addmm"] += 1
  291. out = (mat1.T * mat2).sum(dim=0, keepdim=True)
  292. return alpha * out + beta * self
  293. return NotImplemented
  294. @register_decomposition([aten.mm])
  295. @pw_cast_for_opmath
  296. def mm(
  297. self: torch.Tensor,
  298. input2: torch.Tensor,
  299. out_dtype: Optional[torch.dtype] = None,
  300. ) -> torch.Tensor:
  301. # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
  302. # todo: Look into why and fix it (hopefully)
  303. # TODO: Re-enable for mps once our reductions are performant enough
  304. # (https://github.com/pytorch/pytorch/issues/150121)
  305. if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]:
  306. if statically_known_true(self.shape[0] == 1) or statically_known_true(
  307. input2.shape[1] == 1
  308. ):
  309. return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
  310. if self.device.type == "cpu":
  311. if (
  312. statically_known_true(self.size(-1) == 1)
  313. and statically_known_true(self.size(0) > 0)
  314. and statically_known_true(input2.size(0) == 1)
  315. and (self.dtype == input2.dtype)
  316. and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
  317. ):
  318. counters["inductor"]["decompose_mm"] += 1
  319. return self * input2
  320. if statically_known_true(self.size(0) == 1) and statically_known_true(
  321. input2.size(-1) == 1
  322. ):
  323. counters["inductor"]["decompose_mm"] += 1
  324. return torch.sum(
  325. self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
  326. ).unsqueeze(0)
  327. return NotImplemented
  328. # This pass does two things:
  329. # - Eliminate cat when there is only one tensor input
  330. # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
  331. # don't remove ALL empty tensors, only the naughty ones)
  332. @register_decomposition([aten.cat.default])
  333. def cat(
  334. tensors: list[torch.Tensor],
  335. dim: int = 0,
  336. ) -> torch.Tensor:
  337. def non_empty_tensor(x: torch.Tensor) -> bool:
  338. # For better or worse, this is a valid cat:
  339. #
  340. # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
  341. #
  342. # We'd like to eliminate naughtiness like this for downstream passes
  343. # like split_cat. The easiest way is to just drop such inputs
  344. # (guarding that they are non-zero).
  345. #
  346. # Is it permissible for this filtering to be size-oblivious? A case
  347. # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
  348. # happened to be zero, we would have liked to have filtered it out.
  349. # But actually, the ONLY way this could have passed is if u0 == 0,
  350. # so by the time we get here we have already installed a deferred
  351. # runtime assert forcing u0 to be zero. So if this hasn't happened,
  352. # we know that the unbacked SymInt has appropriate size and there are
  353. # no problems.
  354. if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0):
  355. return False
  356. if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0):
  357. return False
  358. return True
  359. filtered_tensors = list(filter(non_empty_tensor, tensors))
  360. if len(filtered_tensors) == 1:
  361. # check dtype promotion
  362. promoted_dtype = elementwise_dtypes(
  363. *tensors,
  364. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  365. )[1]
  366. filtered_t = filtered_tensors[0]
  367. return (
  368. filtered_t.clone()
  369. if promoted_dtype == filtered_t.dtype
  370. else filtered_t.to(dtype=promoted_dtype)
  371. )
  372. elif 1 < len(filtered_tensors) < len(tensors):
  373. # on the first call, when we remove empty tensors, we redispatch recursively
  374. return aten.cat.default(filtered_tensors, dim)
  375. # optimization, avoid concat for single, repeated input
  376. if len(filtered_tensors) > 1 and all(
  377. t is filtered_tensors[0] for t in filtered_tensors
  378. ):
  379. inp = filtered_tensors[0]
  380. shape = list(inp.shape)
  381. dim = dim + len(inp.shape) if dim < 0 else dim
  382. shape.insert(dim, len(filtered_tensors))
  383. return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone()
  384. # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
  385. return NotImplemented
  386. @register_decomposition([aten.angle])
  387. def angle(x: torch.Tensor) -> torch.Tensor:
  388. if x.is_complex():
  389. return torch.where(
  390. torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
  391. )
  392. # when x is real number
  393. # if x >= 0, return 0
  394. # if x < 0, return pi
  395. # if x is nan, return nan
  396. _, dtype = elementwise_dtypes(
  397. x,
  398. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  399. )
  400. pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
  401. ret = torch.where(x < 0, pi, 0.0)
  402. return torch.where(torch.isnan(x), float("nan"), ret)
  403. @register_decomposition([aten.add])
  404. def add(
  405. x: torch.Tensor,
  406. y: torch.Tensor,
  407. *,
  408. alpha: Optional[torch.types.Number] = None,
  409. ) -> torch.Tensor:
  410. # Require both x and y to be complex tensors.
  411. x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
  412. y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
  413. if not x_is_complex_tensor or not y_is_complex_tensor:
  414. return NotImplemented
  415. output_size_zero = False
  416. if x.ndim == 0 and y.ndim == 0:
  417. output_size_zero = True
  418. if x.ndim == 0:
  419. x = x.reshape(1)
  420. if y.ndim == 0:
  421. y = y.reshape(1)
  422. z = y
  423. if alpha is not None:
  424. z = alpha * y
  425. complex_type = torch.promote_types(x.dtype, y.dtype)
  426. # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
  427. # when broadcasting the add.
  428. def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor:
  429. """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
  430. # Get the current shape of the tensor
  431. *initial_dims, last_dim = tensor.shape
  432. # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)`
  433. # doubles the last dimension for complex numbers.
  434. if last_dim % 2 != 0:
  435. raise AssertionError(
  436. "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]"
  437. )
  438. # Reshape the tensor
  439. new_shape = (*initial_dims, last_dim // 2, 2)
  440. reshaped_tensor = tensor.view(new_shape)
  441. return reshaped_tensor
  442. # Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation.
  443. x = x + 0
  444. z = z + 0
  445. x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
  446. z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
  447. result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
  448. if output_size_zero:
  449. return result[0]
  450. return result
  451. @register_decomposition([aten.conj_physical])
  452. def conj_physical(self: torch.Tensor) -> torch.Tensor:
  453. if self.is_complex():
  454. return NotImplemented
  455. return self
  456. @register_decomposition([aten.lift, aten.detach_])
  457. def lift(self: torch.Tensor) -> torch.Tensor:
  458. return self
  459. @register_decomposition([aten.fmin, prims.fmin])
  460. def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
  461. return torch.where(torch.isnan(other) | (other > self), self, other)
  462. @register_decomposition([aten.fmax, prims.fmax])
  463. def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
  464. return torch.where(torch.isnan(other) | (other < self), self, other)
  465. @register_decomposition(aten.amax)
  466. def amax(
  467. self: torch.Tensor,
  468. dim: Optional[int] = None,
  469. keepdim: bool = False,
  470. ) -> torch.Tensor:
  471. if self.dtype == torch.bool:
  472. return torch.any(self, dim=dim, keepdim=keepdim)
  473. return NotImplemented
  474. @register_decomposition(aten.amin)
  475. def amin(
  476. self: torch.Tensor,
  477. dim: Optional[int] = None,
  478. keepdim: bool = False,
  479. ) -> torch.Tensor:
  480. if self.dtype == torch.bool:
  481. return torch.all(self, dim=dim, keepdim=keepdim)
  482. return NotImplemented
  483. @register_decomposition([aten.narrow_copy])
  484. def narrow_copy(
  485. self: torch.Tensor,
  486. dim: int,
  487. start: int,
  488. length: int,
  489. ) -> torch.Tensor:
  490. return torch.narrow(self, dim, start, length).clone()
  491. @register_decomposition([aten.view_copy.default])
  492. def view_copy_default(
  493. self: torch.Tensor,
  494. size: list[Union[int, torch.SymInt]],
  495. ) -> torch.Tensor:
  496. return aten.view(self, size).clone()
  497. @register_decomposition([aten.view_copy.dtype])
  498. def view_copy_dtype(
  499. self: torch.Tensor,
  500. dtype: torch.dtype,
  501. ) -> torch.Tensor:
  502. return self.to(dtype).clone()
  503. def _get_shape_permutation_like(
  504. self: torch.Tensor,
  505. ) -> tuple[utils.ShapeType, utils.StrideType]:
  506. physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
  507. shape = [self.shape[l] for l in physical_layout]
  508. permutation = [0] * len(shape)
  509. for p, l in enumerate(physical_layout):
  510. permutation[l] = p
  511. return (shape, permutation)
  512. @register_decomposition(aten.full_like)
  513. def full_like(
  514. self: torch.Tensor,
  515. fill_value: Union[int, float],
  516. *,
  517. dtype: Optional[torch.dtype] = None,
  518. layout: Optional[torch.layout] = None,
  519. device: Optional[torch.device] = None,
  520. pin_memory: bool = False,
  521. requires_grad: bool = False,
  522. memory_format: torch.memory_format = torch.preserve_format,
  523. ) -> torch.Tensor:
  524. dtype = self.dtype if dtype is None else dtype
  525. layout = self.layout if layout is None else layout
  526. device = self.device if device is None else device
  527. if memory_format != torch.preserve_format:
  528. result = torch.full(
  529. self.shape,
  530. fill_value,
  531. dtype=dtype,
  532. layout=layout,
  533. device=device,
  534. pin_memory=pin_memory,
  535. requires_grad=requires_grad,
  536. )
  537. return result.to(memory_format=memory_format)
  538. else:
  539. assert layout == torch.strided
  540. shape, permutation = _get_shape_permutation_like(self)
  541. result = torch.full(
  542. shape,
  543. fill_value,
  544. dtype=dtype,
  545. layout=layout,
  546. device=device,
  547. pin_memory=pin_memory,
  548. requires_grad=requires_grad,
  549. )
  550. if permutation == list(range(len(permutation))):
  551. return result
  552. return result.permute(permutation).clone()
  553. def _rand_like(
  554. rand_fn: Callable[..., torch.Tensor],
  555. self: torch.Tensor,
  556. *,
  557. dtype: Optional[torch.dtype] = None,
  558. device: Optional[torch.device] = None,
  559. memory_format: torch.memory_format = torch.preserve_format,
  560. **kwargs: Any,
  561. ) -> torch.Tensor:
  562. dtype = self.dtype if dtype is None else dtype
  563. device = self.device if device is None else device
  564. if memory_format != torch.preserve_format:
  565. return rand_fn(
  566. self.shape,
  567. dtype=dtype,
  568. device=device,
  569. **kwargs,
  570. ).to(memory_format=memory_format)
  571. shape, permutation = _get_shape_permutation_like(self)
  572. result = rand_fn(
  573. shape,
  574. dtype=dtype,
  575. device=device,
  576. **kwargs,
  577. )
  578. if permutation == list(range(len(permutation))):
  579. return result
  580. return result.permute(permutation).clone()
  581. @register_decomposition(aten.rand_like)
  582. def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
  583. return _rand_like(torch.rand, self, **kwargs)
  584. @register_decomposition(aten.randn_like)
  585. def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
  586. return _rand_like(torch.randn, self, **kwargs)
  587. @register_decomposition(aten.randint_like.default)
  588. def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
  589. return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)
  590. @register_decomposition(aten.randint_like.low_dtype)
  591. def randint_like_low(
  592. self: torch.Tensor, low: int, high: int, **kwargs: Any
  593. ) -> torch.Tensor:
  594. return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)
  595. @register_decomposition(aten.randint.default)
  596. def randint(
  597. high: int,
  598. size: list[Union[int, torch.SymInt]],
  599. **kwargs: Any,
  600. ) -> torch.Tensor:
  601. return aten.randint.low(0, high, size, **kwargs)
  602. @register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
  603. def linear_dynamic_fp16_unpacked_weight(
  604. input: torch.Tensor,
  605. weight: torch.Tensor,
  606. bias: Optional[torch.Tensor] = None,
  607. ) -> torch.Tensor:
  608. packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
  609. return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
  610. input, packed_weight, bias, weight.size()[0]
  611. )
  612. @register_decomposition(_quantized.wrapped_quantized_linear.default)
  613. def wrapped_quantized_linear(
  614. input: torch.Tensor,
  615. input_scale: torch.Tensor,
  616. input_zero_point: torch.Tensor,
  617. weight: torch.Tensor,
  618. weight_scale: torch.Tensor,
  619. weight_zero_point: torch.Tensor,
  620. bias: torch.Tensor,
  621. out_scale: torch.Tensor,
  622. out_zero_point: torch.Tensor,
  623. out_channel: int,
  624. ) -> torch.Tensor:
  625. packed_weight = torch.ops._quantized._wrapped_linear_prepack(
  626. weight, weight_scale, weight_zero_point, bias
  627. )
  628. return torch.ops._quantized._wrapped_quantized_linear_prepacked(
  629. input,
  630. input_scale,
  631. input_zero_point,
  632. packed_weight,
  633. out_scale,
  634. out_zero_point,
  635. out_channel,
  636. )
  637. @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
  638. def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
  639. def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
  640. x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
  641. if sys.byteorder == "little":
  642. return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
  643. else:
  644. return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
  645. scales = bitcast_u8_to_f32(packed[..., -8:-4])
  646. offsets = bitcast_u8_to_f32(packed[..., -4:])
  647. return packed[..., :-8].to(torch.float32) * scales + offsets
  648. @register_decomposition([aten.grid_sampler_2d])
  649. @pw_cast_for_opmath
  650. def grid_sampler_2d(
  651. a: torch.Tensor,
  652. grid: torch.Tensor,
  653. interpolation_mode: int = 0,
  654. padding_mode: int = 0,
  655. align_corners: bool = False,
  656. ) -> torch.Tensor:
  657. # We do not expand the grid (_expand_grid=False) on cpu for performance reasons
  658. # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
  659. # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
  660. # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
  661. # Thus we apply this hack to not expand the grid for this case.
  662. _expand_grid = not (
  663. a.device == torch.device("cpu")
  664. and interpolation_mode == 0
  665. and a.is_contiguous(memory_format=torch.contiguous_format)
  666. )
  667. output = decomp_grid_sampler_2d(
  668. a,
  669. grid=grid,
  670. interpolation_mode=interpolation_mode,
  671. padding_mode=padding_mode,
  672. align_corners=align_corners,
  673. _expand_grid=_expand_grid,
  674. )
  675. return output
  676. @register_decomposition(aten._foreach_addcmul.Scalar)
  677. def _foreach_addcmul_scalar(
  678. self: list[torch.Tensor],
  679. left_tensors: list[torch.Tensor],
  680. right_tensors: list[torch.Tensor],
  681. scalar: float = 1,
  682. ) -> list[torch.Tensor]:
  683. return aten._foreach_add.List(
  684. self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
  685. )
  686. @register_decomposition(aten._foreach_addcdiv.Scalar)
  687. def _foreach_addcdiv_scalar(
  688. self: list[torch.Tensor],
  689. left_tensors: list[torch.Tensor],
  690. right_tensors: list[torch.Tensor],
  691. scalar: float = 1,
  692. ) -> list[torch.Tensor]:
  693. return aten._foreach_add.List(
  694. self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
  695. )
  696. @register_decomposition(aten._foreach_lerp.Scalar)
  697. def _foreach_lerp_scalar(
  698. start_tensors: list[torch.Tensor],
  699. end_tensors: list[torch.Tensor],
  700. weight: torch.types.Number,
  701. ) -> list[torch.Tensor]:
  702. return aten._foreach_add.List(
  703. start_tensors,
  704. aten._foreach_mul.Scalar(
  705. aten._foreach_sub.List(end_tensors, start_tensors), weight
  706. ),
  707. )
  708. @register_decomposition(aten._foreach_lerp.ScalarList)
  709. def _foreach_lerp_scalarlist(
  710. start_tensors: list[torch.Tensor],
  711. end_tensors: list[torch.Tensor],
  712. scalars: list[torch.types.Number],
  713. ) -> list[torch.Tensor]:
  714. return aten._foreach_add.List(
  715. start_tensors,
  716. aten._foreach_mul.ScalarList(
  717. aten._foreach_sub.List(end_tensors, start_tensors), scalars
  718. ),
  719. )
  720. @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
  721. @register_decomposition(aten.miopen_batch_norm)
  722. def miopen_batch_norm(
  723. input: torch.Tensor,
  724. weight: torch.Tensor,
  725. bias: typing.Optional[torch.Tensor],
  726. running_mean: typing.Optional[torch.Tensor],
  727. running_var: typing.Optional[torch.Tensor],
  728. training: bool,
  729. exponential_average_factor: float,
  730. epsilon: float,
  731. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  732. a, b, c = aten.native_batch_norm(
  733. input,
  734. weight,
  735. bias,
  736. running_mean,
  737. running_var,
  738. training,
  739. exponential_average_factor,
  740. epsilon,
  741. )
  742. if training:
  743. return (a, b, c)
  744. return (
  745. a,
  746. weight.new_zeros((0,)),
  747. weight.new_zeros((0,)),
  748. )
  749. @functools.cache
  750. def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
  751. return {**decompositions, **extra_random_decomps}
  752. # TODO(aakhundov): replace this (and the above) Any by more
  753. # specific type and fix all the cascading mypy errors
  754. def select_decomp_table() -> dict[Any, Callable[..., Any]]:
  755. """decomps can change based on config"""
  756. if config.fallback_random:
  757. return decompositions
  758. return fast_random_decomps()
  759. @register_decomposition(aten.masked_scatter)
  760. def masked_scatter(
  761. self: torch.Tensor,
  762. mask: torch.Tensor,
  763. source: torch.Tensor,
  764. ) -> torch.Tensor:
  765. from .codegen.common import BackendFeature, has_backend_feature
  766. if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX):
  767. # This two-step algorithm is the same as eager CUDA, for eager CPU we
  768. # use a 1-shot serial iteration.
  769. self, mask = aten.broadcast_tensors([self, mask])
  770. source_idx = mask.reshape(-1).cumsum(0) - 1
  771. self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source))
  772. result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0)
  773. return torch.where(mask_flat, result, self_flat).view(self.shape)
  774. return NotImplemented
  775. @register_decomposition(quantized_decomposed.choose_qparams.tensor)
  776. def choose_qparams_tensor(
  777. input: torch.Tensor,
  778. quant_min: int,
  779. quant_max: int,
  780. eps: float,
  781. dtype: torch.dtype,
  782. ) -> tuple[torch.Tensor, torch.Tensor]:
  783. min_val, max_val = torch.aminmax(input)
  784. scale = (max_val - min_val) / float(quant_max - quant_min)
  785. scale = torch.max(scale, torch.Tensor([eps]))
  786. zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
  787. zero_point = torch.clamp(zero_point, quant_min, quant_max)
  788. return scale.to(torch.float64), zero_point.to(torch.int64)
  789. @register_decomposition(aten.put)
  790. def put(
  791. self: torch.Tensor,
  792. index: torch.Tensor,
  793. source: torch.Tensor,
  794. accumulate: bool = False,
  795. ) -> torch.Tensor:
  796. flattened = self.flatten()
  797. flattened = torch.index_put(
  798. flattened, [index], source.reshape(index.shape), accumulate
  799. )
  800. return flattened.reshape(self.shape)
  801. @register_decomposition(aten.put_)
  802. def put_(
  803. self: torch.Tensor,
  804. index: torch.Tensor,
  805. source: torch.Tensor,
  806. accumulate: bool = False,
  807. ) -> torch.Tensor:
  808. out = aten.put(self, index, source, accumulate=accumulate)
  809. return self.copy_(out)
  810. @register_decomposition(aten._softmax_backward_data.default)
  811. @pw_cast_for_opmath
  812. def _softmax_backward_data(
  813. grad_output: torch.Tensor,
  814. output: torch.Tensor,
  815. dim: int,
  816. input_dtype: torch.dtype,
  817. ) -> torch.Tensor:
  818. new_grad_output = grad_output * output
  819. sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
  820. # grad_input = new_grad_output - output * sum_new_grad
  821. grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)
  822. # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
  823. # if grad_output.device == torch.device("cpu"):
  824. # return grad_input.contiguous()
  825. if grad_output.dtype != input_dtype:
  826. grad_input = grad_input.to(input_dtype)
  827. return grad_input.contiguous()
  828. @register_decomposition(aten.index_reduce)
  829. def index_reduce(
  830. self: torch.Tensor,
  831. dim: int,
  832. index: torch.Tensor,
  833. src: torch.Tensor,
  834. reduction_type: str,
  835. *,
  836. include_self: bool = True,
  837. ) -> torch.Tensor:
  838. if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
  839. self.dtype
  840. ):
  841. true_division = self.dtype.is_floating_point or self.dtype.is_complex
  842. ones = torch.ones_like(src)
  843. if include_self:
  844. out = self
  845. counts = torch.ones_like(self).index_add(dim, index, ones)
  846. else:
  847. out = self.index_fill(dim, index, 0)
  848. counts = torch.zeros_like(self).index_add(dim, index, ones)
  849. counts = counts.masked_fill(counts < 1, 1)
  850. out = out.index_add(dim, index, src)
  851. return out / counts if true_division else out // counts
  852. if use_scatter_fallback(
  853. aten.scatter_reduce_.two,
  854. reduction_type,
  855. self.dtype,
  856. src.dtype,
  857. src.device.type,
  858. True,
  859. ):
  860. return NotImplemented
  861. repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel()
  862. index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim])
  863. perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim))
  864. scatter_index = (
  865. index.to(torch.int64)
  866. .repeat_interleave(repeats)
  867. .reshape(index_shape)
  868. .permute(perm)
  869. )
  870. return self.scatter_reduce(
  871. dim,
  872. scatter_index,
  873. src,
  874. reduction_type,
  875. include_self=include_self,
  876. )
  877. def _max_pool_with_indices(
  878. x: torch.Tensor,
  879. kernel_size: list[int],
  880. stride: Optional[Union[int, list[int]]],
  881. padding: Union[int, list[int]],
  882. dilation: Union[int, list[int]],
  883. ceil_mode: bool,
  884. dim: int,
  885. ) -> tuple[torch.Tensor, torch.Tensor]:
  886. if dilation == 1:
  887. dilation = [1] * dim
  888. if padding == 0:
  889. padding = [0] * dim
  890. if not stride:
  891. stride = kernel_size
  892. kernel_size = pad_listlike(kernel_size, dim)
  893. dilation = pad_listlike(dilation, dim)
  894. padding = pad_listlike(padding, dim)
  895. stride = pad_listlike(stride, dim)
  896. window_size = functools.reduce(operator.mul, kernel_size)
  897. # We fallback when using non-default dilation or when the window size is too large
  898. if (
  899. torch._inductor.lowering.should_fallback_max_pool_with_indices(
  900. kernel_size, n_dim=dim
  901. )
  902. or window_size > torch.iinfo(torch.int8).max
  903. ):
  904. return NotImplemented
  905. vals, offsets = prims._low_memory_max_pool_with_offsets(
  906. x,
  907. kernel_size,
  908. stride,
  909. padding,
  910. dilation,
  911. ceil_mode,
  912. )
  913. indices = prims._low_memory_max_pool_offsets_to_indices(
  914. offsets,
  915. kernel_size,
  916. x.shape[-dim:],
  917. stride,
  918. padding,
  919. dilation,
  920. )
  921. return vals, indices
  922. @register_decomposition(aten.max_pool2d_with_indices)
  923. def max_pool2d_with_indices(
  924. x: torch.Tensor,
  925. kernel_size: list[int],
  926. stride: Optional[Union[int, list[int]]] = None,
  927. padding: Union[int, list[int]] = 0,
  928. dilation: Union[int, list[int]] = 1,
  929. ceil_mode: bool = False,
  930. ) -> tuple[torch.Tensor, torch.Tensor]:
  931. return _max_pool_with_indices(
  932. x, kernel_size, stride, padding, dilation, ceil_mode, dim=2
  933. )
  934. @register_decomposition(aten.max_pool3d_with_indices)
  935. def max_pool3d_with_indices(
  936. x: torch.Tensor,
  937. kernel_size: list[int],
  938. stride: Optional[Union[int, list[int]]] = None,
  939. padding: Union[int, list[int]] = 0,
  940. dilation: Union[int, list[int]] = 1,
  941. ceil_mode: bool = False,
  942. ) -> tuple[torch.Tensor, torch.Tensor]:
  943. return _max_pool_with_indices(
  944. x, kernel_size, stride, padding, dilation, ceil_mode, dim=3
  945. )
  946. @register_decomposition(aten.adaptive_max_pool2d)
  947. def adaptive_max_pool2d(
  948. x: torch.Tensor, output_size: list[int]
  949. ) -> tuple[torch.Tensor, torch.Tensor]:
  950. *batch, h_in, w_in = x.shape
  951. h_out, w_out = output_size
  952. if h_out == 0 or w_out == 0:
  953. o_size = [*batch, h_out, w_out]
  954. return x.new_empty(o_size), x.new_empty(o_size, dtype=torch.int64)
  955. if h_in % h_out == 0 and w_in % w_out == 0:
  956. kernel_size = [h_in // h_out, w_in // w_out]
  957. return aten.max_pool2d_with_indices(x, kernel_size)
  958. return NotImplemented
  959. @register_decomposition(aten.searchsorted.Scalar)
  960. def searchsorted_scalar(
  961. sorted_sequence: torch.Tensor,
  962. self: torch.types.Number,
  963. *,
  964. out_int32: bool = False,
  965. right: bool = False,
  966. side: Optional[str] = None,
  967. sorter: Optional[torch.Tensor] = None,
  968. ) -> torch.Tensor:
  969. return aten.searchsorted(
  970. sorted_sequence,
  971. torch.tensor([self], device=sorted_sequence.device),
  972. out_int32=out_int32,
  973. right=right,
  974. side=side,
  975. sorter=sorter,
  976. )[0]
  977. @register_decomposition(aten.rrelu_with_noise_functional)
  978. def rrelu_with_noise_functional(
  979. self: torch.Tensor,
  980. noise: torch.Tensor,
  981. lower: float = 0.125,
  982. upper: float = 0.3333333333333333,
  983. training: bool = False,
  984. generator: Optional[torch.Generator] = None,
  985. ) -> tuple[torch.Tensor, torch.Tensor]:
  986. if training:
  987. not_positive = self <= 0
  988. r = aten.uniform(self, lower, upper, generator=generator)
  989. output = torch.where(not_positive, self * r, self)
  990. noise_out = torch.where(not_positive, r, 1)
  991. return output, noise_out
  992. else:
  993. negative_slope = (lower + upper) / 2
  994. return aten.leaky_relu(self, negative_slope), torch.Tensor()
  995. @register_decomposition(aten.repeat_interleave.Tensor)
  996. def repeat_interleave_Tensor(
  997. repeat: torch.Tensor,
  998. output_size: Optional[int] = None,
  999. ) -> torch.Tensor:
  1000. if config.triton.autotune_at_compile_time:
  1001. # We can't compile-time auto-tune this because
  1002. # it expects specific data in `repeat`
  1003. return NotImplemented
  1004. if output_size is None or type(output_size) is not int:
  1005. return NotImplemented
  1006. if repeat.device.type == "mps":
  1007. return NotImplemented
  1008. assert repeat.dtype in [torch.int32, torch.int64]
  1009. assert repeat.ndim == 1
  1010. cumsum = repeat.cumsum(0)
  1011. pos = torch.arange(output_size, device=repeat.device)
  1012. return torch.searchsorted(
  1013. cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
  1014. )