register_dispatch_key.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. from __future__ import annotations
  2. import itertools
  3. import textwrap
  4. from dataclasses import dataclass
  5. from typing import Literal, TYPE_CHECKING
  6. from typing_extensions import assert_never
  7. import torchgen.api.cpp as cpp
  8. import torchgen.api.meta as meta
  9. import torchgen.api.structured as structured
  10. from torchgen.api.translate import translate
  11. from torchgen.api.types import (
  12. BaseCType,
  13. Binding,
  14. ConstRefCType,
  15. CppSignature,
  16. CppSignatureGroup,
  17. DispatcherSignature,
  18. Expr,
  19. kernel_signature,
  20. MutRefCType,
  21. NamedCType,
  22. NativeSignature,
  23. tensorT,
  24. )
  25. from torchgen.context import method_with_native_function, native_function_manager
  26. from torchgen.model import (
  27. Argument,
  28. BackendIndex,
  29. DeviceCheckType,
  30. DispatchKey,
  31. gets_generated_out_inplace_wrapper,
  32. is_cuda_dispatch_key,
  33. NativeFunction,
  34. NativeFunctionsGroup,
  35. SchemaKind,
  36. TensorOptionsArguments,
  37. )
  38. from torchgen.utils import mapMaybe, Target
  39. if TYPE_CHECKING:
  40. from torchgen.selective_build.selector import SelectiveBuilder
  41. def gen_registration_headers(
  42. backend_index: BackendIndex,
  43. per_operator_headers: bool,
  44. rocm: bool,
  45. ) -> list[str]:
  46. if per_operator_headers:
  47. headers = ["#include <ATen/ops/as_strided_native.h>"]
  48. else:
  49. headers = ["#include <ATen/NativeFunctions.h>"]
  50. if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
  51. headers.append("#include <ATen/EmptyTensor.h>")
  52. elif backend_index.dispatch_key == DispatchKey.CUDA:
  53. if rocm:
  54. headers.append("#include <ATen/hip/EmptyTensor.h>")
  55. else:
  56. headers.append("#include <ATen/cuda/EmptyTensor.h>")
  57. elif backend_index.dispatch_key == DispatchKey.MPS:
  58. headers.append("#include <ATen/mps/EmptyTensor.h>")
  59. elif backend_index.dispatch_key == DispatchKey.XPU:
  60. # XPU specific, this header resides in third_party/torch-xpu-ops
  61. headers.append("#include <ATen/xpu/EmptyTensor.h>")
  62. elif backend_index.dispatch_key == DispatchKey.MTIA:
  63. headers.append("#include <ATen/native/mtia/EmptyTensor.h>")
  64. elif per_operator_headers:
  65. headers += [
  66. "#include <ATen/ops/empty.h>",
  67. "#include <ATen/ops/empty_strided.h>",
  68. "#include <ATen/ops/_copy_from_and_resize.h>",
  69. "#include <ATen/ops/_copy_from.h>",
  70. ]
  71. else:
  72. headers.append("#include <ATen/Functions.h>")
  73. headers.append("#include <c10/macros/Macros.h>")
  74. return headers
  75. def gen_empty_impl_names(
  76. backend_index: BackendIndex,
  77. ) -> tuple[str | None, str | None]:
  78. empty_impl = None
  79. empty_strided_impl = None
  80. if backend_index.dispatch_key in (
  81. DispatchKey.Meta,
  82. DispatchKey.CPU,
  83. DispatchKey.CUDA,
  84. DispatchKey.MPS,
  85. DispatchKey.XPU,
  86. DispatchKey.MTIA,
  87. ):
  88. dispatch = str(backend_index.dispatch_key).lower()
  89. empty_impl = f"at::detail::empty_{dispatch}"
  90. empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
  91. elif backend_index.dispatch_key in (
  92. DispatchKey.CompositeExplicitAutogradNonFunctional,
  93. DispatchKey.QuantizedCPU,
  94. DispatchKey.QuantizedCUDA,
  95. DispatchKey.XPU,
  96. ):
  97. empty_impl = "at::empty"
  98. empty_strided_impl = "at::empty_strided"
  99. return empty_impl, empty_strided_impl
  100. def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
  101. if backend_index.dispatch_key == DispatchKey.Meta:
  102. empty_options = "options.device(at::kMeta)"
  103. else:
  104. empty_options = "options"
  105. empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
  106. if empty_impl is None:
  107. return []
  108. return [
  109. f"""
  110. Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
  111. if (strides.empty()) {{
  112. return {empty_impl}(sizes, {empty_options});
  113. }} else {{
  114. return {empty_strided_impl}(sizes, strides, {empty_options});
  115. }}
  116. }}
  117. """
  118. ]
  119. def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
  120. _, empty_strided_impl = gen_empty_impl_names(backend_index)
  121. return (
  122. []
  123. if empty_strided_impl is None
  124. else [
  125. f"""
  126. std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
  127. if (out.strides() != strides) {{
  128. return {empty_strided_impl}(sizes, strides, options);
  129. }}
  130. return std::nullopt;
  131. }}
  132. """
  133. ]
  134. )
  135. def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
  136. if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
  137. # The function isn't used by this key (since only functional ops have a kernel for this key),
  138. # so we need to not include it to avoid a defined-but-not-used error.
  139. return []
  140. return [
  141. """
  142. void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
  143. TORCH_CHECK(options.dtype() == out.dtype(),
  144. "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
  145. TORCH_CHECK(options.device() == out.device(),
  146. "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
  147. const bool resized = at::native::resize_output(out, sizes);
  148. // Only restride if a resize occurred; otherwise we ignore the (advisory)
  149. // strides from the meta function and directly use the output tensor's
  150. // preexisting strides
  151. if (resized) {
  152. if (!strides.empty()) {
  153. TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
  154. // TODO: avoid the redispatch here
  155. out.as_strided_(sizes, strides);
  156. } else if (options.memory_format_opt().has_value()) {
  157. out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
  158. }
  159. }
  160. }
  161. """
  162. ]
  163. def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
  164. return [
  165. """
  166. void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
  167. // These checks are needed on those operators that:
  168. // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
  169. // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
  170. // For other operators (e.g. 'add'), 'TensorIterator' already checks
  171. // these things separately.
  172. TORCH_CHECK(options.dtype() == self.dtype(),
  173. "Bad in-place call: ",
  174. "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
  175. TORCH_CHECK(options.device() == self.device(),
  176. "Bad in-place call: ",
  177. "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
  178. TORCH_CHECK(sizes == self.sizes(),
  179. "Bad in-place call: ",
  180. "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
  181. }
  182. """
  183. ]
  184. def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
  185. return [
  186. 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
  187. *gen_create_out_helper(backend_index),
  188. *gen_resize_out_helper(backend_index),
  189. *gen_check_inplace_helper(backend_index),
  190. *gen_maybe_create_proxy_helper(backend_index),
  191. "C10_DIAGNOSTIC_POP()",
  192. ]
  193. # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
  194. #
  195. # - The primary function of this file is to register all of the
  196. # implementations for the given dispatch key to the dispatcher,
  197. # so they are available for use in PyTorch. If dispatch is
  198. # None, we generate schema (def) registrations and catchall
  199. # registrations.
  200. # - The secondary function of this file is to generate a wrapper
  201. # around functions. In CPUType these wrappers do nothing
  202. # (and should be removed), but in other cases they handle
  203. # DeviceGuard. A small extra benefit of wrappers is they
  204. # are not overloaded, so they can be used in the registration
  205. # API without having to disambiguate which overload you want
  206. # (as would be the case if you directly registered native::
  207. # functions).
  208. # - The tertiary function of this file is to generate *static*
  209. # cpp API bindings which can be used to bypass dispatcher
  210. # directly to kernels, but with user-friendly cpp-style API
  211. @dataclass(frozen=True)
  212. class RegisterDispatchKey:
  213. backend_index: BackendIndex
  214. target: Literal[
  215. Target.ANONYMOUS_DEFINITION,
  216. Target.NAMESPACED_DEFINITION,
  217. Target.NAMESPACED_DECLARATION,
  218. Target.REGISTRATION,
  219. ]
  220. # Selector object to determine which operators to generate
  221. # registration code for.
  222. selector: SelectiveBuilder
  223. # Whether or not we are actually code-genning for ROCm
  224. rocm: bool
  225. # Whether or not to generate symint registrations or not. External users
  226. # of codegen who don't care about symints can set this to false to get
  227. # non-SymInt codegen
  228. symint: bool
  229. # The class that all unstructured native functions live under. This is used to improve
  230. # compiler error messages when a kernel writer adds a native function with the wrong signature.
  231. # This is only used in unstructured kernels, since structured kernels already live in a class.
  232. # Finally, this field is currently Optional because it is only used by external backends.
  233. # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
  234. # all of the existing kernel signatures scattered across aten/src/ATen/native.
  235. class_method_name: str | None
  236. # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
  237. # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
  238. skip_dispatcher_op_registration: bool
  239. @staticmethod
  240. def gen_device_check(
  241. type: DeviceCheckType, args: list[Argument], method_name: str
  242. ) -> str:
  243. if type == DeviceCheckType.NoCheck:
  244. return " // No device check\n"
  245. device_check = "std::optional<Device> common_device = std::nullopt;\n"
  246. device_check += "(void)common_device; // Suppress unused variable warning\n"
  247. for arg in args:
  248. # Only tensor like arguments are eligible
  249. if arg.type.is_tensor_like():
  250. device_check += f"""
  251. c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
  252. return device_check
  253. @method_with_native_function
  254. def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
  255. if isinstance(f, NativeFunctionsGroup):
  256. g: NativeFunctionsGroup = f
  257. # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
  258. # gen_structured() has special logic to handle auto-generated kernels.
  259. if g.structured:
  260. return self.gen_structured(g)
  261. else:
  262. return list(
  263. mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
  264. )
  265. elif isinstance(f, NativeFunction):
  266. r = self.gen_unstructured(f)
  267. return [] if r is None else [r]
  268. else:
  269. assert_never(f)
  270. def wrapper_kernel_sig(
  271. self, f: NativeFunction
  272. ) -> NativeSignature | DispatcherSignature:
  273. # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
  274. return DispatcherSignature.from_schema(
  275. f.func,
  276. prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
  277. symint=self.symint,
  278. )
  279. def gen_out_inplace_wrapper(
  280. self, f: NativeFunction, g: NativeFunctionsGroup | None
  281. ) -> str | None:
  282. if g is None:
  283. return None
  284. k = f.func.kind()
  285. if k is SchemaKind.inplace:
  286. copy_op = "at::_copy_from"
  287. elif k is SchemaKind.out:
  288. copy_op = "at::_copy_from_and_resize"
  289. else:
  290. raise AssertionError("gen_out_inplace_wrapper called on a functional op")
  291. sig = self.wrapper_kernel_sig(f)
  292. name = sig.name()
  293. func_res = f"{name}_tmp"
  294. return_names = cpp.return_names(f)
  295. if len(return_names) > 1:
  296. updates = "\n ".join(
  297. f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
  298. for i, ret_name in enumerate(return_names)
  299. )
  300. returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})"
  301. elif len(return_names) == 1:
  302. ret_name = return_names[0]
  303. updates = f"{copy_op}({func_res}, {ret_name});"
  304. returns = ret_name
  305. else:
  306. assert len(f.func.arguments.out) == 1
  307. returns = ""
  308. out_arg = f.func.arguments.out[0]
  309. if out_arg.type.is_list_like():
  310. updates = f"""\
  311. for (int64_t i = 0; i < {func_res}.size(); ++i) {{
  312. {copy_op}({func_res}[i], {out_arg.name}[i]);
  313. }}"""
  314. else:
  315. updates = f"{copy_op}({func_res}, {out_arg.name});"
  316. functional_sig = self.wrapper_kernel_sig(g.functional)
  317. wrapper_name = sig.name()
  318. return f"""\
  319. {sig.defn(name=wrapper_name)} {{
  320. auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
  321. {updates}
  322. return {returns};
  323. }}
  324. """
  325. def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
  326. metadata = self.backend_index.get_kernel(g)
  327. if self.backend_index.dispatch_key == DispatchKey.Meta:
  328. assert not self.backend_index.has_kernel(g.out), (
  329. "Do not explicitly specify Meta dispatch key on structured "
  330. "functions, they will be automatically generated for you"
  331. )
  332. elif (
  333. self.backend_index.dispatch_key
  334. == DispatchKey.CompositeExplicitAutogradNonFunctional
  335. ):
  336. assert not self.backend_index.has_kernel(g.out), (
  337. "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
  338. "functions, they will be automatically generated for you"
  339. )
  340. elif metadata is None or not metadata.structured:
  341. return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
  342. structured_gen = StructuredRegisterDispatchKey(
  343. self.backend_index,
  344. self.target,
  345. self.selector,
  346. self.rocm,
  347. self.symint,
  348. self.class_method_name,
  349. self.skip_dispatcher_op_registration,
  350. g,
  351. )
  352. return list(mapMaybe(structured_gen.gen_one, g.functions()))
  353. def gen_unstructured(
  354. self, f: NativeFunction, g: NativeFunctionsGroup | None = None
  355. ) -> str | None:
  356. with native_function_manager(f):
  357. inplace_meta = False
  358. gets_out_inplace_wrapper = False
  359. if not self.backend_index.has_kernel(f):
  360. if (
  361. self.backend_index.dispatch_key == DispatchKey.Meta
  362. and f.func.kind() is SchemaKind.inplace
  363. and
  364. # Defer to composites for meta implementation
  365. not f.has_composite_kernel
  366. and
  367. # Inplace list operations are not supported
  368. len(f.func.returns) == 1
  369. ):
  370. inplace_meta = True
  371. elif (
  372. not self.backend_index.use_out_as_primary
  373. and g is not None
  374. and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
  375. ):
  376. # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
  377. gets_out_inplace_wrapper = True
  378. else:
  379. return None
  380. if f.manual_kernel_registration:
  381. return None
  382. if (
  383. self.target is Target.REGISTRATION
  384. and not self.selector.is_native_function_selected(f)
  385. ):
  386. return None
  387. sig = self.wrapper_kernel_sig(f)
  388. name = sig.name()
  389. returns_type = sig.returns_type().cpp_type()
  390. args = sig.arguments()
  391. args_str = ", ".join(a.defn() for a in args)
  392. # See Note [Direct dispatch bindings]
  393. cpp_sig_group = CppSignatureGroup.from_native_function(
  394. f, method=False, fallback_binding=False
  395. )
  396. # TODO: dedupe this with the structured codegen
  397. if self.target is Target.NAMESPACED_DECLARATION:
  398. result = ""
  399. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  400. result += f"TORCH_API {cpp_sig.decl()};\n"
  401. return result
  402. elif self.target is Target.NAMESPACED_DEFINITION:
  403. def generate_defn(cpp_sig: CppSignature) -> str:
  404. return f"""
  405. {cpp_sig.defn()} {{
  406. return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
  407. }}
  408. """
  409. result = ""
  410. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  411. result += generate_defn(cpp_sig)
  412. return result
  413. elif self.target is Target.ANONYMOUS_DEFINITION:
  414. # short circuit for inplace_meta
  415. if inplace_meta:
  416. assert f.func.arguments.self_arg is not None
  417. self_arg_name = f.func.arguments.self_arg.argument.name
  418. # TODO: handle in place on tensor list
  419. return f"""
  420. {returns_type} {name}({args_str}) {{
  421. TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
  422. "Cannot inplace into non-meta tensor with meta tensor argument");
  423. return {self_arg_name};
  424. }}
  425. """
  426. # short circuit for generated inplace/out wrappers
  427. if gets_out_inplace_wrapper:
  428. return self.gen_out_inplace_wrapper(f, g)
  429. metadata = self.backend_index.get_kernel(f)
  430. if metadata is None:
  431. return None
  432. if self.class_method_name is None:
  433. impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
  434. else:
  435. impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
  436. kernel_sig = kernel_signature(f, self.backend_index)
  437. args_exprs_str = ", ".join(
  438. e.expr
  439. for e in translate(
  440. sig.arguments(), kernel_sig.arguments(), method=False
  441. )
  442. )
  443. device_check = " // No device check\n"
  444. # Backends that require device guards presumably also require device checks.
  445. if self.backend_index.device_guard:
  446. device_check_args = itertools.chain(
  447. f.func.arguments.out, f.func.arguments.flat_positional
  448. )
  449. device_check = RegisterDispatchKey.gen_device_check(
  450. f.device_check, list(device_check_args), name
  451. )
  452. device_guard = "// DeviceGuard omitted" # default
  453. if f.device_guard and self.backend_index.device_guard:
  454. has_tensor_options = any(
  455. isinstance(a, TensorOptionsArguments)
  456. for a in f.func.arguments.non_out
  457. )
  458. if has_tensor_options:
  459. # kernel is creating a tensor
  460. device_guard = """
  461. const DeviceGuard device_guard(device_or_default(device));"""
  462. # CUDA requires special handling
  463. if is_cuda_dispatch_key(self.backend_index.dispatch_key):
  464. device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}"
  465. else:
  466. # kernel is operating on existing tensors
  467. # There is precedence for which argument we use to do
  468. # device guard. This describes the precedence order.
  469. self_arg = (
  470. [f.func.arguments.self_arg.argument]
  471. if f.func.arguments.self_arg is not None
  472. else []
  473. )
  474. candidate_args = itertools.chain(
  475. self_arg,
  476. f.func.arguments.out,
  477. f.func.arguments.flat_positional,
  478. )
  479. # Only tensor like arguments are eligible
  480. device_of = next(
  481. (
  482. f"{a.name}"
  483. for a in candidate_args
  484. if a.type.is_tensor_like()
  485. ),
  486. None,
  487. )
  488. if device_of is not None:
  489. device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
  490. return f"""\
  491. namespace {{
  492. {returns_type} {name}({args_str}) {{
  493. {device_check}
  494. {device_guard}
  495. return {impl_name}({args_exprs_str});
  496. }}
  497. }} // anonymous namespace
  498. """
  499. elif self.target is Target.REGISTRATION:
  500. if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
  501. return None
  502. else:
  503. payload = f"TORCH_FN({name})"
  504. return f'm.impl("{f.func.name}",\n{payload});\n'
  505. else:
  506. assert_never(self.target)
  507. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  508. #
  509. # STRUCTURED
  510. #
  511. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  512. @dataclass(frozen=True)
  513. class StructuredRegisterDispatchKey(RegisterDispatchKey):
  514. g: NativeFunctionsGroup
  515. def gen_class_set_output_functions(
  516. self, k: SchemaKind, parent_class: str, generate_super: bool
  517. ) -> str:
  518. if generate_super:
  519. set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
  520. else:
  521. set_output_super = ""
  522. def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
  523. return f"""
  524. void set_output_{name}(
  525. int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
  526. TensorOptions options, DimnameList names
  527. ) override {{
  528. {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
  529. if (!names.empty()) {{
  530. namedinference::propagate_names(outputs_[output_idx], names);
  531. }}
  532. // super must happen after, so that downstream can use maybe_get_output
  533. // to retrieve the output
  534. {textwrap.indent(set_output_super, " ")}
  535. }}
  536. """
  537. return f"""
  538. {gen_set_output_function("strided", maybe_create_proxy=True)}
  539. {gen_set_output_function("raw_strided", maybe_create_proxy=False)}
  540. """
  541. def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
  542. if self.backend_index.dispatch_key in [
  543. DispatchKey.CUDA,
  544. DispatchKey.MPS,
  545. DispatchKey.XPU,
  546. DispatchKey.CompositeExplicitAutogradNonFunctional,
  547. ]:
  548. maybe_set_guard = """
  549. auto current_device = guard_.current_device();
  550. if (C10_UNLIKELY(current_device.has_value())) {
  551. TORCH_INTERNAL_ASSERT(*current_device == options.device(),
  552. "structured kernels don't support multi-device outputs");
  553. } else {
  554. guard_.reset_device(options.device());
  555. }
  556. """
  557. maybe_set_guard_line = maybe_set_guard + "\n"
  558. else:
  559. maybe_set_guard_line = maybe_set_guard = ""
  560. if maybe_create_proxy:
  561. create_proxy = """
  562. auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
  563. if (C10_UNLIKELY(maybe_proxy.has_value())) {
  564. proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
  565. }
  566. """
  567. else:
  568. create_proxy = ""
  569. if k is SchemaKind.functional:
  570. assert self.backend_index.dispatch_key in (
  571. DispatchKey.Meta,
  572. DispatchKey.CPU,
  573. DispatchKey.CUDA,
  574. DispatchKey.MPS,
  575. DispatchKey.XPU,
  576. DispatchKey.MTIA,
  577. DispatchKey.CompositeExplicitAutogradNonFunctional,
  578. )
  579. return f"""{maybe_set_guard_line}
  580. outputs_[output_idx] = create_out(sizes, strides, options);"""
  581. elif k is SchemaKind.inplace:
  582. return f"""{maybe_set_guard_line}
  583. const auto& out = outputs_[output_idx].get();
  584. check_inplace(out, sizes, options);
  585. {create_proxy}"""
  586. elif k is SchemaKind.out:
  587. return f"""{maybe_set_guard_line}
  588. const auto& out = outputs_[output_idx].get();
  589. resize_out(out, sizes, strides, options);
  590. {create_proxy}"""
  591. elif k is SchemaKind.mutable or k is SchemaKind.scratch:
  592. raise AssertionError(
  593. f"{k} structured operators are currently not supported"
  594. )
  595. else:
  596. assert_never(k)
  597. # returns the definition of a ctor, as well as how to construct
  598. # this class to a variable named op
  599. def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
  600. if k is SchemaKind.functional:
  601. return ""
  602. elif k is SchemaKind.inplace:
  603. # TODO: Make sure out argument is guaranteed to be self
  604. return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
  605. elif k is SchemaKind.out:
  606. out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
  607. out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
  608. return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
  609. elif k is SchemaKind.mutable or k is SchemaKind.scratch:
  610. raise AssertionError(
  611. f"{k} structured operators are currently not supported"
  612. )
  613. else:
  614. assert_never(k)
  615. def gen_class(
  616. self,
  617. f: NativeFunction,
  618. k: SchemaKind,
  619. *,
  620. class_name: str,
  621. parent_class: str,
  622. generate_super: bool,
  623. ) -> str:
  624. if k is SchemaKind.functional:
  625. output_type = "Tensor"
  626. output_value = "outputs_[output_idx]"
  627. proxy_field = ""
  628. elif k is SchemaKind.inplace:
  629. output_type = "std::reference_wrapper<Tensor>"
  630. output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
  631. proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
  632. elif k is SchemaKind.out:
  633. output_type = "std::reference_wrapper<Tensor>"
  634. output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
  635. proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
  636. else:
  637. raise RuntimeError(f"Unsupported SchemaKind {k}")
  638. if self.backend_index.dispatch_key == DispatchKey.CUDA:
  639. if self.rocm:
  640. guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
  641. else:
  642. guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
  643. elif (
  644. self.backend_index.dispatch_key
  645. == DispatchKey.CompositeExplicitAutogradNonFunctional
  646. ):
  647. guard_field = "c10::OptionalDeviceGuard guard_;"
  648. elif self.backend_index.dispatch_key == DispatchKey.MPS:
  649. # TODO: Move to OptionalMPSGuard.
  650. guard_field = "c10::OptionalDeviceGuard guard_;"
  651. elif self.backend_index.dispatch_key == DispatchKey.XPU:
  652. guard_field = "c10::OptionalDeviceGuard guard_;"
  653. elif self.backend_index.dispatch_key == DispatchKey.MTIA:
  654. guard_field = "c10::OptionalDeviceGuard guard_;"
  655. else:
  656. guard_field = ""
  657. indent = " " * 4
  658. class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
  659. lines = (
  660. f"struct {class_name} final : public {parent_class} {{",
  661. f"{textwrap.indent(class_ctor_str, indent)}",
  662. f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
  663. " const Tensor& maybe_get_output(int64_t output_idx) override {",
  664. f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
  665. " }",
  666. # type: ignore[possibly-undefined] # TODO: audit
  667. f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
  668. f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
  669. f"{textwrap.indent(guard_field, indent)}",
  670. "};",
  671. )
  672. return "\n".join(line for line in lines if line)
  673. @method_with_native_function
  674. def gen_one(self, f: NativeFunction) -> str | None:
  675. assert not f.manual_kernel_registration
  676. if (
  677. self.target is Target.REGISTRATION
  678. and not self.selector.is_native_function_selected(f)
  679. ):
  680. return None
  681. # TODO: Now, there is something interesting going on here. In the code below,
  682. # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
  683. # based on the out implementation. But in fact, out is definable by
  684. # functional too (just not very efficiently), and this is honestly the
  685. # MORE likely situation for a backend implementer. How do we pick?
  686. # Well, taking a page from Haskell type classes and default methods,
  687. # we could conceivably register a circular definition (out in terms
  688. # of functional, and functional in terms of out) and just require
  689. # someone to implement one or the other. We'd have to do a little bit
  690. # of work to not register one of these "weak" definitions unless there
  691. # is a strong definition somewhere in the DAG! So it's not implemented yet.
  692. if (
  693. self.backend_index.dispatch_key
  694. == DispatchKey.CompositeExplicitAutogradNonFunctional
  695. and f.func.kind() is SchemaKind.out
  696. ):
  697. # Never generate a default implementation for out, that's what you
  698. # have to define as a backend implementer
  699. return None
  700. # Note [Direct dispatch bindings]
  701. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  702. # Signature of the non-dispatched function we'll expose in a header
  703. # (e.g., at::cpu::add). We don't generate methods (TODO: do this
  704. # when CPUTensor class is a thing); nor do we generate fallback
  705. # bindings for manual_cpp_binding functions.
  706. cpp_sig_group = CppSignatureGroup.from_native_function(
  707. f, method=False, fallback_binding=False
  708. )
  709. # Signature of the wrapper function we'll register to the dispatcher
  710. kern = self.backend_index.get_kernel(f)
  711. sig = NativeSignature(
  712. f.func,
  713. prefix=f"wrapper_{self.backend_index.dispatch_key}_",
  714. symint=kern is not None and kern.supports_symint(),
  715. )
  716. if self.target is Target.NAMESPACED_DECLARATION:
  717. result = ""
  718. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  719. result += f"TORCH_API {cpp_sig.decl()};\n"
  720. return result
  721. elif self.target is Target.NAMESPACED_DEFINITION:
  722. def generate_defn(cpp_sig: CppSignature) -> str:
  723. return f"""
  724. {cpp_sig.defn()} {{
  725. return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
  726. }}
  727. """
  728. result = ""
  729. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  730. result += generate_defn(cpp_sig)
  731. return result
  732. elif self.target is Target.ANONYMOUS_DEFINITION:
  733. k = f.func.kind()
  734. # Construct the body of the wrapper function with signature sig
  735. sig_body = []
  736. # We'll use context to keep track of any variables we've brought
  737. # into scope while generating code
  738. context: list[Binding | Expr] = list(sig.arguments())
  739. # Initialize the class corresponding to this structured
  740. # operator; feeding it the output argument(s) if it is known
  741. if self.backend_index.dispatch_key is DispatchKey.Meta:
  742. class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
  743. parent_class = f"at::meta::structured_{meta.name(self.g)}"
  744. elif (
  745. self.backend_index.dispatch_key
  746. is DispatchKey.CompositeExplicitAutogradNonFunctional
  747. ):
  748. # TODO: dedup this branch
  749. class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
  750. parent_class = f"at::meta::structured_{meta.name(self.g)}"
  751. else:
  752. metadata = self.backend_index.get_kernel(self.g)
  753. assert metadata is not None
  754. class_name = f"structured_{metadata.kernel}_{k.name}"
  755. parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
  756. if self.backend_index.device_guard:
  757. device_check_args = itertools.chain(
  758. f.func.arguments.out, f.func.arguments.flat_positional
  759. )
  760. sig_body.append(
  761. RegisterDispatchKey.gen_device_check(
  762. f.device_check, list(device_check_args), sig.name()
  763. )
  764. )
  765. if k is SchemaKind.functional:
  766. sig_body.append(f"{class_name} op;")
  767. elif k is SchemaKind.inplace:
  768. sig_body.append(f"{class_name} op(self);")
  769. elif k is SchemaKind.out:
  770. out_args_str = ", ".join(a.name for a in f.func.arguments.out)
  771. sig_body.append(f"{class_name} op({out_args_str});")
  772. # Translate the input native arguments into structured
  773. # arguments for the meta call
  774. meta_exprs = ", ".join(
  775. e.expr
  776. for e in translate(
  777. context, structured.meta_arguments(self.g), method=False
  778. )
  779. )
  780. if self.g.out.precomputed:
  781. # If this function group has precomputed elements, the meta function
  782. # returns a struct containing them which must be saved so that it
  783. # can be unpacked when generating code to call the impl.
  784. sig_body.append(f"auto precompute = op.meta({meta_exprs});")
  785. # Put all of the contents of the precompute struct into the context
  786. # so that translate will be able to return the correct args for the
  787. # call to the impl.
  788. precomputed_values = [
  789. *self.g.out.precomputed.replace.values(),
  790. self.g.out.precomputed.add,
  791. ]
  792. for precomputed_elems in precomputed_values:
  793. context.extend(
  794. Expr(
  795. expr=f"precompute.{arg.name}",
  796. type=structured.argument_type(arg, binds=arg.name),
  797. )
  798. for arg in precomputed_elems
  799. )
  800. # Add a use of the precompute struct so FB internal compilers don't
  801. # complain that there is an unused variable.
  802. sig_body.append("(void)precompute;")
  803. else:
  804. sig_body.append(f"op.meta({meta_exprs});")
  805. # After running meta, op.outputs_ is guaranteed to be valid;
  806. # add it to the context
  807. out_args = structured.out_arguments(self.g)
  808. for i, out_arg in enumerate(out_args):
  809. assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
  810. if k is SchemaKind.out:
  811. expr = f"op.maybe_get_output({i})"
  812. else:
  813. expr = f"op.outputs_[{i}]"
  814. context.append(
  815. Expr(
  816. expr=expr,
  817. # TODO: Stop hardcoding that the output type is a Tensor. Note
  818. # that for the codegen here this is fine because outputs_ is
  819. # hardcoded to be tensor already
  820. type=NamedCType(
  821. out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
  822. ),
  823. )
  824. )
  825. # With the expanded context, do the impl call (if not a meta
  826. # function)
  827. if (
  828. self.backend_index.dispatch_key
  829. == DispatchKey.CompositeExplicitAutogradNonFunctional
  830. ):
  831. # TODO: https://github.com/pytorch/pytorch/issues/53023
  832. out_sig_group = CppSignatureGroup.from_native_function(
  833. self.g.out, method=False, fallback_binding=f.manual_cpp_binding
  834. )
  835. out_sig = out_sig_group.most_faithful_signature()
  836. api_name = out_sig.name()
  837. out_exprs = ", ".join(
  838. e.expr
  839. for e in translate(context, out_sig.arguments(), method=False)
  840. )
  841. # TODO: I think this means structured won't work with method
  842. # only functions (but maybe you're saved by faithful? iunno.)
  843. # NB: Originally I wrote this as an at::redispatch call, but
  844. # I got in trouble because that meant I needed a DispatchKeySet
  845. # in the wrapper function, which meant I needed a DispatchKeySet
  846. # in the DispatchKeyFunctions declarations, but the defined API
  847. # there does NOT permit a dispatch key set. I think you can
  848. # probably unwind this by calling some function to do the TLS
  849. # fetch and get the DispatchKeySet when you don't have it, but
  850. # I didn't do it for this version
  851. sig_body.append(f"at::{api_name}({out_exprs});")
  852. elif self.backend_index.dispatch_key != DispatchKey.Meta:
  853. impl_exprs = ", ".join(
  854. e.expr
  855. for e in translate(
  856. context, structured.impl_arguments(self.g), method=False
  857. )
  858. )
  859. sig_body.append(f"op.impl({impl_exprs});")
  860. # Go over each output, and check if there is a proxy created for it.
  861. # If so, copy it over to the original output.
  862. if k is SchemaKind.out or k is SchemaKind.inplace:
  863. for i in range(len(f.func.returns)):
  864. sig_body.append(
  865. f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
  866. )
  867. # Destructively return the final tensors
  868. # TODO: Do this in translate instead
  869. if k is SchemaKind.functional:
  870. if len(f.func.returns) == 1:
  871. ret_expr = "std::move(op.outputs_[0])" # small optimization
  872. else:
  873. moved = ", ".join(
  874. f"std::move(op.outputs_[{i}])"
  875. for i in range(len(f.func.returns))
  876. )
  877. ret_expr = f"std::make_tuple({moved})"
  878. elif k is SchemaKind.inplace:
  879. ret_expr = "self"
  880. elif k is SchemaKind.out:
  881. if len(f.func.returns) == 1:
  882. ret_expr = f.func.arguments.out[0].name
  883. else:
  884. refs = ", ".join(a.name for a in f.func.arguments.out)
  885. ret_expr = f"std::forward_as_tuple({refs})"
  886. sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
  887. sig_body_str = "\n".join(sig_body)
  888. # For an overview of what this template code looks like, see
  889. # https://github.com/pytorch/rfcs/pull/9
  890. return f"""\
  891. {
  892. self.gen_class(
  893. f,
  894. k,
  895. class_name=class_name,
  896. parent_class=parent_class,
  897. generate_super=self.g.out.structured_inherits is not None,
  898. )
  899. }
  900. {sig.defn()} {{
  901. {sig_body_str}
  902. }}
  903. """
  904. elif self.target is Target.REGISTRATION:
  905. return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
  906. else:
  907. assert_never(self.target)
  908. # Silence mypy's "Missing return statement" error
  909. return None