native_function_generation.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. from __future__ import annotations
  2. import string
  3. from collections import defaultdict
  4. from typing import TYPE_CHECKING
  5. import torchgen.api.dispatcher as dispatcher
  6. from torchgen.api.translate import translate
  7. from torchgen.api.types import Binding, DispatcherSignature, Expr
  8. from torchgen.context import with_native_function
  9. from torchgen.model import (
  10. Annotation,
  11. Argument,
  12. BackendIndex,
  13. BackendMetadata,
  14. BaseOperatorName,
  15. BaseTy,
  16. BaseType,
  17. DEFAULT_KERNEL_NAMESPACE,
  18. DeviceCheckType,
  19. DispatchKey,
  20. FunctionSchema,
  21. NativeFunction,
  22. NativeFunctionsGroup,
  23. OperatorName,
  24. Return,
  25. SchemaKind,
  26. Variant,
  27. )
  28. from torchgen.utils import concatMap
  29. if TYPE_CHECKING:
  30. from collections.abc import Sequence
  31. # See Note: [Out ops with functional variants that don't get grouped properly]
  32. OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
  33. # This has a functional variant, but it's currently marked private.
  34. # This function should be marked private as well (*_backward ops aren't exposed to python anyway).
  35. "adaptive_avg_pool3d_backward.grad_input",
  36. # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
  37. # Maybe we can kill this operator in favor of convolution_backward?
  38. "_slow_conv2d_backward.grad_input",
  39. ]
  40. # See Note: [Mutable ops that cannot get an out variant]
  41. MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
  42. # should be out=?
  43. "_cummax_helper",
  44. # should be out=?
  45. "_cummin_helper",
  46. ]
  47. # All of these operators don't have any tensor like returns
  48. FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
  49. "_assert_async", # no return
  50. "_assert_async.msg", # no return
  51. "_assert_tensor_metadata", # no return
  52. "_cslt_sparse_mm_search", # returns an int
  53. "_assert_scalar", # no return
  54. "_dimI", # returns an int
  55. "_dimV", # returns an int
  56. "_has_same_storage_numel", # returns a boolean
  57. "_linalg_check_errors", # no return
  58. "_local_scalar_dense", # returns a Scalar
  59. "_nested_tensor_from_mask_left_aligned", # returns a boolean
  60. "_nnz", # returns an int
  61. "_use_cudnn_ctc_loss", # returns a boolean
  62. "_use_cudnn_ctc_loss.Tensor", # returns a boolean
  63. "_validate_compressed_sparse_indices", # no return
  64. "allclose", # returns a boolean
  65. "dense_dim", # returns an int
  66. "equal", # returns a boolean
  67. "is_coalesced", # returns an boolean
  68. "is_pinned", # returns a boolean
  69. "is_same_size", # returns a boolean
  70. "is_set_to", # returns a boolean
  71. "q_per_channel_axis", # returns an int
  72. "q_scale", # returns a float
  73. "q_zero_point", # returns an int
  74. "qscheme", # returns a QScheme
  75. "record_stream", # no return
  76. "sparse_dim", # returns an int
  77. "sym_constrain_range", # no return
  78. "sym_constrain_range_for_size", # no return
  79. "_nested_tensor_storage_offsets", # returns a vector of ints
  80. "_chunk_grad_outputs_efficient_attention", # returns a bool
  81. "_fused_sdp_choice", # returns an int
  82. "_print", # no return
  83. "_sink_tokens", # no return
  84. "_nested_get_ragged_idx", # returns an int
  85. ]
  86. INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
  87. # polygamma and polygamma.out both exist, but have a
  88. # pre-self arg (while polygamma_ does not)
  89. # We should either fix this schema so it can be grouped properly,
  90. # or allow the codegen to generate new functional/out= NativeFunctions for this op
  91. # (which would require changing its overload name to prevent overload ambiguity).
  92. "polygamma_"
  93. ]
  94. # Groups "similar" NativeFunctions together
  95. # example add.Tensor, add_.Tensor, add.out
  96. # "similar" NativeFunctions are all expected to have an identical `signature()`,
  97. # But have differing SchemaKinds.
  98. def pre_group_native_functions(
  99. native_functions: Sequence[NativeFunction],
  100. ) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
  101. pre_grouped_native_functions: dict[
  102. FunctionSchema, dict[SchemaKind, NativeFunction]
  103. ] = defaultdict(dict)
  104. for f in native_functions:
  105. d = pre_grouped_native_functions[f.func.signature()]
  106. assert f.func.kind() not in d
  107. d[f.func.kind()] = f
  108. return pre_grouped_native_functions
  109. # Returns the out variant overload name given a base function overload name
  110. def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
  111. return "out" if not overload_name else f"{overload_name}_out"
  112. # Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
  113. # Example before:
  114. # _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
  115. # Example after:
  116. # _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
  117. def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  118. # Generating an out= schema from an inplace schema.
  119. assert func.kind() == SchemaKind.inplace
  120. assert func.arguments.self_arg is not None
  121. # The new out= schema has:
  122. # - a new out argument with the same type as "func" (but with a mutable annotation)
  123. # - The returns (if any) now alias the out= argument instead of "func"
  124. # - an "out" overload name
  125. return FunctionSchema(
  126. name=func.name.remove_inplace().with_overload(
  127. get_expected_out_variant_overload_name(func.name.overload_name)
  128. ),
  129. arguments=func.arguments.remove_self_annotation().with_out_args(
  130. [
  131. Argument(
  132. name="out",
  133. type=func.arguments.self_arg.argument.type,
  134. default=None,
  135. annotation=func.arguments.self_arg.argument.annotation,
  136. )
  137. ]
  138. ),
  139. returns=func.returns,
  140. )
  141. # Helper function: given a functional FunctionSchema, generate its corresponding out= variant
  142. # Example before:
  143. # _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
  144. # bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
  145. # Example after:
  146. # _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None,
  147. # Tensor(a!) out) -> Tensor(a!)
  148. def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  149. # Generating an out= schema from a functional schema.
  150. assert func.kind() == SchemaKind.functional
  151. new_returns, new_out_args = generate_out_args_from_schema(func)
  152. # The new out= schema has:
  153. # - one or more new out argument(s) with the same type as returns (but with a mutable annotation)
  154. # - The returns now alias the out= arguments
  155. # - an "_out" overload name
  156. return FunctionSchema(
  157. name=func.name.with_overload(
  158. get_expected_out_variant_overload_name(func.name.overload_name)
  159. ),
  160. arguments=func.arguments.signature().with_out_args(
  161. new_out_args,
  162. ),
  163. returns=tuple(new_returns),
  164. )
  165. # Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
  166. def generate_out_args_from_schema(
  167. func: FunctionSchema,
  168. ) -> tuple[list[Return], list[Argument]]:
  169. # More of a sanity check - our existing restrictions on schemas should enforce that
  170. # mutable schema kinds never return their mutable arguments.
  171. assert not any(
  172. r.annotation is not None and r.annotation.is_write for r in func.returns
  173. )
  174. tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
  175. assert len(tensorlike_rets) > 0
  176. used_annotations = concatMap(
  177. lambda a: [] if a.annotation is None else a.annotation.alias_set,
  178. func.arguments.flat_all,
  179. )
  180. valid_annotations = [x for x in string.ascii_lowercase if x not in used_annotations]
  181. all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
  182. new_out_args: list[Argument] = []
  183. # The end result of new_returns is that:
  184. # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
  185. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
  186. new_returns: list[Return] = []
  187. for i, r in enumerate(func.returns):
  188. if r.type.is_tensor_like():
  189. new_out = Argument(
  190. name="out" if len(func.returns) == 1 else f"out{i}",
  191. type=r.type,
  192. default=None,
  193. annotation=Annotation.parse(f"{valid_annotations[i]}!"),
  194. )
  195. new_out_args.append(new_out)
  196. if all_rets_are_tensors:
  197. # The convention for out= schemas is that they only return their out arguments
  198. # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
  199. new_ret = Return(
  200. name=None, type=new_out.type, annotation=new_out.annotation
  201. )
  202. new_returns.append(new_ret)
  203. else:
  204. new_returns.append(r)
  205. return new_returns, new_out_args
  206. # Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
  207. # Example before:
  208. # _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
  209. # Example after:
  210. # _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950
  211. def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  212. # Generating an out= schema from a mutable schema.
  213. assert func.kind() == SchemaKind.mutable
  214. # The new out= schema has:
  215. # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
  216. # (if the argument is a tensor then we also return it for method chaining,
  217. # otherwise we return nothing)
  218. # - an "out" overload name
  219. #
  220. # Note that:
  221. # (1) This also means that we can *only* generate an out= variant from a mutable schema
  222. # if the mutable schema has at least one tensor-like non-aliasing return.
  223. # (2) The generated out= variant still has mutable positional arguments,
  224. # but if necessary we could probably add another out= variant that also
  225. # functionalizes the mutable arguments (a functional_out variant)
  226. new_returns, new_out_args = generate_out_args_from_schema(func)
  227. return FunctionSchema(
  228. name=func.name.remove_inplace().with_overload(
  229. get_expected_out_variant_overload_name(func.name.overload_name)
  230. ),
  231. arguments=func.arguments.with_out_args(new_out_args),
  232. returns=tuple(new_returns),
  233. )
  234. # This function, given function of one SchemaKind, as well as a target SchemaKind,
  235. # generates a new NativeFunction with the same properties, but using the target SchemaKind.
  236. # We only actually generate functions for either functional or out= SchemaKinds.
  237. # This function returns a tuple, with:
  238. # - The generated NativeFunction
  239. # - a dictionary of `BackendIndex` objects, describing which dispatch keys
  240. # we will generate kernels for, for the new NativeFunction.
  241. # Details are in the function, but we only generate composite kernels (in some cases) today.
  242. def generate_function(
  243. f: NativeFunction, k: SchemaKind
  244. ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
  245. from torchgen.api import cpp
  246. if k == SchemaKind.functional:
  247. assert f.func.kind() != SchemaKind.functional
  248. # The new "functional" NativeFunction has:
  249. # - any mutable arguments have been converted into (immutable) returns.
  250. # (if a mutable argument was not also a return, it gets converted to one)
  251. # - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
  252. # See Note [Overload Ambiguity With Functional Variants]
  253. # The default grouping logic in signature() actually already does this,
  254. # so we can piggy-back off it (but we still want return names)
  255. func = f.func.signature(keep_return_names=True).with_name(
  256. OperatorName(
  257. name=BaseOperatorName(
  258. base=f.func.name.name.base,
  259. inplace=False,
  260. dunder_method=f.func.name.name.dunder_method,
  261. # See Note [Overload Ambiguity With Functional Variants]
  262. functional_overload=f.func.kind() == SchemaKind.mutable,
  263. ),
  264. overload_name=f.func.name.overload_name,
  265. )
  266. )
  267. elif k == SchemaKind.out:
  268. # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
  269. # but at least today, there is no good reason to actually use them.
  270. # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
  271. if f.func.kind() == SchemaKind.inplace:
  272. func = self_to_out_signature(f.func)
  273. elif f.func.kind() == SchemaKind.mutable:
  274. func = mutable_to_out_signature(f.func)
  275. elif f.func.kind() == SchemaKind.functional:
  276. func = functional_to_out_signature(f.func)
  277. else:
  278. raise AssertionError(
  279. "We only bother generating out= functions from either inplace or mutable or functional variants"
  280. )
  281. else:
  282. raise AssertionError(
  283. "We currently only generate either functional or out= NativeFunctions"
  284. )
  285. # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to
  286. # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and
  287. # `randn.generator_with_names_out`.
  288. kernel_name = (
  289. func.name.unambiguous_name()
  290. if func.kind() == SchemaKind.out
  291. else cpp.name(func)
  292. )
  293. if f.func.has_symint():
  294. kernel_name += "_symint"
  295. backend_metadata = {
  296. DispatchKey.CompositeExplicitAutograd: {
  297. func.name: BackendMetadata(
  298. kernel=kernel_name,
  299. structured=False,
  300. cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
  301. )
  302. }
  303. }
  304. tags = {"generated"} | set(
  305. f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"}
  306. )
  307. return (
  308. NativeFunction(
  309. func=func,
  310. use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
  311. # These generated fn's aren't meant to be user friendly- don't generate methods.
  312. variants={Variant.function},
  313. structured=False,
  314. structured_delegate=None,
  315. structured_inherits=None,
  316. precomputed=None,
  317. autogen=[],
  318. ufunc_inner_loop={},
  319. manual_kernel_registration=False,
  320. manual_cpp_binding=False,
  321. python_module=None,
  322. category_override=None,
  323. device_guard=False,
  324. device_check=DeviceCheckType.NoCheck,
  325. loc=f.loc,
  326. cpp_no_default_args=set(),
  327. is_abstract=f.is_abstract,
  328. has_composite_implicit_autograd_kernel=False,
  329. has_composite_implicit_autograd_nested_tensor_kernel=False,
  330. has_composite_explicit_autograd_kernel=True,
  331. has_composite_explicit_autograd_non_functional_kernel=False,
  332. # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
  333. # which NativeFunction objects did not come directly from native_functions.yaml.
  334. tags=tags,
  335. namespace=f.namespace,
  336. ),
  337. backend_metadata,
  338. )
  339. # This function is responsible for adding generated NativeFunctions which don't appear
  340. # explicitly in the codegen.
  341. # You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
  342. # torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
  343. # (Maybe we should make a friendly API for this)
  344. #
  345. # Note: this function *mutates* its two inputs,
  346. # adding the new NativeFunctions / BackendMetadata to them
  347. def add_generated_native_functions(
  348. rs: list[NativeFunction],
  349. indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
  350. ) -> None:
  351. # The main code for generating new NativeFunctions
  352. # First we group of NativeFunctions by schema kind,
  353. # then we detect which ones are missing and generate them.
  354. pre_grouped_native_functions = pre_group_native_functions(rs)
  355. for d in pre_grouped_native_functions.values():
  356. has_functional = SchemaKind.functional in d
  357. has_inplace = SchemaKind.inplace in d
  358. has_mutable = SchemaKind.mutable in d
  359. has_out = SchemaKind.out in d
  360. is_core = any("core" in variant.tags for variant in d.values())
  361. # We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
  362. # (1) If an operator has an inplace/out= variant but no functional variant, we can generate
  363. # a simple functional variant that the functionalization pass can consume.
  364. # (2) If an operator has an inplace or functional but no out= variant, we generate an out=
  365. # variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
  366. # while maintaining the constraint that the out= variant is "required".
  367. if has_mutable or has_inplace or has_out or has_functional:
  368. # Don't bother generating functions trio's for native functions that bypass the dispatcher.
  369. are_manual = all(f.manual_cpp_binding for f in d.values())
  370. # Don't bother generating functional + out= variants for view operators
  371. # set_ is technically an inplace_view, but for now it is treated
  372. # as a normal inplace op in the codegen
  373. has_view_ops = any(
  374. f.is_view_op and str(f.func.name.name) != "set_" for f in d.values()
  375. )
  376. # Don't generate the other variants for non-core CompositeImplicitAutograd operators.
  377. # We could probably do this, but the main benefit of generating the function triplets
  378. # is for transforms that need them, and transforms don't need to act directly
  379. # on CompositeImplicitAutograd operators (since we let them decompose).
  380. are_composite_implicit = all(
  381. f.has_composite_implicit_autograd_kernel for f in d.values()
  382. )
  383. if are_manual or has_view_ops or are_composite_implicit and not is_core:
  384. continue
  385. if has_out and len(d.values()) == 1:
  386. # Note: [Out ops with functional variants that don't get grouped properly]
  387. # In theory we could validly have an out= operator in native_functions.yaml
  388. # that has no other variants.
  389. # But today, all of the operators where that's the case actually do have
  390. # functional variants, that we are just unable to pair up properly.
  391. # I think banning this all together is probably safer
  392. # (you can always add a functional variant yourself if you want to add a new out= operator).
  393. #
  394. # We should probably fix the existing cases; this check is to prevent us from adding more over time.
  395. if (
  396. str(d[SchemaKind.out].func.name)
  397. not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
  398. ):
  399. raise AssertionError(
  400. f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
  401. )
  402. continue
  403. # Some inplace ops that have problematic schemas (that we should fix), which prevent us
  404. # from generating out= and functional variants
  405. if (
  406. has_inplace
  407. and str(d[SchemaKind.inplace].func.name)
  408. in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
  409. ):
  410. continue
  411. base_fn = (
  412. d[SchemaKind.mutable]
  413. if has_mutable
  414. else d[SchemaKind.inplace]
  415. if has_inplace
  416. else d[SchemaKind.out]
  417. if has_out
  418. else d[SchemaKind.functional]
  419. )
  420. # Note: [Mutable ops that cannot get an out variant]
  421. # We can only generate an out= variant if either:
  422. # - the original function has tensor-like returns (since we can convert them to out kwargs)
  423. # - or it's inplace (since we can convert `self` to an out kwarg)
  424. # There are only two functions that don't fit this criteria today though,
  425. # and they both look like they should be fixed to be out= variants,
  426. # so if feels safer to ban this schema all-together
  427. base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any(
  428. r.type.is_tensor_like() for r in base_fn.func.returns
  429. )
  430. # Note: [Loosen the assertion that all functional should have out variant]
  431. # By design all functional operators should have our variants. The needs_out check
  432. # is loosening this requirement, changing it to only generate out variant if there's
  433. # an `autogen` block in the native function, in the long run it should be removed.
  434. # FIXME: Remove this after figuring out CI job failures related to min, max, mean
  435. needs_out = any("out" in str(op_name) for op_name in base_fn.autogen)
  436. gets_out_variant = not has_out and base_fn_valid and needs_out
  437. if not has_out and not base_fn_valid:
  438. if (
  439. str(base_fn.func.name)
  440. not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
  441. and str(base_fn.func.name)
  442. not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
  443. ):
  444. raise AssertionError(
  445. f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}.
  446. This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If
  447. out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list."""
  448. )
  449. # Generate an out= variant
  450. if gets_out_variant:
  451. fn, metadata = generate_function(base_fn, SchemaKind.out)
  452. d[SchemaKind.out] = fn
  453. BackendIndex.grow_index(indices, metadata)
  454. rs.append(fn)
  455. # Generate a functional variant, but only do it if the operator got an out= variant
  456. # (Functional variants are only useful if we can group up the variants,
  457. # which we can only do if they have an out= variant)
  458. if not has_functional and (has_out or gets_out_variant):
  459. fn, metadata = generate_function(base_fn, SchemaKind.functional)
  460. d[SchemaKind.functional] = fn
  461. BackendIndex.grow_index(indices, metadata)
  462. rs.append(fn)
  463. def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
  464. assert len(rets) == len(names)
  465. if len(rets) == 0:
  466. return ""
  467. elif len(rets) == 1:
  468. return f"return {names[0]};"
  469. else:
  470. return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
  471. # Given a function, and the name of a variable corresponding to the output of that function,
  472. # gather up all of the individual returns that are not aliased
  473. def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
  474. aliased_rets = func.aliased_return_names()
  475. non_aliased_names = []
  476. is_out_var_a_tuple = len(func.returns) > 1
  477. for i, r in enumerate(aliased_rets):
  478. if r is None:
  479. non_aliased_names.append(
  480. f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
  481. )
  482. return non_aliased_names
  483. # Generates functional kernels in terms of their inplace.mutable counterparts.
  484. # We only do this for "generated" NativeFunctions
  485. @with_native_function
  486. def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
  487. # We should only be generating these for code-generated NativeFunctions
  488. if "generated" not in g.functional.tags:
  489. return None
  490. # And we always write the kernel for a generated op in terms of a non-generated op.
  491. if g.inplace is not None and "generated" not in g.inplace.tags:
  492. target_f = g.inplace
  493. elif g.mutable is not None and "generated" not in g.mutable.tags:
  494. target_f = g.mutable
  495. else:
  496. # We should be guaranteed to have a valid inplace/mutable variant to call into.
  497. # See Note: [Mutable Ops Not Using Functionalization]
  498. raise AssertionError(str(g.functional.func))
  499. sig = DispatcherSignature(g.functional.func)
  500. target_sig = DispatcherSignature(target_f.func)
  501. context: list[Binding | Expr] = []
  502. clone_mutable_inputs = []
  503. cloned_return_names = []
  504. # We can't just directly pass all of the arguments from the functional op into the mutating op.
  505. # We need to check for which inputs to the mutating operator are mutable,
  506. # and clone those inputs first.
  507. for a_curr, a_tgt in zip(
  508. dispatcher.jit_arguments(g.functional.func),
  509. dispatcher.jit_arguments(target_f.func),
  510. ):
  511. if a_tgt.annotation is not None and a_tgt.annotation.is_write:
  512. clone_mutable_inputs.append(
  513. f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
  514. )
  515. context.append(
  516. Expr(
  517. expr=f"{a_curr.name}_clone",
  518. type=dispatcher.argument_type(a_curr, binds=a_curr.name),
  519. )
  520. )
  521. # Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
  522. cloned_return_names.append(f"{a_curr.name}_clone")
  523. else:
  524. context.append(dispatcher.argument(a_curr))
  525. exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
  526. out_name = "output"
  527. maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
  528. inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
  529. ret_str = return_str(
  530. g.functional.func.returns, inner_return_names + cloned_return_names
  531. )
  532. clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
  533. return f"""
  534. {sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
  535. {clone_mutable_inputs_str}
  536. {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
  537. {ret_str}
  538. }}
  539. """
  540. # Generates out= kernels in terms of their functional counterparts.
  541. # We only do this for "generated" NativeFunctions
  542. @with_native_function
  543. def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
  544. # We should only be generating these for code-generated NativeFunctions
  545. if "generated" not in g.out.tags:
  546. return None
  547. # And we always write the kernel for the out= op in terms of the functional.
  548. # Note that the functional op might have also been generated, but we don't have to
  549. # worry about cycles, because the generated functional kernels are always implemented
  550. # in terms of non-generated kernels (see gen_composite_functional_kernel).
  551. sig = DispatcherSignature(g.out.func)
  552. target_sig = DispatcherSignature(g.functional.func)
  553. exprs = ", ".join(
  554. [e.expr for e in translate(sig.arguments(), target_sig.arguments())]
  555. )
  556. copy_outs = []
  557. out_name = "tmp_output"
  558. for i, out_arg in enumerate(g.out.func.arguments.out):
  559. functional_return_name = (
  560. out_name
  561. if len(g.functional.func.returns) == 1
  562. else f"std::get<{i}>({out_name})"
  563. )
  564. copy_outs.append(
  565. f"""\
  566. resize_out_helper({out_arg.name}, {functional_return_name});
  567. copy_arg({out_arg.name}, {functional_return_name});"""
  568. )
  569. rets = []
  570. # For each return arg in the calling (out=) operator,
  571. # If it corresponds to an aliased input, return the input.
  572. # Otherwise, return the corresponding output from calling the functional operator.
  573. for i, ret_name in enumerate(g.out.func.aliased_return_names()):
  574. if ret_name is not None:
  575. rets.append(ret_name)
  576. else:
  577. functional_return_name = (
  578. out_name
  579. if len(g.functional.func.returns) == 1
  580. else f"std::get<{i}>({out_name})"
  581. )
  582. rets.append(functional_return_name)
  583. copy_outs_str = "\n".join(copy_outs)
  584. # Kernel name needs to follow the naming convention defined in `generate_function()`
  585. return f"""
  586. {sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
  587. auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
  588. {copy_outs_str}
  589. {return_str(g.out.func.returns, rets)}
  590. }}
  591. """