| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016 |
- from __future__ import annotations
- import itertools
- import textwrap
- from dataclasses import dataclass
- from typing import Literal, TYPE_CHECKING
- from typing_extensions import assert_never
- import torchgen.api.cpp as cpp
- import torchgen.api.meta as meta
- import torchgen.api.structured as structured
- from torchgen.api.translate import translate
- from torchgen.api.types import (
- BaseCType,
- Binding,
- ConstRefCType,
- CppSignature,
- CppSignatureGroup,
- DispatcherSignature,
- Expr,
- kernel_signature,
- MutRefCType,
- NamedCType,
- NativeSignature,
- tensorT,
- )
- from torchgen.context import method_with_native_function, native_function_manager
- from torchgen.model import (
- Argument,
- BackendIndex,
- DeviceCheckType,
- DispatchKey,
- gets_generated_out_inplace_wrapper,
- is_cuda_dispatch_key,
- NativeFunction,
- NativeFunctionsGroup,
- SchemaKind,
- TensorOptionsArguments,
- )
- from torchgen.utils import mapMaybe, Target
- if TYPE_CHECKING:
- from torchgen.selective_build.selector import SelectiveBuilder
- def gen_registration_headers(
- backend_index: BackendIndex,
- per_operator_headers: bool,
- rocm: bool,
- ) -> list[str]:
- if per_operator_headers:
- headers = ["#include <ATen/ops/as_strided_native.h>"]
- else:
- headers = ["#include <ATen/NativeFunctions.h>"]
- if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
- headers.append("#include <ATen/EmptyTensor.h>")
- elif backend_index.dispatch_key == DispatchKey.CUDA:
- if rocm:
- headers.append("#include <ATen/hip/EmptyTensor.h>")
- else:
- headers.append("#include <ATen/cuda/EmptyTensor.h>")
- elif backend_index.dispatch_key == DispatchKey.MPS:
- headers.append("#include <ATen/mps/EmptyTensor.h>")
- elif backend_index.dispatch_key == DispatchKey.XPU:
- # XPU specific, this header resides in third_party/torch-xpu-ops
- headers.append("#include <ATen/xpu/EmptyTensor.h>")
- elif backend_index.dispatch_key == DispatchKey.MTIA:
- headers.append("#include <ATen/native/mtia/EmptyTensor.h>")
- elif per_operator_headers:
- headers += [
- "#include <ATen/ops/empty.h>",
- "#include <ATen/ops/empty_strided.h>",
- "#include <ATen/ops/_copy_from_and_resize.h>",
- "#include <ATen/ops/_copy_from.h>",
- ]
- else:
- headers.append("#include <ATen/Functions.h>")
- headers.append("#include <c10/macros/Macros.h>")
- return headers
- def gen_empty_impl_names(
- backend_index: BackendIndex,
- ) -> tuple[str | None, str | None]:
- empty_impl = None
- empty_strided_impl = None
- if backend_index.dispatch_key in (
- DispatchKey.Meta,
- DispatchKey.CPU,
- DispatchKey.CUDA,
- DispatchKey.MPS,
- DispatchKey.XPU,
- DispatchKey.MTIA,
- ):
- dispatch = str(backend_index.dispatch_key).lower()
- empty_impl = f"at::detail::empty_{dispatch}"
- empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
- elif backend_index.dispatch_key in (
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- DispatchKey.QuantizedCPU,
- DispatchKey.QuantizedCUDA,
- DispatchKey.XPU,
- ):
- empty_impl = "at::empty"
- empty_strided_impl = "at::empty_strided"
- return empty_impl, empty_strided_impl
- def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
- if backend_index.dispatch_key == DispatchKey.Meta:
- empty_options = "options.device(at::kMeta)"
- else:
- empty_options = "options"
- empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
- if empty_impl is None:
- return []
- return [
- f"""
- Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
- if (strides.empty()) {{
- return {empty_impl}(sizes, {empty_options});
- }} else {{
- return {empty_strided_impl}(sizes, strides, {empty_options});
- }}
- }}
- """
- ]
- def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
- _, empty_strided_impl = gen_empty_impl_names(backend_index)
- return (
- []
- if empty_strided_impl is None
- else [
- f"""
- std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
- if (out.strides() != strides) {{
- return {empty_strided_impl}(sizes, strides, options);
- }}
- return std::nullopt;
- }}
- """
- ]
- )
- def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
- if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
- # The function isn't used by this key (since only functional ops have a kernel for this key),
- # so we need to not include it to avoid a defined-but-not-used error.
- return []
- return [
- """
- void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
- TORCH_CHECK(options.dtype() == out.dtype(),
- "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
- TORCH_CHECK(options.device() == out.device(),
- "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
- const bool resized = at::native::resize_output(out, sizes);
- // Only restride if a resize occurred; otherwise we ignore the (advisory)
- // strides from the meta function and directly use the output tensor's
- // preexisting strides
- if (resized) {
- if (!strides.empty()) {
- TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
- // TODO: avoid the redispatch here
- out.as_strided_(sizes, strides);
- } else if (options.memory_format_opt().has_value()) {
- out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
- }
- }
- }
- """
- ]
- def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
- return [
- """
- void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
- // These checks are needed on those operators that:
- // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
- // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
- // For other operators (e.g. 'add'), 'TensorIterator' already checks
- // these things separately.
- TORCH_CHECK(options.dtype() == self.dtype(),
- "Bad in-place call: ",
- "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
- TORCH_CHECK(options.device() == self.device(),
- "Bad in-place call: ",
- "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
- TORCH_CHECK(sizes == self.sizes(),
- "Bad in-place call: ",
- "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
- }
- """
- ]
- def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
- return [
- 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
- *gen_create_out_helper(backend_index),
- *gen_resize_out_helper(backend_index),
- *gen_check_inplace_helper(backend_index),
- *gen_maybe_create_proxy_helper(backend_index),
- "C10_DIAGNOSTIC_POP()",
- ]
- # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
- #
- # - The primary function of this file is to register all of the
- # implementations for the given dispatch key to the dispatcher,
- # so they are available for use in PyTorch. If dispatch is
- # None, we generate schema (def) registrations and catchall
- # registrations.
- # - The secondary function of this file is to generate a wrapper
- # around functions. In CPUType these wrappers do nothing
- # (and should be removed), but in other cases they handle
- # DeviceGuard. A small extra benefit of wrappers is they
- # are not overloaded, so they can be used in the registration
- # API without having to disambiguate which overload you want
- # (as would be the case if you directly registered native::
- # functions).
- # - The tertiary function of this file is to generate *static*
- # cpp API bindings which can be used to bypass dispatcher
- # directly to kernels, but with user-friendly cpp-style API
- @dataclass(frozen=True)
- class RegisterDispatchKey:
- backend_index: BackendIndex
- target: Literal[
- Target.ANONYMOUS_DEFINITION,
- Target.NAMESPACED_DEFINITION,
- Target.NAMESPACED_DECLARATION,
- Target.REGISTRATION,
- ]
- # Selector object to determine which operators to generate
- # registration code for.
- selector: SelectiveBuilder
- # Whether or not we are actually code-genning for ROCm
- rocm: bool
- # Whether or not to generate symint registrations or not. External users
- # of codegen who don't care about symints can set this to false to get
- # non-SymInt codegen
- symint: bool
- # The class that all unstructured native functions live under. This is used to improve
- # compiler error messages when a kernel writer adds a native function with the wrong signature.
- # This is only used in unstructured kernels, since structured kernels already live in a class.
- # Finally, this field is currently Optional because it is only used by external backends.
- # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
- # all of the existing kernel signatures scattered across aten/src/ATen/native.
- class_method_name: str | None
- # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
- # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
- skip_dispatcher_op_registration: bool
- @staticmethod
- def gen_device_check(
- type: DeviceCheckType, args: list[Argument], method_name: str
- ) -> str:
- if type == DeviceCheckType.NoCheck:
- return " // No device check\n"
- device_check = "std::optional<Device> common_device = std::nullopt;\n"
- device_check += "(void)common_device; // Suppress unused variable warning\n"
- for arg in args:
- # Only tensor like arguments are eligible
- if arg.type.is_tensor_like():
- device_check += f"""
- c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
- return device_check
- @method_with_native_function
- def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
- if isinstance(f, NativeFunctionsGroup):
- g: NativeFunctionsGroup = f
- # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
- # gen_structured() has special logic to handle auto-generated kernels.
- if g.structured:
- return self.gen_structured(g)
- else:
- return list(
- mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
- )
- elif isinstance(f, NativeFunction):
- r = self.gen_unstructured(f)
- return [] if r is None else [r]
- else:
- assert_never(f)
- def wrapper_kernel_sig(
- self, f: NativeFunction
- ) -> NativeSignature | DispatcherSignature:
- # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
- return DispatcherSignature.from_schema(
- f.func,
- prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
- symint=self.symint,
- )
- def gen_out_inplace_wrapper(
- self, f: NativeFunction, g: NativeFunctionsGroup | None
- ) -> str | None:
- if g is None:
- return None
- k = f.func.kind()
- if k is SchemaKind.inplace:
- copy_op = "at::_copy_from"
- elif k is SchemaKind.out:
- copy_op = "at::_copy_from_and_resize"
- else:
- raise AssertionError("gen_out_inplace_wrapper called on a functional op")
- sig = self.wrapper_kernel_sig(f)
- name = sig.name()
- func_res = f"{name}_tmp"
- return_names = cpp.return_names(f)
- if len(return_names) > 1:
- updates = "\n ".join(
- f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
- for i, ret_name in enumerate(return_names)
- )
- returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})"
- elif len(return_names) == 1:
- ret_name = return_names[0]
- updates = f"{copy_op}({func_res}, {ret_name});"
- returns = ret_name
- else:
- assert len(f.func.arguments.out) == 1
- returns = ""
- out_arg = f.func.arguments.out[0]
- if out_arg.type.is_list_like():
- updates = f"""\
- for (int64_t i = 0; i < {func_res}.size(); ++i) {{
- {copy_op}({func_res}[i], {out_arg.name}[i]);
- }}"""
- else:
- updates = f"{copy_op}({func_res}, {out_arg.name});"
- functional_sig = self.wrapper_kernel_sig(g.functional)
- wrapper_name = sig.name()
- return f"""\
- {sig.defn(name=wrapper_name)} {{
- auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
- {updates}
- return {returns};
- }}
- """
- def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
- metadata = self.backend_index.get_kernel(g)
- if self.backend_index.dispatch_key == DispatchKey.Meta:
- assert not self.backend_index.has_kernel(g.out), (
- "Do not explicitly specify Meta dispatch key on structured "
- "functions, they will be automatically generated for you"
- )
- elif (
- self.backend_index.dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- assert not self.backend_index.has_kernel(g.out), (
- "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
- "functions, they will be automatically generated for you"
- )
- elif metadata is None or not metadata.structured:
- return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
- structured_gen = StructuredRegisterDispatchKey(
- self.backend_index,
- self.target,
- self.selector,
- self.rocm,
- self.symint,
- self.class_method_name,
- self.skip_dispatcher_op_registration,
- g,
- )
- return list(mapMaybe(structured_gen.gen_one, g.functions()))
- def gen_unstructured(
- self, f: NativeFunction, g: NativeFunctionsGroup | None = None
- ) -> str | None:
- with native_function_manager(f):
- inplace_meta = False
- gets_out_inplace_wrapper = False
- if not self.backend_index.has_kernel(f):
- if (
- self.backend_index.dispatch_key == DispatchKey.Meta
- and f.func.kind() is SchemaKind.inplace
- and
- # Defer to composites for meta implementation
- not f.has_composite_kernel
- and
- # Inplace list operations are not supported
- len(f.func.returns) == 1
- ):
- inplace_meta = True
- elif (
- not self.backend_index.use_out_as_primary
- and g is not None
- and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
- ):
- # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
- gets_out_inplace_wrapper = True
- else:
- return None
- if f.manual_kernel_registration:
- return None
- if (
- self.target is Target.REGISTRATION
- and not self.selector.is_native_function_selected(f)
- ):
- return None
- sig = self.wrapper_kernel_sig(f)
- name = sig.name()
- returns_type = sig.returns_type().cpp_type()
- args = sig.arguments()
- args_str = ", ".join(a.defn() for a in args)
- # See Note [Direct dispatch bindings]
- cpp_sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- # TODO: dedupe this with the structured codegen
- if self.target is Target.NAMESPACED_DECLARATION:
- result = ""
- for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
- result += f"TORCH_API {cpp_sig.decl()};\n"
- return result
- elif self.target is Target.NAMESPACED_DEFINITION:
- def generate_defn(cpp_sig: CppSignature) -> str:
- return f"""
- {cpp_sig.defn()} {{
- return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
- }}
- """
- result = ""
- for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
- result += generate_defn(cpp_sig)
- return result
- elif self.target is Target.ANONYMOUS_DEFINITION:
- # short circuit for inplace_meta
- if inplace_meta:
- assert f.func.arguments.self_arg is not None
- self_arg_name = f.func.arguments.self_arg.argument.name
- # TODO: handle in place on tensor list
- return f"""
- {returns_type} {name}({args_str}) {{
- TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
- "Cannot inplace into non-meta tensor with meta tensor argument");
- return {self_arg_name};
- }}
- """
- # short circuit for generated inplace/out wrappers
- if gets_out_inplace_wrapper:
- return self.gen_out_inplace_wrapper(f, g)
- metadata = self.backend_index.get_kernel(f)
- if metadata is None:
- return None
- if self.class_method_name is None:
- impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
- else:
- impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
- kernel_sig = kernel_signature(f, self.backend_index)
- args_exprs_str = ", ".join(
- e.expr
- for e in translate(
- sig.arguments(), kernel_sig.arguments(), method=False
- )
- )
- device_check = " // No device check\n"
- # Backends that require device guards presumably also require device checks.
- if self.backend_index.device_guard:
- device_check_args = itertools.chain(
- f.func.arguments.out, f.func.arguments.flat_positional
- )
- device_check = RegisterDispatchKey.gen_device_check(
- f.device_check, list(device_check_args), name
- )
- device_guard = "// DeviceGuard omitted" # default
- if f.device_guard and self.backend_index.device_guard:
- has_tensor_options = any(
- isinstance(a, TensorOptionsArguments)
- for a in f.func.arguments.non_out
- )
- if has_tensor_options:
- # kernel is creating a tensor
- device_guard = """
- const DeviceGuard device_guard(device_or_default(device));"""
- # CUDA requires special handling
- if is_cuda_dispatch_key(self.backend_index.dispatch_key):
- device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}"
- else:
- # kernel is operating on existing tensors
- # There is precedence for which argument we use to do
- # device guard. This describes the precedence order.
- self_arg = (
- [f.func.arguments.self_arg.argument]
- if f.func.arguments.self_arg is not None
- else []
- )
- candidate_args = itertools.chain(
- self_arg,
- f.func.arguments.out,
- f.func.arguments.flat_positional,
- )
- # Only tensor like arguments are eligible
- device_of = next(
- (
- f"{a.name}"
- for a in candidate_args
- if a.type.is_tensor_like()
- ),
- None,
- )
- if device_of is not None:
- device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
- return f"""\
- namespace {{
- {returns_type} {name}({args_str}) {{
- {device_check}
- {device_guard}
- return {impl_name}({args_exprs_str});
- }}
- }} // anonymous namespace
- """
- elif self.target is Target.REGISTRATION:
- if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
- return None
- else:
- payload = f"TORCH_FN({name})"
- return f'm.impl("{f.func.name}",\n{payload});\n'
- else:
- assert_never(self.target)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # STRUCTURED
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- @dataclass(frozen=True)
- class StructuredRegisterDispatchKey(RegisterDispatchKey):
- g: NativeFunctionsGroup
- def gen_class_set_output_functions(
- self, k: SchemaKind, parent_class: str, generate_super: bool
- ) -> str:
- if generate_super:
- set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
- else:
- set_output_super = ""
- def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
- return f"""
- void set_output_{name}(
- int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
- TensorOptions options, DimnameList names
- ) override {{
- {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
- if (!names.empty()) {{
- namedinference::propagate_names(outputs_[output_idx], names);
- }}
- // super must happen after, so that downstream can use maybe_get_output
- // to retrieve the output
- {textwrap.indent(set_output_super, " ")}
- }}
- """
- return f"""
- {gen_set_output_function("strided", maybe_create_proxy=True)}
- {gen_set_output_function("raw_strided", maybe_create_proxy=False)}
- """
- def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
- if self.backend_index.dispatch_key in [
- DispatchKey.CUDA,
- DispatchKey.MPS,
- DispatchKey.XPU,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ]:
- maybe_set_guard = """
- auto current_device = guard_.current_device();
- if (C10_UNLIKELY(current_device.has_value())) {
- TORCH_INTERNAL_ASSERT(*current_device == options.device(),
- "structured kernels don't support multi-device outputs");
- } else {
- guard_.reset_device(options.device());
- }
- """
- maybe_set_guard_line = maybe_set_guard + "\n"
- else:
- maybe_set_guard_line = maybe_set_guard = ""
- if maybe_create_proxy:
- create_proxy = """
- auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
- if (C10_UNLIKELY(maybe_proxy.has_value())) {
- proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
- }
- """
- else:
- create_proxy = ""
- if k is SchemaKind.functional:
- assert self.backend_index.dispatch_key in (
- DispatchKey.Meta,
- DispatchKey.CPU,
- DispatchKey.CUDA,
- DispatchKey.MPS,
- DispatchKey.XPU,
- DispatchKey.MTIA,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- )
- return f"""{maybe_set_guard_line}
- outputs_[output_idx] = create_out(sizes, strides, options);"""
- elif k is SchemaKind.inplace:
- return f"""{maybe_set_guard_line}
- const auto& out = outputs_[output_idx].get();
- check_inplace(out, sizes, options);
- {create_proxy}"""
- elif k is SchemaKind.out:
- return f"""{maybe_set_guard_line}
- const auto& out = outputs_[output_idx].get();
- resize_out(out, sizes, strides, options);
- {create_proxy}"""
- elif k is SchemaKind.mutable or k is SchemaKind.scratch:
- raise AssertionError(
- f"{k} structured operators are currently not supported"
- )
- else:
- assert_never(k)
- # returns the definition of a ctor, as well as how to construct
- # this class to a variable named op
- def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
- if k is SchemaKind.functional:
- return ""
- elif k is SchemaKind.inplace:
- # TODO: Make sure out argument is guaranteed to be self
- return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
- elif k is SchemaKind.out:
- out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
- out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
- return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
- elif k is SchemaKind.mutable or k is SchemaKind.scratch:
- raise AssertionError(
- f"{k} structured operators are currently not supported"
- )
- else:
- assert_never(k)
- def gen_class(
- self,
- f: NativeFunction,
- k: SchemaKind,
- *,
- class_name: str,
- parent_class: str,
- generate_super: bool,
- ) -> str:
- if k is SchemaKind.functional:
- output_type = "Tensor"
- output_value = "outputs_[output_idx]"
- proxy_field = ""
- elif k is SchemaKind.inplace:
- output_type = "std::reference_wrapper<Tensor>"
- output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
- proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
- elif k is SchemaKind.out:
- output_type = "std::reference_wrapper<Tensor>"
- output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
- proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
- else:
- raise RuntimeError(f"Unsupported SchemaKind {k}")
- if self.backend_index.dispatch_key == DispatchKey.CUDA:
- if self.rocm:
- guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
- else:
- guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
- elif (
- self.backend_index.dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- guard_field = "c10::OptionalDeviceGuard guard_;"
- elif self.backend_index.dispatch_key == DispatchKey.MPS:
- # TODO: Move to OptionalMPSGuard.
- guard_field = "c10::OptionalDeviceGuard guard_;"
- elif self.backend_index.dispatch_key == DispatchKey.XPU:
- guard_field = "c10::OptionalDeviceGuard guard_;"
- elif self.backend_index.dispatch_key == DispatchKey.MTIA:
- guard_field = "c10::OptionalDeviceGuard guard_;"
- else:
- guard_field = ""
- indent = " " * 4
- class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
- lines = (
- f"struct {class_name} final : public {parent_class} {{",
- f"{textwrap.indent(class_ctor_str, indent)}",
- f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
- " const Tensor& maybe_get_output(int64_t output_idx) override {",
- f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
- " }",
- # type: ignore[possibly-undefined] # TODO: audit
- f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
- f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
- f"{textwrap.indent(guard_field, indent)}",
- "};",
- )
- return "\n".join(line for line in lines if line)
- @method_with_native_function
- def gen_one(self, f: NativeFunction) -> str | None:
- assert not f.manual_kernel_registration
- if (
- self.target is Target.REGISTRATION
- and not self.selector.is_native_function_selected(f)
- ):
- return None
- # TODO: Now, there is something interesting going on here. In the code below,
- # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
- # based on the out implementation. But in fact, out is definable by
- # functional too (just not very efficiently), and this is honestly the
- # MORE likely situation for a backend implementer. How do we pick?
- # Well, taking a page from Haskell type classes and default methods,
- # we could conceivably register a circular definition (out in terms
- # of functional, and functional in terms of out) and just require
- # someone to implement one or the other. We'd have to do a little bit
- # of work to not register one of these "weak" definitions unless there
- # is a strong definition somewhere in the DAG! So it's not implemented yet.
- if (
- self.backend_index.dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- and f.func.kind() is SchemaKind.out
- ):
- # Never generate a default implementation for out, that's what you
- # have to define as a backend implementer
- return None
- # Note [Direct dispatch bindings]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # Signature of the non-dispatched function we'll expose in a header
- # (e.g., at::cpu::add). We don't generate methods (TODO: do this
- # when CPUTensor class is a thing); nor do we generate fallback
- # bindings for manual_cpp_binding functions.
- cpp_sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- # Signature of the wrapper function we'll register to the dispatcher
- kern = self.backend_index.get_kernel(f)
- sig = NativeSignature(
- f.func,
- prefix=f"wrapper_{self.backend_index.dispatch_key}_",
- symint=kern is not None and kern.supports_symint(),
- )
- if self.target is Target.NAMESPACED_DECLARATION:
- result = ""
- for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
- result += f"TORCH_API {cpp_sig.decl()};\n"
- return result
- elif self.target is Target.NAMESPACED_DEFINITION:
- def generate_defn(cpp_sig: CppSignature) -> str:
- return f"""
- {cpp_sig.defn()} {{
- return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
- }}
- """
- result = ""
- for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
- result += generate_defn(cpp_sig)
- return result
- elif self.target is Target.ANONYMOUS_DEFINITION:
- k = f.func.kind()
- # Construct the body of the wrapper function with signature sig
- sig_body = []
- # We'll use context to keep track of any variables we've brought
- # into scope while generating code
- context: list[Binding | Expr] = list(sig.arguments())
- # Initialize the class corresponding to this structured
- # operator; feeding it the output argument(s) if it is known
- if self.backend_index.dispatch_key is DispatchKey.Meta:
- class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
- parent_class = f"at::meta::structured_{meta.name(self.g)}"
- elif (
- self.backend_index.dispatch_key
- is DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- # TODO: dedup this branch
- class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
- parent_class = f"at::meta::structured_{meta.name(self.g)}"
- else:
- metadata = self.backend_index.get_kernel(self.g)
- assert metadata is not None
- class_name = f"structured_{metadata.kernel}_{k.name}"
- parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
- if self.backend_index.device_guard:
- device_check_args = itertools.chain(
- f.func.arguments.out, f.func.arguments.flat_positional
- )
- sig_body.append(
- RegisterDispatchKey.gen_device_check(
- f.device_check, list(device_check_args), sig.name()
- )
- )
- if k is SchemaKind.functional:
- sig_body.append(f"{class_name} op;")
- elif k is SchemaKind.inplace:
- sig_body.append(f"{class_name} op(self);")
- elif k is SchemaKind.out:
- out_args_str = ", ".join(a.name for a in f.func.arguments.out)
- sig_body.append(f"{class_name} op({out_args_str});")
- # Translate the input native arguments into structured
- # arguments for the meta call
- meta_exprs = ", ".join(
- e.expr
- for e in translate(
- context, structured.meta_arguments(self.g), method=False
- )
- )
- if self.g.out.precomputed:
- # If this function group has precomputed elements, the meta function
- # returns a struct containing them which must be saved so that it
- # can be unpacked when generating code to call the impl.
- sig_body.append(f"auto precompute = op.meta({meta_exprs});")
- # Put all of the contents of the precompute struct into the context
- # so that translate will be able to return the correct args for the
- # call to the impl.
- precomputed_values = [
- *self.g.out.precomputed.replace.values(),
- self.g.out.precomputed.add,
- ]
- for precomputed_elems in precomputed_values:
- context.extend(
- Expr(
- expr=f"precompute.{arg.name}",
- type=structured.argument_type(arg, binds=arg.name),
- )
- for arg in precomputed_elems
- )
- # Add a use of the precompute struct so FB internal compilers don't
- # complain that there is an unused variable.
- sig_body.append("(void)precompute;")
- else:
- sig_body.append(f"op.meta({meta_exprs});")
- # After running meta, op.outputs_ is guaranteed to be valid;
- # add it to the context
- out_args = structured.out_arguments(self.g)
- for i, out_arg in enumerate(out_args):
- assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
- if k is SchemaKind.out:
- expr = f"op.maybe_get_output({i})"
- else:
- expr = f"op.outputs_[{i}]"
- context.append(
- Expr(
- expr=expr,
- # TODO: Stop hardcoding that the output type is a Tensor. Note
- # that for the codegen here this is fine because outputs_ is
- # hardcoded to be tensor already
- type=NamedCType(
- out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
- ),
- )
- )
- # With the expanded context, do the impl call (if not a meta
- # function)
- if (
- self.backend_index.dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- # TODO: https://github.com/pytorch/pytorch/issues/53023
- out_sig_group = CppSignatureGroup.from_native_function(
- self.g.out, method=False, fallback_binding=f.manual_cpp_binding
- )
- out_sig = out_sig_group.most_faithful_signature()
- api_name = out_sig.name()
- out_exprs = ", ".join(
- e.expr
- for e in translate(context, out_sig.arguments(), method=False)
- )
- # TODO: I think this means structured won't work with method
- # only functions (but maybe you're saved by faithful? iunno.)
- # NB: Originally I wrote this as an at::redispatch call, but
- # I got in trouble because that meant I needed a DispatchKeySet
- # in the wrapper function, which meant I needed a DispatchKeySet
- # in the DispatchKeyFunctions declarations, but the defined API
- # there does NOT permit a dispatch key set. I think you can
- # probably unwind this by calling some function to do the TLS
- # fetch and get the DispatchKeySet when you don't have it, but
- # I didn't do it for this version
- sig_body.append(f"at::{api_name}({out_exprs});")
- elif self.backend_index.dispatch_key != DispatchKey.Meta:
- impl_exprs = ", ".join(
- e.expr
- for e in translate(
- context, structured.impl_arguments(self.g), method=False
- )
- )
- sig_body.append(f"op.impl({impl_exprs});")
- # Go over each output, and check if there is a proxy created for it.
- # If so, copy it over to the original output.
- if k is SchemaKind.out or k is SchemaKind.inplace:
- for i in range(len(f.func.returns)):
- sig_body.append(
- f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
- )
- # Destructively return the final tensors
- # TODO: Do this in translate instead
- if k is SchemaKind.functional:
- if len(f.func.returns) == 1:
- ret_expr = "std::move(op.outputs_[0])" # small optimization
- else:
- moved = ", ".join(
- f"std::move(op.outputs_[{i}])"
- for i in range(len(f.func.returns))
- )
- ret_expr = f"std::make_tuple({moved})"
- elif k is SchemaKind.inplace:
- ret_expr = "self"
- elif k is SchemaKind.out:
- if len(f.func.returns) == 1:
- ret_expr = f.func.arguments.out[0].name
- else:
- refs = ", ".join(a.name for a in f.func.arguments.out)
- ret_expr = f"std::forward_as_tuple({refs})"
- sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
- sig_body_str = "\n".join(sig_body)
- # For an overview of what this template code looks like, see
- # https://github.com/pytorch/rfcs/pull/9
- return f"""\
- {
- self.gen_class(
- f,
- k,
- class_name=class_name,
- parent_class=parent_class,
- generate_super=self.g.out.structured_inherits is not None,
- )
- }
- {sig.defn()} {{
- {sig_body_str}
- }}
- """
- elif self.target is Target.REGISTRATION:
- return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
- else:
- assert_never(self.target)
- # Silence mypy's "Missing return statement" error
- return None
|