| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032 |
- from __future__ import annotations
- import argparse
- import functools
- import json
- import keyword
- import os
- from collections import defaultdict, namedtuple, OrderedDict
- from dataclasses import dataclass, field
- from pathlib import Path
- from typing import Any, Literal, TYPE_CHECKING, TypeVar
- from typing_extensions import assert_never
- import yaml
- import torchgen.api.dispatcher as dispatcher
- import torchgen.api.meta as meta
- import torchgen.api.native as native
- import torchgen.api.structured as structured
- import torchgen.dest as dest
- from torchgen.api import cpp
- from torchgen.api.translate import translate
- from torchgen.api.types import (
- Binding,
- CppSignature,
- CppSignatureGroup,
- DispatcherSignature,
- NamedCType,
- NativeSignature,
- SpecialArgName,
- )
- from torchgen.context import (
- method_with_native_function,
- native_function_manager,
- with_native_function,
- with_native_function_and_indices,
- )
- from torchgen.gen_aoti_c_shim import (
- gen_aoti_c_shim_files,
- gen_static_dispatch_backend_call_signature,
- )
- from torchgen.gen_functionalization_type import (
- gen_functionalization_definition,
- gen_functionalization_registration,
- gen_functionalization_view_inverse_declaration,
- gen_functionalization_view_meta_classes_decl,
- gen_functionalization_view_meta_classes_impl,
- GenCompositeViewCopyKernel,
- )
- from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
- from torchgen.model import (
- Argument,
- BackendIndex,
- BackendMetadata,
- BaseOperatorName,
- DEFAULT_KERNEL_NAMESPACE,
- dispatch_device_map,
- DispatchKey,
- FRAGMENT_NAMESPACES,
- FunctionSchema,
- is_cuda_dispatch_key,
- is_generic_dispatch_key,
- is_ufunc_dispatch_key,
- is_xpu_dispatch_key,
- Location,
- NativeFunction,
- NativeFunctionsGroup,
- NativeFunctionsViewGroup,
- OperatorName,
- OptionalType,
- SchemaKind,
- SelfArgument,
- STRUCTURED_DISPATCH_KEYS,
- TensorOptionsArguments,
- Type,
- Variant,
- ViewSchemaKind,
- )
- from torchgen.native_function_generation import (
- add_generated_native_functions,
- gen_composite_functional_kernel,
- gen_composite_out_kernel,
- pre_group_native_functions,
- )
- from torchgen.selective_build.selector import SelectiveBuilder
- from torchgen.utils import (
- concatMap,
- context,
- FileManager,
- make_file_manager,
- mapMaybe,
- NamespaceHelper,
- Target,
- )
- from torchgen.yaml_utils import YamlDumper, YamlLoader
- if TYPE_CHECKING:
- from collections.abc import Callable, Sequence
- T = TypeVar("T")
- # Welcome to the ATen code generator v2! The ATen code generator is
- # responsible for parsing native_functions.yaml and then generating
- # various generated files (e.g., TypeDefault.cpp) based on the operators
- # defined in this file. This means that the code generator knows how to
- # parse function schema, and then translate this into various C++ types
- # and boilerplate code.
- #
- # Some things to know about this file when you modify it:
- #
- # - This file has STRICT mypy typechecking. Typecheck it with
- # `mypy --config mypy-strict.ini` in the root source directory
- #
- # - Most of the heavy lifting lives in external modules:
- # - 'model' has the data model for native_functions.yaml. The classes
- # in those file represent what you see when you look at
- # a native_functions.yaml
- # - 'api' has conversions for how to translate JIT schema into
- # the various C++ APIs that the codegen interacts with. There
- # are in fact THREE different C++ APIs: the public C++ API,
- # the dispatcher API, and the legacy dispatcher API. See each
- # of these respective files for more information
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # HELPER FUNCTIONS
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # A custom loader for YAML to let us also keep track of line numbers
- # of each entry in the YAML file
- class LineLoader(YamlLoader):
- def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
- mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
- # Add 1 so line numbering starts at 1
- mapping["__line__"] = node.start_mark.line + 1
- return mapping
- # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
- ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
- _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
- _GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
- def file_manager_from_dispatch_key(
- dispatch_key: DispatchKey,
- device_fms: dict[str, FileManager],
- default_fm: FileManager,
- ) -> FileManager:
- fm = device_fms.get(
- next(
- (
- device
- for check, device in dispatch_device_map.items()
- if check(dispatch_key)
- ),
- "",
- ),
- default_fm,
- )
- return fm
- def parse_native_yaml_struct(
- es: object,
- valid_tags: set[str],
- ignore_keys: set[DispatchKey] | None = None,
- path: str = "<stdin>",
- skip_native_fns_gen: bool = False,
- ) -> ParsedYaml:
- assert isinstance(es, list)
- rs: list[NativeFunction] = []
- bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
- for e in es:
- assert isinstance(e, dict), f"expected to be dict: {e}"
- assert isinstance(e.get("__line__"), int), e
- loc = Location(path, e["__line__"])
- funcs = e.get("func")
- assert funcs is not None, f"missed 'func' in {e}"
- with context(lambda: f"in {loc}:\n {funcs}"):
- func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
- rs.append(func)
- BackendIndex.grow_index(bs, m)
- error_check_native_functions(rs)
- # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
- indices: dict[DispatchKey, BackendIndex] = defaultdict(
- lambda: BackendIndex(
- dispatch_key=DispatchKey.Undefined,
- use_out_as_primary=True,
- external=False,
- device_guard=False,
- # I'm actually not sure about this; undefined could be hit on
- # empty TensorList, hypothetically that could have sizes in it
- index={},
- )
- )
- if not skip_native_fns_gen:
- add_generated_native_functions(rs, bs)
- for k, v in bs.items():
- # All structured in-tree operators are implemented in terms of their out operator.
- indices[k] = BackendIndex(
- dispatch_key=k,
- use_out_as_primary=True,
- external=False,
- # Only cuda-like devices in tree require device guards
- device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
- index=v,
- )
- return ParsedYaml(rs, indices)
- def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
- assert isinstance(es, list)
- rs: set[str] = set()
- for e in es:
- assert isinstance(e.get("__line__"), int), e
- loc = Location(path, e["__line__"])
- tags = e.get("tag")
- with context(lambda: f"in {loc}:\n {tags}"):
- e_i = e.copy()
- name = e_i.pop("tag")
- desc = e_i.pop("desc", "")
- # ensure that each tag has a non-empty description
- assert desc != ""
- rs.add(name)
- return rs
- @functools.cache
- def parse_tags_yaml(path: str) -> set[str]:
- global _GLOBAL_PARSE_TAGS_YAML_CACHE
- if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
- with open(path) as f:
- es = yaml.load(f, Loader=LineLoader)
- _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
- return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
- def parse_native_yaml(
- path: str,
- tags_yaml_path: str,
- ignore_keys: set[DispatchKey] | None = None,
- *,
- skip_native_fns_gen: bool = False,
- loaded_yaml: object | None = None,
- ) -> ParsedYaml:
- global _GLOBAL_PARSE_NATIVE_YAML_CACHE
- if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
- valid_tags = parse_tags_yaml(tags_yaml_path)
- # if a loaded yaml is provided, use that instead of reading from path
- if loaded_yaml is None:
- with open(path) as f:
- es = yaml.load(f, Loader=LineLoader)
- else:
- es = loaded_yaml
- _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
- es,
- valid_tags,
- ignore_keys,
- path=path,
- skip_native_fns_gen=skip_native_fns_gen,
- )
- return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
- # Some assertions are already performed during parsing, but those are only within a single NativeFunction.
- # Assertions here are meant to be performed across NativeFunctions.
- def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
- func_map: dict[OperatorName, NativeFunction] = {}
- base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
- for f in funcs:
- func_map[f.func.name] = f
- base_func_map[f.func.name.name].append(f)
- for f in funcs:
- if f.structured_delegate is not None:
- delegate_func = func_map.get(f.structured_delegate)
- assert delegate_func is not None, (
- f"{f.func.name} is marked as a structured_delegate pointing to "
- f"{f.structured_delegate}, but {f.structured_delegate} is missing."
- )
- assert delegate_func.structured, (
- f"{f.func.name} is marked as a structured_delegate pointing to "
- f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
- f"Consider adding 'structured=True' to the delegated operator"
- )
- # Check for reserved Python keywords
- PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist)
- # List of pre-existing operators that are known to have reserved keywords
- # Exclusion list is used to suppress the assertion for these operators
- EXCLUSION_LIST = {
- ("_has_compatible_shallow_copy_type", "from"),
- ("random_.from", "from"),
- ("uniform_", "from"),
- }
- for arg in f.func.arguments.flat_all:
- if arg.name in PYTHON_RESERVED_KEYWORDS:
- if (str(f.func.name), arg.name) not in EXCLUSION_LIST:
- raise AssertionError(
- f"Argument name '{arg.name}' in function '{f.func.name}' is a reserved Python keyword."
- )
- # See Note [resize_ in Functionalization]
- # resize_() is technically an inplace view op (and therefore needs the tag),
- # but it would be overkill to add a true "view" variant of resize.
- # Instead, resize_() gets special treatment in functionalization,
- # and we have a resize() op that is non-aliasing + functional.
- if (
- "inplace_view" in f.tags
- and str(f.func.name) != "resize_"
- and str(f.func.name) != "resize_as_"
- and str(f.func.name.name) != "set_"
- ):
- base_name = f.func.name.name
- assert base_name.inplace, (
- f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
- "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
- )
- out_of_place_base_name = BaseOperatorName(
- base_name.base, False, base_name.dunder_method
- )
- assert len(base_func_map[out_of_place_base_name]) > 0, (
- f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
- f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
- )
- def cpp_string(s: str) -> str:
- """Convert a python string into a c++ string literal"""
- s = s.replace("\\", "\\\\")
- s = s.replace('"', '\\"')
- s = s.replace("\a", "\\a")
- s = s.replace("\b", "\\b")
- s = s.replace("\f", "\\f")
- s = s.replace("\n", "\\n")
- s = s.replace("\v", "\\v")
- s = s.replace("\t", "\\t")
- return f'"{s}"'
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # C++ CODE GENERATION
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # Most functions in this section are curried: they consist of a function
- # that takes some parameters (e.g., what is to be generated) which itself
- # returns a function that actually maps NativeFunction to the code
- # to be generated. This pattern makes it convenient to use map, concatMap
- # and similar functional combinators.
- def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
- if len(backends) == 0:
- return []
- else:
- return [backend.dispatch_key for backend in backends] + [
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ]
- def get_static_dispatch_backend(
- f: NativeFunction, backend_index: BackendIndex
- ) -> DispatchKey | None:
- if f.structured_delegate is not None or backend_index.has_kernel(f):
- # TODO: for ops with structured_delegate it should check the dispatch table of
- # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
- # so we always dispatch to the `backend`, but this could be wrong when we
- # migrate math/default_backend ops to use structured delegate.
- return backend_index.dispatch_key
- elif f.has_composite_explicit_autograd_kernel:
- return DispatchKey.CompositeExplicitAutograd
- elif f.has_composite_explicit_autograd_non_functional_kernel:
- return DispatchKey.CompositeExplicitAutogradNonFunctional
- elif f.has_composite_implicit_autograd_kernel:
- return DispatchKey.CompositeImplicitAutograd
- elif f.has_composite_implicit_autograd_nested_tensor_kernel:
- return DispatchKey.CompositeImplicitAutogradNestedTensor
- return None
- def static_dispatch_ops_header(
- f: NativeFunction, backend_index: list[BackendIndex]
- ) -> str | None:
- if backend_index is None or f.manual_kernel_registration:
- return None
- output = []
- for index in backend_index:
- dispatch_key = get_static_dispatch_backend(f, index)
- if dispatch_key is not None:
- output.append(
- f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
- )
- return "\n".join(output)
- def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
- return [
- f"#include <ATen/{dispatch_key}Functions.h>"
- for dispatch_key in static_dispatch_keys(backends)
- ]
- # Translates arguments of `sig` to CppSignature bindings.
- # Note that we have a special case for `memory_format` argument and this case is not covered by
- # tools.codegen.api.translate() yet as its application is limited to static dispatch.
- def translate_args(
- sig: CppSignature | DispatcherSignature,
- cpp_sig: CppSignature,
- ) -> str:
- # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
- def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
- output_bindings: list[Binding] = []
- for binding in input_bindings:
- if binding.name == "memory_format":
- spl_mem_format_binding = Binding(
- nctype=NamedCType(
- SpecialArgName.possibly_redundant_memory_format,
- binding.nctype.type,
- ),
- name=binding.name,
- default=binding.default,
- argument=binding.argument,
- )
- output_bindings.append(spl_mem_format_binding)
- else:
- output_bindings.append(binding)
- return output_bindings
- src_bindings = list(sig.arguments())
- goal_bindings = list(cpp_sig.arguments())
- # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
- # get memory_format bindings of dispatcher signature to have the same NCType as well
- for arg in goal_bindings:
- if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
- src_bindings = add_spl_memory_format_binding(src_bindings)
- break
- exprs = translate(src_bindings, goal_bindings)
- return ", ".join(a.expr for a in exprs)
- def generate_static_dispatch_backend_call(
- sig: CppSignature | DispatcherSignature,
- f: NativeFunction,
- backend_index: BackendIndex,
- ) -> str:
- cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
- name = cpp_sig.name()
- exprs = translate_args(sig, cpp_sig)
- backend_metadata = backend_index.get_kernel(f)
- kernel_ns = (
- backend_metadata.cpp_namespace
- if backend_metadata and backend_metadata.cpp_namespace
- else DEFAULT_KERNEL_NAMESPACE
- )
- ns = kernel_ns.replace("::native", "")
- return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
- def generate_static_dispatch_fallback_call(
- sig: CppSignature | DispatcherSignature,
- f: NativeFunction,
- backend_indices: list[BackendIndex],
- ) -> str:
- cpp_sigs = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- if sig.symint and f.func.has_symint():
- cpp_sig = cpp_sigs.symint_signature
- else:
- cpp_sig = cpp_sigs.signature
- assert cpp_sig is not None
- name = cpp_sig.name()
- exprs = translate_args(sig, cpp_sig)
- ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
- if f.has_composite_explicit_autograd_kernel:
- return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
- elif f.has_composite_explicit_autograd_non_functional_kernel:
- return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
- elif f.has_composite_implicit_autograd_kernel:
- return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
- elif f.has_composite_implicit_autograd_nested_tensor_kernel:
- return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
- else:
- return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
- {", ".join([str(index.dispatch_key) for index in backend_indices])} ");"""
- def static_dispatch(
- sig: CppSignature | DispatcherSignature,
- f: NativeFunction,
- backend_indices: list[BackendIndex],
- ) -> str:
- """
- For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
- backends exist, fallback to static dispatch by determining dispatch key from inputs.
- Arguments:
- sig: A CppSignature or DispatcherSignature for this native function we want to use.
- f: NativeFunction to generate static dispatch.
- backend_indices: All available backends.
- Return:
- C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
- """
- if len(backend_indices) == 0 or f.manual_kernel_registration:
- return ""
- keys = [
- b
- for b in backend_indices
- if b.has_kernel(f)
- or (
- f.structured_delegate is not None
- and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
- )
- ]
- if len(keys) == 1:
- return generate_static_dispatch_backend_call(sig, f, keys[0])
- elif len(keys) == 0:
- return generate_static_dispatch_fallback_call(sig, f, backend_indices)
- native_tensor_args = [
- a.name
- for a in sig.arguments()
- if isinstance(a.argument, SelfArgument)
- or isinstance(a.argument, Argument)
- and a.argument.type.is_tensor_like()
- ]
- tensor_args = ", ".join(native_tensor_args)
- tensor_opts = f.func.arguments.tensor_options
- stmts = []
- subexprs: list[str] = []
- if tensor_opts is not None:
- subexprs.append(
- "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
- )
- if tensor_args != "":
- subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
- stmts.append(f"""DispatchKeySet _dk_set = {" | ".join(subexprs)};""")
- stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
- dispatch_code = []
- for index in keys:
- dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
- dispatch_code.append(
- f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
- )
- fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
- connector = "\n\t\t"
- return f"""
- {connector.join(stmts)}
- switch (_dk) {{
- {connector.join(dispatch_code)}
- default:
- {fallback}
- }}
- """
- # Generates RegisterSchema.cpp. Depending on the selector, either
- # all schemas are registered, or only some are (in the case of
- # selective build)
- @dataclass(frozen=True)
- class RegisterSchema:
- selector: SelectiveBuilder
- known_tags: dict[str, int] = field(default_factory=dict)
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- if not self.selector.is_native_function_selected(f):
- return None
- tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
- if tags == "{}":
- return f"m.def({cpp_string(str(f.func))}, {{}});\n"
- maybe_tags = ""
- if tags not in self.known_tags:
- idx = len(self.known_tags)
- self.known_tags[tags] = idx
- maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
- return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
- # Generates Operators.h and Operators.cpp.
- # These provide macros that, given an operator and overload name, allow users
- # to access an "un-overloaded" function version of the operator. This
- # is useful for extension writers who want to (1) want to decltype the operator
- # and (2) don't want to worry about method-only operators.
- @dataclass(frozen=True)
- class ComputeOperators:
- target: Literal[Target.DECLARATION, Target.DEFINITION]
- static_dispatch_backend_indices: list[BackendIndex]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str:
- sig = DispatcherSignature.from_schema(f.func)
- name = f.func.name.unambiguous_name()
- if self.target is Target.DECLARATION:
- # Note [The ATen Operators API]
- # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
- # metadata about each operator + entry points into the Dispatcher.
- # The C++ function, method, and redispatch API's are all implemented as wrappers
- # into various bits of the structs defined here.
- #
- # Important characteristics about the Operators API:
- # (1) It follows the Dispatcher API.
- # This is kind of necessary to avoid overhead.
- # For example: if it followed the C++ API, then all of the faithful C++ factory functions
- # would need to wrap their arguments into TensorOptions only to unwrap them again.
- # (2) Overload names are disambiguated.
- # This is helpful for pytorch extenders who would like to decltype() an aten operator,
- # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
- # (3) No argument defaulting is allowed.
- # This is more of an implementation detail to avoid #include cycles,
- # since TensorBody.h (which defines the Tensor class) needs to include this file.
- # (4) manual_cpp_bindings and faithful names are not included in the API.
- # This applies to stuff like __dispatch__is_complex(), and add_outf().
- # These aren't "real aten ops", they're just additional functions provided by the C++ API.
- # They're implemented as wrappers in Functions.h that call into the actual operators
- # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
- # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
- return f"""
- struct TORCH_API {name} {{
- using schema = {sig.type()};
- using ptr_schema = schema*;
- // See Note [static constexpr char* members for windows NVCC]
- static constexpr const char* name = "aten::{f.func.name.name}";
- static constexpr const char* overload_name = "{f.func.name.overload_name}";
- static constexpr const char* schema_str = {cpp_string(str(f.func))};
- static {sig.defn(name="call", is_redispatching_fn=False)};
- static {sig.defn(name="redispatch", is_redispatching_fn=True)};
- }};"""
- elif self.target is Target.DEFINITION:
- defns = f"""
- // aten::{f.func}
- static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
- return c10::Dispatcher::singleton()
- .findSchemaOrThrow({name}::name, {name}::overload_name)
- .typed<{name}::schema>();
- }}
- """
- for is_redispatching_fn in [False, True]:
- if is_redispatching_fn:
- dispatcher_exprs_str = ", ".join(
- ["dispatchKeySet"] + [a.name for a in sig.arguments()]
- )
- method_base = "redispatch"
- else:
- dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
- method_base = "call"
- dispatcher_call = method_base
- method_name = f"{name}::{method_base}"
- fn_body = f"""
- static auto op = create_{name}_typed_handle();
- return op.{dispatcher_call}({dispatcher_exprs_str});"""
- if (
- not is_redispatching_fn
- and len(self.static_dispatch_backend_indices) > 0
- ):
- # call() should go through static dispatch
- fn_body = static_dispatch(
- sig, f, backend_indices=self.static_dispatch_backend_indices
- )
- defns += f"""
- // aten::{f.func}
- {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
- {fn_body}
- }}
- """
- return defns
- else:
- assert_never(self.target)
- # Generates Functions.h, which provides the functional public C++ API,
- # and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeFunction:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- )
- has_symint = f.func.has_symint()
- result = ""
- for sig in sig_group.signatures():
- # See Note [The ATen Operators API]
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments())
- exprs_str = ", ".join([e.expr for e in exprs])
- if sig.symint:
- intlike_t = "c10::SymInt"
- else:
- intlike_t = "int64_t"
- if Variant.function in f.variants:
- result += f"""
- // aten::{f.func}
- inline {sig.decl()} {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}"""
- # The template function can be used from template situations
- # where you want to switch between the symint or not version
- # depending on a template argument
- #
- # NB: we ALWAYS generate this even for methods. But we put it in
- # this header so it can take advantage of per-op headers
- if has_symint:
- result += f"""
- namespace symint {{
- template <typename T, typename = std::enable_if_t<std::is_same_v<T, {intlike_t}>>>
- {sig.decl(suppress_symint_suffix=True)} {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}
- }}
- """
- return result
- # Generates TensorBody.h. This file provides the object-oriented (method-based)
- # public C++ API, and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeTensorMethod:
- target: Literal[Target.DECLARATION, Target.DEFINITION]
- static_dispatch_backend_indices: list[BackendIndex]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- if Variant.method not in f.variants:
- return None
- assert not f.func.is_out_fn()
- assert f.func.arguments.self_arg is not None
- sig_group = CppSignatureGroup.from_native_function(
- f, method=True, fallback_binding=f.manual_cpp_binding
- )
- if self.target is Target.DECLARATION:
- result = ""
- for sig in sig_group.signatures():
- result += f"{sig.decl()} const;\n"
- return result
- if self.target is not Target.DEFINITION:
- assert_never(self.target)
- result = ""
- for sig in sig_group.signatures():
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
- exprs_str = ", ".join([e.expr for e in exprs])
- result += f"""
- // aten::{f.func}
- inline {sig.defn(prefix="Tensor::")} const {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}
- """
- return result
- # Generates RedispatchFunctions.h.
- # This is similar to the C++ API defined in Functions.h, but provides access
- # to the dispatcher's redispatch API.
- @dataclass(frozen=True)
- class ComputeRedispatchFunction:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- # We unconditionally generate function variants of the redispatch API.
- # This is mainly because we can namespace functions separately, but not methods,
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- )
- result = ""
- for sig in sig_group.signatures():
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments())
- exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
- result += f"""
- // aten::{f.func}
- inline {sig.decl(is_redispatching_fn=True)} {{
- return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
- }}
- """
- return result
- # Generates ATenOpList.cpp, a runtime accessible list of all aten
- # operators.
- # TODO: This was historically used to help some JIT interop code
- # figure out whether or not to treat aten namespace'd operators
- # one way or another, we should reevaluate if this is actually needed.
- @with_native_function
- def compute_aten_op(f: NativeFunction) -> str:
- return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
- # Generates MetaFunctions.h
- def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
- if not g.structured:
- return None
- with native_function_manager(g.out):
- name = meta.name(g)
- args = structured.meta_arguments(g)
- args_str = ", ".join(a.decl() for a in args)
- parent_class = g.out.structured_inherits
- if parent_class is None:
- parent_class = "at::impl::MetaBase"
- meta_return = "void"
- precomputed = g.out.precomputed if g.structured else None
- if precomputed:
- # Generate the template declaration with one bool parameter for each
- # precomputed element. Each parameter is true if the corresponding (in
- # terms of position) precomputed element has been set.
- precomputed_values = [*precomputed.replace.values(), precomputed.add]
- precomputed_elements = [
- elem for replace_list in precomputed_values for elem in replace_list
- ]
- precomputed_template_parameters = [
- elem.name.upper() for elem in precomputed_elements
- ]
- precomputed_template_params_str = ", ".join(
- f"bool {param} = false" for param in precomputed_template_parameters
- )
- precompute_template_decl = f"template <{precomputed_template_params_str}>"
- # Generate a string containing declarations of all precomputed elements.
- precomputed_elements_with_cpp_types = [
- structured.argument_type(elem, binds=elem.name)
- for elem in precomputed_elements
- ]
- precomputed_elements_decl = ";\n".join(
- f"{elem.cpp_type(strip_ref=True)} {elem.name}"
- for elem in precomputed_elements_with_cpp_types
- )
- # Generate "setter" methods for each precomputed element. Each method will return
- # a new instance of precompute_out with the template parameter that corresponds to
- # the member set by the method to true (to indicate that it has been set).
- setter_methods = []
- for i, elem in enumerate(precomputed_elements):
- # Generate the signature. The return type will be the same
- # as the type of `this` but with the template parameter
- # corresponding to the element set by this method set to true.
- # The assert generated below will ensure that this template
- # parameter is false on the type of `this`.
- return_ty_templates = ", ".join(
- precomputed_template_parameters[:i]
- + ["true"]
- + precomputed_template_parameters[i + 1 :]
- )
- return_ty = f"precompute_out<{return_ty_templates}>"
- elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
- strip_ref=True
- )
- signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
- # Generate an assert which checks that the
- # template parameter corresponding to the precomputed
- # element that is set by this method is false on the
- # class corresponding to the object that `this` points to.
- # This ensures that each element can be set only once.
- assert_msg = f'"{elem.name} already set"'
- assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
- # Generate the new object construction block. All state
- # except the element that this method sets is copied from the
- # object that `this` points to. The value for the element that
- # the method sets is taken from a method parameter.
- construction_stmts = []
- construction_stmts.append(f"{return_ty} ret;")
- for j, elem in enumerate(precomputed_elements):
- if i == j:
- construction_stmts.append(f"ret.{elem.name} = value;")
- else:
- construction_stmts.append(
- f"ret.{elem.name} = this->{elem.name};"
- )
- construction_stmts.append("return ret;")
- construction_block = "\n".join(construction_stmts)
- setter_methods.append(
- f"""
- {signature} {{
- {assert_stmt}
- {construction_block}
- }}
- """
- )
- setter_methods_decl = "\n".join(setter_methods)
- # Meta should return an instance of the struct containing the precomputed elements.
- meta_return_template_params = ", ".join(
- ["true"] * len(precomputed_template_parameters)
- )
- # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
- # type (which has a variable number of template parameters).
- meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
- meta_return = "meta_return_ty"
- precomputed_decl = f"""
- {precompute_template_decl}
- struct TORCH_API precompute_out {{
- {setter_methods_decl}
- {precomputed_elements_decl};
- }};"""
- else:
- meta_return_typedef = ""
- precomputed_decl = ""
- return f"""\
- struct TORCH_API structured_{name} : public {parent_class} {{
- {precomputed_decl}
- {meta_return_typedef}
- {meta_return} meta({args_str});
- }};
- """
- def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
- name = str(f.func.name.name)
- if name.endswith("_like") or name.startswith("new_"):
- return False
- if f.func.arguments.tensor_options is None:
- return False
- return selector.is_native_function_selected(f)
- # Generates RegisterBackendSelect.cpp, a series of kernels which provide
- # specialized computation of dispatch key for operator signatures which cannot
- # be easily done automatically using templating.
- @dataclass(frozen=True)
- class ComputeBackendSelect:
- target: Literal[Target.DEFINITION, Target.REGISTRATION]
- # Selector object to determine which operators to generate
- # registration code for.
- selector: SelectiveBuilder
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- if not needs_backend_select(f, self.selector):
- return None
- name = native.name(f.func)
- # BackendSelect can go to Meta, so it must preserve symints
- native_sig = NativeSignature(f.func, symint=True)
- native_tensor_args = [
- a
- for a in native_sig.arguments()
- if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
- ]
- dispatcher_sig = DispatcherSignature.from_schema(f.func)
- sig: NativeSignature | DispatcherSignature
- sig = dispatcher_sig
- dispatcher_exprs = dispatcher_sig.exprs()
- dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
- if self.target is Target.DEFINITION:
- # I don't think there's actually a good reason to generate
- # these two cases differently
- # The first case could probably be improved though- it calls computeDispatchKeySet(),
- # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
- if native_tensor_args:
- assert f.func.arguments.has_tensor_arg()
- tensor_args = ", ".join(a.name for a in native_tensor_args)
- compute_dk = f"""\
- DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
- DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
- DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
- else:
- assert not f.func.arguments.has_tensor_arg()
- compute_dk = (
- f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
- )
- return f"""\
- // aten::{f.func}
- C10_ALWAYS_INLINE
- {sig.defn(name)} {{
- {compute_dk}
- return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
- _dk, {", ".join(a.expr for a in dispatcher_exprs)});
- }}
- """
- elif self.target is Target.REGISTRATION:
- return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
- else:
- assert_never(self.target)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # YAML CODE GENERATION
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def format_yaml(data: object) -> str:
- # Ignore alias in Dumper
- YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
- # Support serializing OrderedDict
- def dict_representer(dumper: Any, data: Any) -> Any:
- return dumper.represent_dict(data.items())
- YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
- # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
- # width=1e9 turns off optional line breaks and improves
- # the portability of the outputted yaml.
- return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
- # For some reason, some defaults we write to YAML are written as native
- # YAML objects, rather than doing them uniformly as strings. This
- # function detects those cases and converts them into native Python
- # objects.
- def pythonify_default(s: str) -> object:
- if s == "true":
- return True
- elif s == "false":
- return False
- try:
- return int(s)
- except ValueError:
- try:
- return float(s)
- except ValueError:
- return s
- # What is a dynamic type? Over time, the semantic meaning of
- # dynamic type has degraded to meaninglessness (in the old days,
- # it captured dtype-ness of types, but that has gone away with
- # the removal of TH). These days, it's mostly the same thing as
- # the C++ API argument type, except that Tensor and Tensor?
- # arguments simply present as Tensor.
- #
- # TODO: Get rid of dynamic_type, after getting tools/autograd
- # to use the new codegen framework
- def dynamic_type(t: Type) -> str:
- if isinstance(t, OptionalType):
- return dynamic_type(t.elem)
- # Note we don't use t.is_tensor_like() here because it would
- # also include Tensor[]
- if str(t) == "Tensor":
- return "at::Tensor"
- # This is a legacy concept, so never report SymInt
- return cpp.argumenttype_type(
- t, mutable=False, binds="__placeholder__", symint=False
- ).cpp_type()
- def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
- # This is written out explicitly to ensure that Tensor and
- # namespace are put into the list in the right order
- method_of = ["Type"]
- if Variant.method in variants:
- method_of.append("Tensor")
- if Variant.function in variants:
- method_of.append("namespace")
- return method_of
- def compute_returns_yaml(
- f: NativeFunction,
- ) -> tuple[list[dict[str, str]], dict[str, str]]:
- # Note [name and field_name]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~
- # To understand name_to_field_name, we must first talk about this
- # schema:
- #
- # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
- #
- # There is something very odd about this schema: it is an out
- # variant of the function (that is to say, it will convert into
- # at::lstsq_out() in the C++ API), but the names of the output
- # return arguments don't match the keyword argument names of
- # the inputs. It TURNS OUT that in this situation, the historical
- # Declarations.yaml we want to output is this (abbreviated to
- # only show relevant fields):
- #
- # arguments:
- # ...
- # - field_name: solution
- # name: X
- # - field_name: QR
- # name: qr
- # ...
- #
- # returns:
- # - field_name: solution
- # name: X
- # - field_name: QR
- # name: qr
- #
- # The name of the return fields is stored in 'field_name', and the
- # name of the arguments is stored in 'name'. So when we process
- # arguments, we need a way to get at the corresponding return. At
- # the moment, this is most conveniently done by constructing a
- # mapping from name (the argument concept) to field_name (the
- # return concept) while processing return arguments, since we don't
- # directly maintain this correspondence in the modeling of function
- # schema itself.
- #
- # See also https://github.com/pytorch/pytorch/issues/43114
- name_to_field_name: dict[str, str] = {}
- # Compute the returns field of the YAML entry
- names = cpp.return_names(f)
- returns = []
- for i, (r, name) in enumerate(zip(f.func.returns, names)):
- ret = {
- "dynamic_type": dynamic_type(r.type),
- "name": name,
- # legacy, report ints
- "type": cpp.return_type(r, symint=False).cpp_type(),
- }
- if r.name:
- # See Note [name and field_name]
- ret["field_name"] = r.name
- if f.func.is_out_fn():
- name_to_field_name[f.func.arguments.out[i].name] = r.name
- returns.append(ret)
- return returns, name_to_field_name
- # arguments in yaml roughly corresponds to the public C++ API
- def compute_cpp_argument_yaml(
- cpp_a: Binding,
- *,
- schema_order: bool,
- kwarg_only_set: set[str],
- out_arg_set: set[str],
- name_to_field_name: dict[str, str],
- ) -> object:
- if isinstance(cpp_a.argument, TensorOptionsArguments):
- arg: dict[str, object] = {
- "annotation": None,
- "dynamic_type": "at::TensorOptions",
- "is_nullable": False,
- "name": cpp_a.name,
- "type": cpp_a.type,
- "kwarg_only": True,
- }
- if cpp_a.default is not None:
- arg["default"] = cpp_a.default
- return arg
- elif isinstance(cpp_a.argument, SelfArgument):
- raise AssertionError
- elif isinstance(cpp_a.argument, Argument):
- return compute_argument_yaml(
- cpp_a.argument,
- schema_order=schema_order,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- def compute_argument_yaml(
- a: Argument,
- *,
- schema_order: bool,
- kwarg_only_set: set[str],
- out_arg_set: set[str],
- name_to_field_name: dict[str, str],
- ) -> object:
- arg: dict[str, object] = {
- "annotation": str(a.annotation) if a.annotation else None,
- "dynamic_type": dynamic_type(a.type),
- "is_nullable": a.type.is_nullable(),
- "name": a.name,
- # legacy, report ints
- "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
- }
- if a.default is not None:
- arg["default"] = pythonify_default(
- cpp.default_expr(a.default, a.type, symint=False)
- )
- if a.name in kwarg_only_set:
- arg["kwarg_only"] = True
- if a.name in out_arg_set:
- arg["output"] = True
- arg["allocate"] = True
- # See Note [name and field_name]
- if a.name in name_to_field_name:
- arg["field_name"] = name_to_field_name[a.name]
- # Historically, booleans don't get their size recorded, because it
- # is already built into the cpp type (e.g., std::array<bool, 4>)
- l = a.type.is_list_like()
- if l is not None and l.size is not None and str(l.elem) != "bool":
- arg["size"] = l.size
- return arg
- @with_native_function
- def compute_declaration_yaml(f: NativeFunction) -> object:
- returns, name_to_field_name = compute_returns_yaml(f)
- # These sets are used to conveniently test if an argument is a
- # kwarg-only or out argument
- kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
- out_arg_set = {a.name for a in f.func.arguments.out}
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- cpp_args = sig_group.signature.arguments()
- arguments = [
- compute_cpp_argument_yaml(
- cpp_a,
- schema_order=False,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- for cpp_a in cpp_args
- ]
- schema_order_jit_arguments = list(f.func.schema_order_arguments())
- schema_order_arguments = [
- compute_argument_yaml(
- a,
- schema_order=True,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- for a in schema_order_jit_arguments
- ]
- cpp_schema_order_types = [
- # NB: method here doesn't matter
- r.type
- for a in schema_order_jit_arguments
- for r in cpp.argument(
- a,
- method=False,
- cpp_no_default_args=set(),
- faithful=False,
- symint=False,
- has_tensor_options=False,
- )
- ]
- # legacy, report ints
- cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
- schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
- is_factory_method = (
- any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
- and Variant.method not in f.variants
- )
- return OrderedDict(
- [
- ("name", cpp.name(f.func)),
- ("operator_name", str(f.func.name.name)),
- ("overload_name", str(f.func.name.overload_name)),
- ("manual_kernel_registration", f.manual_kernel_registration),
- (
- "category_override",
- f.category_override if f.category_override is not None else "",
- ),
- ("schema_string", f"aten::{f.func}"),
- ("arguments", arguments),
- ("schema_order_cpp_signature", schema_order_cpp_signature),
- ("schema_order_arguments", schema_order_arguments),
- ("method_of", compute_method_of_yaml(f.variants)),
- ("mode", "native"),
- ("python_module", "" if f.python_module is None else f.python_module),
- ("returns", returns),
- ("inplace", f.func.name.name.inplace),
- ("is_factory_method", is_factory_method),
- ("abstract", f.is_abstract),
- ("device_guard", f.device_guard),
- ("with_gil", False),
- ("deprecated", False),
- ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
- ]
- )
- # See Note [Auto generated composite kernels]
- def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
- return (f.structured or f.structured_delegate is not None) and (
- f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
- )
- @with_native_function_and_indices
- def compute_registration_declarations(
- f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
- ) -> str:
- name = dispatcher.name(f.func)
- returns_type = dispatcher.returns_type(f.func.returns).cpp_type()
- args = dispatcher.arguments(f.func)
- args_str = ", ".join(a.no_default().decl() for a in args)
- comment_data: dict[str, str] = {
- "schema": f"aten::{f.func}",
- # TODO: What exactly is the semantics of the 'dispatch' field?
- "dispatch": str(
- {k for k, v in backend_indices.items() if v.has_kernel(f)}
- != {DispatchKey.CompositeImplicitAutograd}
- and {k for k, v in backend_indices.items() if v.has_kernel(f)}
- != {
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- }
- ),
- "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
- }
- return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
- """
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # RUN IT ALL
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def get_custom_build_selector(
- provided_op_registration_allowlist: list[str] | None,
- op_selection_yaml_path: str | None,
- ) -> SelectiveBuilder:
- assert not (
- provided_op_registration_allowlist is not None
- and op_selection_yaml_path is not None
- ), (
- "Both provided_op_registration_allowlist and "
- + "op_selection_yaml_path can NOT be provided at the "
- + "same time."
- )
- op_registration_allowlist: set[str] | None = None
- if provided_op_registration_allowlist is not None:
- op_registration_allowlist = set(provided_op_registration_allowlist)
- if op_registration_allowlist is not None:
- selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
- op_registration_allowlist,
- True,
- False,
- )
- elif op_selection_yaml_path is not None:
- selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
- else:
- selector = SelectiveBuilder.get_nop_selector()
- return selector
- def get_grouped_by_view_native_functions(
- native_functions: Sequence[NativeFunction],
- ) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
- def maybe_create_view_group(
- d: dict[ViewSchemaKind | SchemaKind, NativeFunction],
- ) -> list[NativeFunction | NativeFunctionsViewGroup]:
- funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
- if ViewSchemaKind.aliasing in d:
- view = d.pop(ViewSchemaKind.aliasing)
- view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
- view_copy = d.pop(SchemaKind.functional, None)
- funcs.append(
- NativeFunctionsViewGroup(
- view=view,
- view_copy=view_copy,
- view_inplace=view_inplace,
- )
- )
- # Take the remaining functions that weren't part of the view group
- # and emit them separately
- funcs.extend(d.values())
- return funcs
- grouped_by_views: dict[
- FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
- ] = defaultdict(dict)
- for f in native_functions:
- schema = f.func.view_signature()
- view_kind: ViewSchemaKind = f.view_schema_kind
- # We need to group up ops relevant to the same "view", consisting of:
- # view op (ViewSchemaKind.aliasing)
- # view_inplace op (ViewSchemaKind.aliasing_inplace)
- # view_copy op (SchemaKind.functional)
- if view_kind == ViewSchemaKind.non_aliasing:
- kind = f.func.kind()
- assert kind not in grouped_by_views[schema]
- grouped_by_views[schema][kind] = f
- else:
- assert view_kind not in grouped_by_views[schema], (
- f"{view_kind} already in {grouped_by_views[schema].keys()}"
- )
- grouped_by_views[schema][view_kind] = f
- return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
- def get_grouped_native_functions(
- native_functions: Sequence[NativeFunction],
- ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
- def flatten_pre_group(
- d: dict[SchemaKind, NativeFunction],
- ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
- r = NativeFunctionsGroup.from_dict(d)
- if r is None:
- # Invariant: any NativeFunctions that are code-generated
- # should have been grouped into NativeFunctionsGroup objects
- assert not any("generated" in f.tags for f in d.values())
- return list(d.values())
- else:
- return [r]
- # TODO: how come ValuesView isn't a Sequence lol
- pre_grouped_native_functions = pre_group_native_functions(native_functions)
- return list(
- concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
- )
- def get_ns_grouped_kernels(
- *,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- backend_indices: dict[DispatchKey, BackendIndex],
- native_function_decl_gen: Callable[
- [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
- ] = dest.compute_native_function_declaration,
- ) -> dict[str, list[str]]:
- ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
- for f in grouped_native_functions:
- native_function_namespaces = set()
- dispatch_keys = set()
- for dispatch_key, backend_idx in backend_indices.items():
- backend_metadata = backend_idx.get_kernel(f)
- if backend_metadata:
- namespace = backend_metadata.cpp_namespace
- dispatch_keys.add(dispatch_key)
- native_function_namespaces.add(namespace)
- else:
- namespace = DEFAULT_KERNEL_NAMESPACE
- assert len(native_function_namespaces) <= 1, (
- f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
- )
- ns_grouped_kernels[namespace].extend(
- native_function_decl_gen(f, backend_idx)
- )
- return ns_grouped_kernels
- def get_native_function_declarations_from_ns_grouped_kernels(
- *,
- ns_grouped_kernels: dict[str, list[str]],
- ) -> list[str]:
- declarations: list[str] = []
- newline = "\n"
- for namespace, kernels in ns_grouped_kernels.items():
- ns_helper = NamespaceHelper(
- namespace_str=namespace,
- entity_name="",
- max_level=4,
- )
- # Convert to a set first to remove duplicate kernel names. Backends are
- # allowed to repeat kernel names; only generate the declaration once!
- ordered_kernels = list(OrderedDict.fromkeys(kernels))
- declarations.extend(
- f"""
- {ns_helper.prologue}
- {newline.join(ordered_kernels)}
- {ns_helper.epilogue}
- """.split(newline)
- )
- return declarations
- # Return native function declarations grouped by their namespaces.
- def get_native_function_declarations(
- *,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- backend_indices: dict[DispatchKey, BackendIndex],
- native_function_decl_gen: Callable[
- [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
- ] = dest.compute_native_function_declaration,
- ) -> list[str]:
- """
- Generate kernel declarations, in `NativeFunction(s).h`.
- :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
- :param backend_indices: kernel collections grouped by dispatch key.
- :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
- :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
- """
- ns_grouped_kernels = get_ns_grouped_kernels(
- grouped_native_functions=grouped_native_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=native_function_decl_gen,
- )
- return get_native_function_declarations_from_ns_grouped_kernels(
- ns_grouped_kernels=ns_grouped_kernels
- )
- def get_kernel_namespace(
- *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
- ) -> str:
- backend_metadata = backend_idx.get_kernel(f)
- assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
- f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
- f"with dispatch key {backend_idx.dispatch_key}"
- f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
- )
- return (
- backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
- )
- # Return native function definitions grouped by dispatch key and custom namespace.
- # Used in RegisterDispatchKey.cpp and etc.
- def get_native_function_definitions(
- *,
- fm: FileManager,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- dispatch_key: DispatchKey,
- backend_idx: BackendIndex,
- selector: SelectiveBuilder,
- rocm: bool,
- symint: bool,
- skip_dispatcher_op_registration: bool,
- gen_dispatch_helpers: bool,
- ) -> list[str]:
- definitions: list[str] = []
- ns_definitions: dict[str, list[str]] = defaultdict(list)
- anonymous_definitions: dict[str, list[str]] = defaultdict(list)
- registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
- newline = "\n"
- ns_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.NAMESPACED_DEFINITION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- anonymous_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.ANONYMOUS_DEFINITION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- reg_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.REGISTRATION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- for f in grouped_native_functions:
- kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
- "::native", ""
- )
- ns_definitions[kernel_namespace].extend(
- ns_gen(f),
- )
- anonymous_definitions[kernel_namespace].extend(
- anonymous_gen(f),
- )
- namespace = (
- f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
- )
- if namespace not in registrations[kernel_namespace]:
- registrations[kernel_namespace] = defaultdict(list)
- registrations[kernel_namespace][namespace].extend(
- reg_gen(f),
- )
- for kernel_namespace in ns_definitions:
- if len(ns_definitions[kernel_namespace]) == 0:
- continue
- ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
- registration_body = ""
- for namespace in registrations[kernel_namespace]:
- if not registrations[kernel_namespace][namespace]:
- continue
- registration_body += f"""
- TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
- {newline.join(registrations[kernel_namespace][namespace])}
- }}"""
- definitions.extend(
- fm.substitute_with_template(
- "RegisterDispatchDefinitions.ini",
- lambda: {
- "ns_prologue": ns_helper.prologue,
- "ns_epilogue": ns_helper.epilogue,
- "dispatch_anonymous_definitions": anonymous_definitions[
- kernel_namespace
- ],
- "static_init_dispatch_registrations": ""
- if skip_dispatcher_op_registration
- else registration_body,
- "deferred_dispatch_registrations": "",
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
- },
- ).split(newline)
- )
- return definitions
- # Return native function declarations grouped by dispatch key and custom namespace.
- # Used in CPUFunctions_inl.h and etc.
- def get_namespaced_declaration(
- *,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- dispatch_key: DispatchKey,
- backend_idx: BackendIndex,
- selector: SelectiveBuilder,
- rocm: bool,
- symint: bool,
- ) -> list[str]:
- declarations: list[str] = []
- ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
- newline = "\n"
- func = dest.RegisterDispatchKey(
- backend_idx,
- Target.NAMESPACED_DECLARATION,
- selector,
- rocm=rocm,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- symint=symint,
- )
- for f in grouped_native_functions:
- namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
- "native", dispatch_key.lower()
- )
- ns_grouped_kernels[namespace].extend(
- func(f),
- )
- for namespace, kernels in ns_grouped_kernels.items():
- if len(kernels) == 0:
- continue
- ns_helper = NamespaceHelper(
- namespace_str=namespace, entity_name="", max_level=3
- )
- ordered_kernels = list(OrderedDict.fromkeys(kernels))
- declarations.extend(
- f"""
- {ns_helper.prologue}
- {newline.join(ordered_kernels)}
- {ns_helper.epilogue}
- """.split(newline)
- )
- return declarations
- # Return native function schema registration code for aten and other namespaces.
- def get_native_function_schema_registrations(
- *,
- native_functions: Sequence[NativeFunction],
- schema_selector: SelectiveBuilder,
- ) -> tuple[list[str], str]:
- ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
- for native_function in native_functions:
- ns_native_functions[native_function.namespace].append(native_function)
- schema_registrations = ""
- aten_schema_registrations = []
- custom_namespace = None
- for namespace, funcs in ns_native_functions.items():
- schema_registrations_body = list(
- mapMaybe(RegisterSchema(schema_selector), funcs)
- )
- # NB: we have to separate aten namespace registration from other namespaces,
- # because in the template we hardcoded an operator for ATen already.
- if namespace == "aten":
- aten_schema_registrations = schema_registrations_body
- else:
- custom_namespace = namespace
- tab = "\t"
- # if the namespace is predefined, we should use define a library fragment
- # instead of a new library
- torch_library_macro = (
- "TORCH_LIBRARY_FRAGMENT"
- if namespace in FRAGMENT_NAMESPACES
- else "TORCH_LIBRARY"
- )
- schema_registrations += f"""
- {torch_library_macro}({custom_namespace}, m) {{
- {tab.join(schema_registrations_body)}
- }};"""
- return (aten_schema_registrations, schema_registrations)
- def gen_aggregated_headers(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- static_dispatch_idx: list[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- functions_keys: set[DispatchKey],
- dispatch_keys: Sequence[DispatchKey],
- rocm: bool,
- ) -> None:
- # Buck doesn't support dynamic output files, so we aggregate all operator
- # headers into a single file
- cpu_fm.write(
- "NativeMetaFunctions.h",
- lambda: {
- "NativeMetaFunctions_includes": [],
- "NativeMetaFunctions_declarations": list(
- mapMaybe(compute_meta_function_declaration, structured_native_functions)
- ),
- },
- )
- method_native_functions = [
- fn for fn in native_functions if Variant.method in fn.variants
- ]
- non_method_native_functions = [
- fn for fn in native_functions if fn not in method_native_functions
- ]
- cpu_fm.write(
- "MethodOperators.h",
- lambda: {
- "MethodOperators_includes": [],
- "MethodOperators_declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- method_native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "Operators.h",
- lambda: {
- "Operators_includes": ["#include <ATen/MethodOperators.h>"],
- "Operators_declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- non_method_native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "Functions.h",
- lambda: {
- "static_dispatch_extra_headers": static_dispatch_extra_headers(
- static_dispatch_idx
- ),
- "Functions_includes": ["#include <ATen/Operators.h>"],
- "Functions_declarations": list(
- mapMaybe(
- ComputeFunction(),
- native_functions,
- )
- ),
- },
- )
- declarations = get_native_function_declarations(
- grouped_native_functions=grouped_native_functions,
- backend_indices=backend_indices,
- )
- cpu_fm.write(
- "NativeFunctions.h",
- lambda: {
- "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
- "NativeFunctions_declarations": declarations,
- },
- )
- for dispatch_key in dispatch_keys:
- fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
- if dispatch_key in functions_keys:
- inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
- fm.write_with_template(
- f"{dispatch_key}Functions.h",
- "DispatchKeyFunctions.h",
- lambda: {
- "dispatch_key": str(dispatch_key),
- "inline_headers": inl_headers,
- },
- )
- fm.write_with_template(
- f"{dispatch_key}Functions_inl.h",
- "DispatchKeyFunctions_inl.h",
- lambda: {
- "DispatchKeyFunctions_inl_includes": [],
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_declarations": get_namespaced_declaration(
- grouped_native_functions=grouped_native_functions,
- dispatch_key=dispatch_key,
- backend_idx=backend_indices[dispatch_key],
- selector=selector,
- rocm=rocm,
- symint=True,
- ),
- },
- )
- del fm
- def gen_per_operator_headers(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- static_dispatch_idx: list[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- ops_fm: FileManager,
- functions_keys: set[DispatchKey],
- dispatch_keys: Sequence[DispatchKey],
- rocm: bool,
- ) -> None:
- # For CMake builds, split operator declarations into separate headers in
- # the ATen/ops folder to split up header dependencies
- functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
- for fn in native_functions:
- functions_by_root_name[fn.root_name].append(fn)
- grouped_functions_by_root_name: dict[
- str, list[NativeFunction | NativeFunctionsGroup]
- ] = defaultdict(list)
- for group in grouped_native_functions:
- name = group.root_name
- grouped_functions_by_root_name[name].append(group)
- for name, functions in functions_by_root_name.items():
- ops_fm.write_with_template(
- f"{name}_ops.h",
- "Operator.h",
- lambda: {
- "declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- functions,
- )
- ),
- },
- )
- ops_fm.write_with_template(
- f"{name}.h",
- "Function.h",
- lambda: {
- "static_dispatch_ops_headers": list(
- mapMaybe(
- lambda fn: static_dispatch_ops_header(
- fn, backend_index=static_dispatch_idx
- ),
- functions,
- )
- ),
- "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
- "function_definitions": list(
- mapMaybe(
- ComputeFunction(),
- functions,
- )
- ),
- },
- )
- grouped_functions = grouped_functions_by_root_name.get(name, [])
- structured_functions = [
- fn
- for fn in grouped_functions
- if isinstance(fn, NativeFunctionsGroup) and fn.structured
- ]
- is_structured = len(structured_functions) > 0
- if is_structured:
- ops_fm.write_with_template(
- f"{name}_meta.h",
- "NativeMetaFunction.h",
- lambda: {
- "meta_function_declarations": list(
- mapMaybe(
- compute_meta_function_declaration, structured_functions
- )
- ),
- },
- )
- declarations = get_native_function_declarations(
- grouped_native_functions=grouped_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=dest.compute_native_function_declaration,
- )
- ops_fm.write_with_template(
- f"{name}_native.h",
- "NativeFunction.h",
- lambda: {
- "extra_includes": (
- f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
- ),
- "native_function_declarations": declarations,
- },
- )
- for category, suffix in [
- ("Functions", ""),
- ("Operators", "_ops"),
- ("NativeMetaFunctions", "_meta"),
- ("NativeFunctions", "_native"),
- ]:
- cpu_fm.write(
- f"{category}.h",
- lambda: {
- f"{category}_includes": [
- f"#include <ATen/ops/{name}{suffix}.h>"
- for name in sorted(functions_by_root_name.keys())
- ],
- f"{category}_declarations": [],
- },
- )
- for dispatch_key in dispatch_keys:
- if dispatch_key not in functions_keys:
- continue
- dispatch_namespace = dispatch_key.lower()
- dispatch_names = []
- for name, functions in functions_by_root_name.items():
- grouped_functions = grouped_functions_by_root_name.get(name, [])
- declarations = list(
- concatMap(
- dest.RegisterDispatchKey(
- backend_indices[dispatch_key],
- Target.NAMESPACED_DECLARATION,
- selector,
- rocm=rocm,
- symint=True,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- ),
- grouped_functions,
- )
- )
- if len(declarations) == 0:
- continue
- dispatch_names.append(name)
- ops_fm.write_with_template(
- f"{name}_{dispatch_namespace}_dispatch.h",
- "DispatchKeyFunction.h",
- lambda: {
- "dispatch_namespace": dispatch_namespace,
- "dispatch_namespaced_declarations": declarations,
- },
- )
- fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
- inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
- fm.write_with_template(
- f"{dispatch_key}Functions.h",
- "DispatchKeyFunctions.h",
- lambda: {
- "dispatch_key": str(dispatch_key),
- "inline_headers": inl_headers,
- },
- )
- fm.write_with_template(
- f"{dispatch_key}Functions_inl.h",
- "DispatchKeyFunctions_inl.h",
- lambda: {
- "dispatch_namespace": dispatch_namespace,
- "DispatchKeyFunctions_inl_includes": [
- f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
- for name in sorted(dispatch_names)
- ],
- "dispatch_namespaced_declarations": [],
- },
- )
- del fm
- cpu_fm.write(
- "MethodOperators.h",
- lambda: {
- "MethodOperators_includes": sorted(
- f"#include <ATen/ops/{name}_ops.h>"
- for name, functions in functions_by_root_name.items()
- if any(Variant.method in fn.variants for fn in functions)
- ),
- "MethodOperators_declarations": [],
- },
- )
- def gen_headers(
- *,
- native_functions: Sequence[NativeFunction],
- valid_tags: set[str],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- static_dispatch_idx: list[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: dict[DispatchKey, BackendIndex],
- core_fm: FileManager,
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- ops_fm: FileManager,
- dispatch_keys: Sequence[DispatchKey],
- functions_keys: set[DispatchKey],
- rocm: bool,
- per_operator_headers: bool,
- ) -> None:
- if per_operator_headers:
- gen_per_operator_headers(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- ops_fm=ops_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=rocm,
- )
- else:
- gen_aggregated_headers(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=rocm,
- )
- core_fm.write(
- "TensorBody.h",
- lambda: {
- "tensor_method_declarations": list(
- mapMaybe(
- ComputeTensorMethod(
- target=Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- native_functions,
- )
- ),
- "tensor_method_definitions": list(
- mapMaybe(
- ComputeTensorMethod(
- target=Target.DEFINITION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "RedispatchFunctions.h",
- lambda: {
- "function_redispatch_definitions": list(
- mapMaybe(ComputeRedispatchFunction(), native_functions)
- ),
- },
- )
- cpu_fm.write(
- "RegistrationDeclarations.h",
- lambda: {
- "registration_declarations": [
- compute_registration_declarations(f, backend_indices)
- for f in native_functions
- ],
- },
- )
- cpu_fm.write(
- "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
- )
- def gen_aten_interned_strings() -> dict[str, str]:
- attrs: set[str] = set() # All function argument names
- names = set() # All ATen function names
- for func in native_functions:
- names.add(str(func.func.name.name))
- # Some operators don't have a functional variant but we still create a
- # symbol without the underscore
- names.add(func.func.name.name.base)
- attrs.update(arg.name for arg in func.func.schema_order_arguments())
- # These are keywords in C++, so aren't valid symbol names
- # https://en.cppreference.com/w/cpp/language/operator_alternative
- names -= {
- "and",
- "and_eq",
- "bitand",
- "bitor",
- "compl",
- "not",
- "not_eq",
- "or",
- "or_eq",
- "xor",
- "xor_eq",
- }
- return {
- "aten_symbols": " \\\n".join(
- [f"_(aten, {name})" for name in sorted(names)]
- ),
- "attr_symbols": " \\\n".join(
- [f"_(attr, {name})" for name in sorted(attrs)]
- ),
- }
- core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
- def gen_tags_enum() -> dict[str, str]:
- return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
- core_fm.write("enum_tag.h", gen_tags_enum)
- def gen_source_files(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- view_groups: Sequence[NativeFunctionsViewGroup],
- selector: SelectiveBuilder,
- static_dispatch_idx: list[BackendIndex],
- backend_indices: dict[DispatchKey, BackendIndex],
- aoti_fm: FileManager,
- core_fm: FileManager,
- cpu_vec_fm: FileManager,
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- dispatch_keys: Sequence[DispatchKey],
- functions_keys: set[DispatchKey],
- rocm: bool,
- force_schema_registration: bool,
- per_operator_headers: bool,
- skip_dispatcher_op_registration: bool,
- update_aoti_c_shim: bool,
- aoti_backends: set[DispatchKey | None],
- extend_aoti_c_shim: bool,
- ) -> None:
- extra_cuda_headers = """\
- #include <c10/cuda/CUDAGuard.h>
- #include <ATen/cuda/ATenCUDAGeneral.h>
- #include <ATen/cuda/CUDADevice.h>
- #include <ATen/cuda/CUDAContext.h>"""
- if rocm:
- extra_cuda_headers = """\
- #include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
- #include <ATen/hip/ATenHIPGeneral.h>
- #include <ATen/hip/HIPDevice.h>
- #include <ATen/hip/HIPContext.h>"""
- for dispatch_key in dispatch_keys:
- fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
- if per_operator_headers:
- def operator_headers() -> list[str]:
- headers = []
- for g in grouped_native_functions:
- is_registered = False
- if backend_index.has_kernel(g):
- is_registered = True
- # The above has_kernel test on a group will only test for
- # the existence of out dispatch, because that's how
- # structured kernels work. But sometimes functions can be
- # grouped but not be structured, and then you need to check
- # each individual piece, as they may have manual dispatch
- # entries.
- elif isinstance(g, NativeFunctionsGroup) and any(
- backend_index.has_kernel(fn) for fn in g.functions()
- ):
- is_registered = True
- # TODO: this condition is a bit questionable
- # (It has to do with the fact that structured kernels get generated kernels
- # to the Meta + CompositeExplicitAutogradNonFunctional keys).
- elif g.structured and dispatch_key in (
- DispatchKey.Meta,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ):
- is_registered = True
- if not is_registered:
- continue
- headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
- if (
- dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- headers.append(f"#include <ATen/ops/{g.root_name}.h>")
- if dispatch_key in functions_keys:
- headers.append(
- f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
- )
- return sorted(set(headers))
- else:
- def operator_headers() -> list[str]:
- headers = ["#include <ATen/NativeFunctions.h>"]
- if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
- headers.append("#include <ATen/Functions.h>")
- if dispatch_key in functions_keys:
- headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
- return headers
- backend_index = backend_indices[dispatch_key]
- ns_grouped_native_functions = defaultdict(list)
- for grouped_native_function in grouped_native_functions:
- namespace = (
- grouped_native_function.namespace
- if isinstance(grouped_native_function, NativeFunction)
- else grouped_native_function.functional.namespace
- )
- ns_grouped_native_functions[namespace].append(grouped_native_function)
- dispatch_namespace = str(dispatch_key).lower()
- # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
- # compilation will fail when `-Werror=unused-function` flag is set
- gen_dispatch_helpers: bool = (
- dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
- )
- register_dispatch_key_base_env = {
- "extra_cuda_headers": extra_cuda_headers
- if is_cuda_dispatch_key(dispatch_key)
- else "",
- "external_backend_headers": "",
- "dispatch_headers": dest.gen_registration_headers(
- backend_index, per_operator_headers, rocm
- ),
- # ops_headers *could* be sharded, but doesn't seem necessary?
- "ops_headers": operator_headers(),
- "dispatch_helpers": (
- dest.gen_registration_helpers(backend_index)
- if gen_dispatch_helpers
- else []
- ),
- }
- def register_dispatch_key_env_callable(
- gnf: NativeFunction | NativeFunctionsGroup,
- ) -> dict[str, list[str]]:
- return {
- "dispatch_definitions": get_native_function_definitions(
- fm=fm, # noqa: F821
- grouped_native_functions=[gnf],
- dispatch_key=dispatch_key,
- backend_idx=backend_index,
- selector=selector,
- rocm=rocm,
- symint=True,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- gen_dispatch_helpers=gen_dispatch_helpers,
- )
- }
- fm.write_sharded_with_template(
- f"Register{dispatch_key}.cpp",
- "RegisterDispatchKey.cpp",
- grouped_native_functions,
- key_fn=lambda x: x.root_name,
- env_callable=register_dispatch_key_env_callable,
- num_shards=4 if dispatch_key == DispatchKey.CPU else 1,
- base_env=register_dispatch_key_base_env,
- sharded_keys={"dispatch_definitions"},
- )
- for g in structured_native_functions:
- if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
- continue
- name = g.functional.func.name.name
- if dispatch_key is DispatchKey.CPU:
- assert fm is cpu_fm
- fm.write_with_template(
- f"UfuncCPU_{name}.cpp",
- "UfuncCPU.cpp",
- lambda: {
- "meta_declaration": compute_meta_function_declaration(g),
- "native_declaration": dest.compute_native_function_declaration(
- g, backend_indices[dispatch_key]
- ),
- "native_definitions": dest.compute_ufunc_cpu(g),
- },
- )
- cpu_vec_fm.write_with_template(
- f"UfuncCPUKernel_{name}.cpp",
- "UfuncCPUKernel.cpp",
- lambda: {
- "name": name,
- "native_definitions": dest.compute_ufunc_cpu_kernel(g),
- },
- )
- elif dispatch_key is DispatchKey.CUDA:
- cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
- if rocm:
- cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
- fm.write_with_template(
- f"UfuncCUDA_{name}.cu",
- "UfuncCUDA.cu",
- lambda: {
- "name": name,
- "cuda_headers": cuda_headers,
- "meta_declaration": compute_meta_function_declaration(g),
- "native_declaration": dest.compute_native_function_declaration(
- g, backend_indices[dispatch_key]
- ),
- "native_definitions": dest.compute_ufunc_cuda(g),
- },
- )
- else:
- raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
- del fm
- gen_aoti_c_shim_files(
- aoti_fm=aoti_fm,
- aoti_backends=aoti_backends,
- native_functions=native_functions,
- backend_indices=backend_indices,
- structured_native_functions=structured_native_functions,
- extra_cuda_headers=extra_cuda_headers,
- update_aoti_c_shim=update_aoti_c_shim,
- extend_aoti_c_shim=extend_aoti_c_shim,
- )
- # BackendSelect is generated specially
- def gen_backend_select() -> dict[str, list[str]]:
- relevant_fns = [
- fn for fn in native_functions if needs_backend_select(fn, selector)
- ]
- return {
- "ops_headers": [
- f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
- ],
- "backend_select_method_definitions": list(
- mapMaybe(
- ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
- )
- ),
- "backend_select_function_registrations": list(
- mapMaybe(
- ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
- )
- ),
- }
- cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
- schema_selector = selector
- if force_schema_registration:
- schema_selector = SelectiveBuilder.get_nop_selector()
- (
- aten_schema_registrations,
- schema_registrations,
- ) = get_native_function_schema_registrations(
- native_functions=native_functions, schema_selector=schema_selector
- )
- cpu_fm.write(
- "RegisterSchema.cpp",
- lambda: {
- "aten_schema_registrations": []
- if skip_dispatcher_op_registration
- else aten_schema_registrations,
- "schema_registrations": []
- if skip_dispatcher_op_registration
- else schema_registrations,
- },
- )
- def key_func(
- fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
- ) -> str:
- return fn.root_name
- cpu_fm.write_sharded(
- "Operators.cpp",
- native_functions,
- key_fn=key_func,
- env_callable=lambda fn: {
- "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
- "definitions": [
- ComputeOperators(
- Target.DEFINITION,
- static_dispatch_backend_indices=static_dispatch_idx,
- )(fn)
- ],
- },
- base_env={
- "static_dispatch_extra_headers": static_dispatch_extra_headers(
- static_dispatch_idx
- ),
- },
- num_shards=5,
- sharded_keys={
- "operator_headers",
- "definitions",
- "static_dispatch_extra_headers",
- },
- )
- cpu_fm.write("Functions.cpp", dict)
- core_fm.write("TensorMethods.cpp", dict)
- core_fm.write(
- "ATenOpList.cpp",
- lambda: {
- "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
- },
- )
- def gen_op_headers(
- g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
- ) -> list[str]:
- if isinstance(g, NativeFunctionsViewGroup):
- # view ops always get a functionalization kernel
- headers = [
- f"#include <ATen/ops/{g.view.root_name}_native.h>",
- f"#include <ATen/ops/{g.view.root_name}_ops.h>",
- ]
- if g.view_copy is not None:
- headers += [
- f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
- f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
- ]
- return headers
- elif isinstance(g, NativeFunctionsGroup):
- headers = [
- f"#include <ATen/ops/{g.functional.root_name}_native.h>",
- f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
- f"#include <ATen/ops/{g.out.root_name}_native.h>",
- f"#include <ATen/ops/{g.out.root_name}_ops.h>",
- ]
- if g.inplace is not None:
- headers += [
- f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
- f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
- ]
- if g.mutable is not None:
- headers += [
- f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
- f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
- ]
- return headers
- else:
- return [
- f"#include <ATen/ops/{g.root_name}_native.h>",
- f"#include <ATen/ops/{g.root_name}_ops.h>",
- ]
- def functionalization_env_callable(
- g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
- ) -> dict[str, list[str]]:
- return {
- "ops_headers": gen_op_headers(g),
- "func_definitions": gen_functionalization_definition(
- selector,
- g,
- ),
- "func_registrations": gen_functionalization_registration(
- selector,
- g,
- backend_indices[DispatchKey.CompositeImplicitAutograd],
- ),
- }
- all_groups: list[
- NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
- ] = list(structured_native_functions) + list(
- view_groups # type: ignore[assignment, arg-type, operator]
- )
- # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
- # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
- # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
- # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
- # Although this could go away long-term if we add a dedicated dispatch key for decompositions.
- structured_map: dict[OperatorName, NativeFunction] = {
- f.func.name: f
- for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
- }
- view_map: dict[OperatorName, NativeFunction] = {
- f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
- }
- all_groups.extend(
- f
- for f in native_functions
- if f.func.name not in structured_map and f.func.name not in view_map
- )
- cpu_fm.write_sharded(
- "RegisterFunctionalization.cpp",
- all_groups,
- key_fn=key_func,
- env_callable=functionalization_env_callable,
- num_shards=4,
- sharded_keys={
- "ops_headers",
- "func_definitions",
- "func_registrations",
- "func_add_back_views_definitions",
- "func_add_back_views_registrations",
- },
- )
- cpu_fm.write(
- "FunctionalInverses.h",
- lambda: {
- "view_inverse_declarations": list(
- mapMaybe(
- lambda g: gen_functionalization_view_inverse_declaration(
- selector, g
- ),
- view_groups,
- )
- )
- },
- )
- cpu_fm.write(
- "ViewMetaClasses.h",
- lambda: {
- "view_meta_declarations": list(
- concatMap(
- lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
- view_groups,
- )
- )
- },
- )
- cpu_fm.write(
- "ViewMetaClasses.cpp",
- lambda: {
- "view_meta_implementations": list(
- concatMap(
- lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
- view_groups,
- )
- ),
- "op_headers": list(concatMap(gen_op_headers, view_groups)),
- },
- )
- # Note [view_copy NativeFunctions]
- # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
- # needs to have a corresponding non-aliasing {view}_copy variant.
- # Backends that use functionalization and don't know how to handle aliasing ops
- # are expected to implement kernels for these {view}_copy kernels instead.
- # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
- # so we codegen the following:
- # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
- # These are never explicitly invoked by the functionalization pass,
- # but they could theoretically be called from user code (I added these kernels for completeness,
- # since the ops are part of the public API).
- # (2) A derivative formula for every {view}_copy operator
- # {view}_copy operators can reuse the same derivative formulas as their {view} op counterparts,
- # so rather than stamping all of the entries out in derivatives.yaml,
- # we codegen them in.
- # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
- cpu_fm.write(
- "CompositeViewCopyKernels.cpp",
- lambda: {
- "ops_headers": [
- "\n".join(
- f"#include <ATen/ops/{f.root_name}_ops.h>\n"
- # NB: this include is important as it ensures we
- # set the visibility on generated view_copy kernels
- # correctly
- f"#include <ATen/ops/{f.root_name}_native.h>"
- for f in (
- [g.view] if g.view_copy is None else [g.view, g.view_copy]
- )
- )
- for g in view_groups
- ]
- + [
- "\n".join(
- f"#include <ATen/ops/{f.root_name}_ops.h>\n"
- # NB: this include is also important for correct visibility
- f"#include <ATen/ops/{f.root_name}_native.h>"
- for f in [g.inplace, g.mutable, g.functional]
- if f is not None and "generated" not in f.tags
- )
- for g in structured_native_functions
- ],
- "CompositeViewCopyKernel_Definitions": list(
- mapMaybe(
- GenCompositeViewCopyKernel(
- backend_indices[
- DispatchKey.CompositeExplicitAutogradNonFunctional
- ]
- ),
- view_groups,
- )
- ),
- "GeneratedCompositeFunctional_Definitions": list(
- mapMaybe(
- gen_composite_functional_kernel,
- structured_native_functions,
- )
- ),
- "GeneratedCompositeOut_Definitions": list(
- mapMaybe(
- gen_composite_out_kernel,
- structured_native_functions,
- )
- ),
- },
- )
- def gen_declarations_yaml(
- cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
- ) -> None:
- cpu_fm.write(
- "Declarations.yaml",
- lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
- )
- def get_torchgen_root() -> Path:
- """
- If you're depending on torchgen out-of-tree, you can use the root to figure
- out the path to native_functions.yaml
- """
- return Path(__file__).parent.resolve()
- def main() -> None:
- parser = argparse.ArgumentParser(description="Generate ATen source files")
- parser.add_argument(
- "-s",
- "--source-path",
- help="path to source directory for ATen",
- default="aten/src/ATen",
- )
- parser.add_argument(
- "-o",
- "--output-dependencies",
- help="output a list of dependencies into the given file and exit",
- )
- parser.add_argument(
- "--dry-run",
- action="store_true",
- help="run without writing any files (still updates outputs)",
- )
- parser.add_argument(
- "--per-operator-headers",
- action="store_true",
- help="generate separate headers per operator in ATen/ops",
- )
- parser.add_argument(
- "-d",
- "--install-dir",
- "--install_dir",
- help="output directory",
- default="build/aten/src/ATen",
- )
- parser.add_argument(
- "--aoti-install-dir",
- "--aoti_install_dir",
- help="output directory for AOTInductor shim",
- default="torch/csrc/inductor/aoti_torch/generated",
- )
- parser.add_argument(
- "--rocm",
- action="store_true",
- help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
- )
- parser.add_argument(
- "--mps",
- action="store_true",
- help="Generate MPS registration code when set",
- )
- parser.add_argument(
- "--xpu",
- action="store_true",
- help="Generate XPU registration code when set",
- )
- parser.add_argument(
- "--mtia",
- action="store_true",
- help="Generate MTIA registration code when set",
- )
- # TODO: --op-registration-whitelist will be removed when all call-sites
- # for gen.py are moved over to using the operator YAML file for mobile
- # custom build.
- parser.add_argument(
- "--op-registration-whitelist",
- "--op_registration_whitelist",
- nargs="*",
- help="filter op registrations by the whitelist (if set); "
- "each item is `namespace`::`operator name` without overload name; "
- "e.g.: aten::empty aten::conv2d ...",
- )
- parser.add_argument(
- "--op-selection-yaml-path",
- "--op_selection_yaml_path",
- help="Provide a path to the operator selection (for custom build) YAML "
- "that contains the information about the set of selected operators "
- "and their categories (training, ...). Each operator is either a "
- "full operator name with overload or just a bare operator name. "
- "The operator names also contain the namespace prefix (e.g. aten::)",
- )
- parser.add_argument(
- "--backend-whitelist",
- "--backend_whitelist",
- nargs="*",
- help="filter dispatch backend by the whitelist (if set), "
- "e.g.: CPU CUDA QuantizedCPU ...",
- )
- parser.add_argument(
- "--static-dispatch-backend",
- "--static_dispatch_backend",
- nargs="*",
- help="generate static dispatch code for the specific backend (if set)",
- )
- parser.add_argument(
- "--skip-dispatcher-op-registration",
- "--skip_dispatcher_op_registration",
- action="store_true",
- help="Avoid registering operators into the dispatcher.",
- )
- parser.add_argument(
- "--force-schema-registration",
- "--force_schema_registration",
- action="store_true",
- help="force it to generate schema-only registrations for all ops, including"
- "those that are not listed on --op-registration-whitelist",
- )
- parser.add_argument(
- "--generate",
- type=str,
- nargs="*",
- choices=["headers", "sources", "declarations_yaml"],
- default=["headers", "sources", "declarations_yaml"],
- help="Generate only a subset of files",
- )
- parser.add_argument(
- "--update-aoti-c-shim",
- action="store_true",
- help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
- "WARNING: Do not use this unless you are sure what you are doing!!!",
- )
- parser.add_argument(
- "--extend-aoti-c-shim",
- action="store_true",
- help="This Flag indicates the generation of c shims for out-of-tree ATen ops,"
- "which is an extension to the In-tree ATen op c shims. This flag needs to be combined with"
- "---source-path=<out-of-tree native_functions.yaml>"
- "--aoti-install-dir=<in-tree aoti_install_dir>/extend"
- " default is torch/csrc/inductor/aoti_torch/generated/extend"
- "WARNING: Do not use this unless you are sure what you are doing!!!",
- )
- options = parser.parse_args()
- selector = get_custom_build_selector(
- options.op_registration_whitelist,
- options.op_selection_yaml_path,
- )
- native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
- tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
- from torchgen.model import dispatch_keys
- # Only a limited set of dispatch keys get CPUFunctions.h headers generated
- # for them; this is the set
- functions_keys = {
- DispatchKey.CPU,
- DispatchKey.CUDA,
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- DispatchKey.Meta,
- DispatchKey.MTIA,
- }
- aoti_backends = {
- DispatchKey.CPU,
- DispatchKey.CUDA,
- # None will generate the aten shim based on aten_shimified_ops
- # which does not bypass the dispatcher
- None,
- }
- # TODO: stop generating CUDA kernels for non-CUDA builds
- ignore_keys = set()
- MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
- if options.mps or options.update_aoti_c_shim:
- functions_keys.update(MPS_KEYS)
- aoti_backends.add(DispatchKey.MPS)
- else:
- ignore_keys.update(MPS_KEYS)
- dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
- if options.xpu or options.update_aoti_c_shim:
- functions_keys.add(DispatchKey.XPU)
- aoti_backends.add(DispatchKey.XPU)
- else:
- ignore_keys.add(DispatchKey.XPU)
- if DispatchKey.XPU in dispatch_keys:
- del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
- if not options.mtia:
- ignore_keys.add(DispatchKey.MTIA)
- if DispatchKey.MTIA in dispatch_keys:
- del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
- if options.backend_whitelist:
- dispatch_keys = [
- k
- for k in dispatch_keys
- if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
- ]
- parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
- valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
- native_functions, backend_indices = (
- parsed_yaml.native_functions,
- parsed_yaml.backend_indices,
- )
- grouped_native_functions = get_grouped_native_functions(native_functions)
- structured_native_functions = [
- g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
- ]
- native_functions_with_view_groups = get_grouped_by_view_native_functions(
- native_functions
- )
- view_groups = [
- g
- for g in native_functions_with_view_groups
- if isinstance(g, NativeFunctionsViewGroup)
- ]
- # NB: It is mandatory to NOT use os.path.join here, as the install directory
- # will eventually be ingested by cmake, which does not respect Windows style
- # path slashes. If you switch this to use os.path.join, you'll get an error
- # like:
- #
- # Syntax error in cmake code when parsing string
- #
- # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
- #
- # Invalid character escape '\c'.
- core_install_dir = f"{options.install_dir}/core"
- Path(core_install_dir).mkdir(parents=True, exist_ok=True)
- ops_install_dir = f"{options.install_dir}/ops"
- Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
- aoti_install_dir = f"{options.aoti_install_dir}"
- Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
- core_fm = make_file_manager(options=options, install_dir=core_install_dir)
- cpu_fm = make_file_manager(options=options)
- cpu_vec_fm = make_file_manager(options=options)
- cuda_fm = make_file_manager(options=options)
- ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
- aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
- device_fms = {"cuda": cuda_fm}
- if options.xpu:
- device_fms["xpu"] = make_file_manager(options=options)
- static_dispatch_idx: list[BackendIndex] = []
- if options.static_dispatch_backend:
- static_dispatch_idx = [
- backend_indices[DispatchKey.parse(key)]
- for key in options.static_dispatch_backend
- ]
- for key in options.static_dispatch_backend:
- dp_key = DispatchKey.parse(key)
- if dp_key not in functions_keys:
- functions_keys.add(dp_key)
- if "sources" in options.generate:
- gen_source_files(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- view_groups=view_groups,
- selector=selector,
- static_dispatch_idx=static_dispatch_idx,
- backend_indices=backend_indices,
- aoti_fm=aoti_fm,
- core_fm=core_fm,
- cpu_vec_fm=cpu_vec_fm,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=options.rocm,
- force_schema_registration=options.force_schema_registration,
- per_operator_headers=options.per_operator_headers,
- skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
- update_aoti_c_shim=options.update_aoti_c_shim,
- aoti_backends=aoti_backends,
- extend_aoti_c_shim=options.extend_aoti_c_shim,
- )
- if "headers" in options.generate:
- gen_headers(
- native_functions=native_functions,
- valid_tags=valid_tags,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- core_fm=core_fm,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- ops_fm=ops_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=options.rocm,
- per_operator_headers=options.per_operator_headers,
- )
- if "declarations_yaml" in options.generate:
- gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
- if options.output_dependencies:
- depfile_path = Path(options.output_dependencies).resolve()
- depfile_name = depfile_path.name
- depfile_stem = depfile_path.stem
- for fm, prefix in [
- (cpu_fm, ""),
- (cpu_vec_fm, "cpu_vec_"),
- (core_fm, "core_"),
- (ops_fm, "ops_"),
- ] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
- varname = prefix + depfile_stem
- path = depfile_path.parent / (prefix + depfile_name)
- fm.write_outputs(varname, str(path))
- if __name__ == "__main__":
- main()
|