gen_aoti_c_shim.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. from __future__ import annotations
  2. import difflib
  3. import os
  4. import textwrap
  5. from dataclasses import dataclass
  6. from typing import TYPE_CHECKING
  7. from torchgen.aoti.fallback_ops import aten_shimified_ops, inductor_fallback_ops
  8. from torchgen.api.types import DispatcherSignature
  9. from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
  10. from torchgen.context import method_with_native_function
  11. from torchgen.model import (
  12. Argument,
  13. BackendIndex,
  14. BaseTy,
  15. BaseType,
  16. DispatchKey,
  17. FunctionSchema,
  18. is_cuda_dispatch_key,
  19. ListType,
  20. NativeFunction,
  21. NativeFunctionsGroup,
  22. OperatorName,
  23. OptionalType,
  24. Type,
  25. Variant,
  26. )
  27. from torchgen.utils import FileManager, mapMaybe
  28. if TYPE_CHECKING:
  29. from collections.abc import Sequence
  30. from typing import Optional
  31. base_type_to_c_type = {
  32. BaseTy.Tensor: "AtenTensorHandle",
  33. BaseTy.bool: "int32_t", # Use int to pass bool
  34. BaseTy.int: "int64_t",
  35. BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt
  36. BaseTy.Scalar: "double", # Use double to pass both integer and floating point
  37. BaseTy.float: "double", # TODO: how about other floating point types?
  38. BaseTy.str: "const char*",
  39. BaseTy.DeviceIndex: "int32_t",
  40. BaseTy.Layout: "int32_t", # Represent enum as int
  41. BaseTy.MemoryFormat: "int32_t", # Represent enum as int
  42. BaseTy.ScalarType: "int32_t", # Represent enum as int
  43. BaseTy.Generator: "AtenGeneratorHandle",
  44. }
  45. base_type_to_aten_type = {
  46. BaseTy.Tensor: "at::Tensor",
  47. BaseTy.bool: "bool",
  48. BaseTy.int: "int64_t",
  49. BaseTy.SymInt: "c10::SymInt",
  50. BaseTy.Scalar: "c10::Scalar",
  51. BaseTy.float: "double",
  52. BaseTy.str: "::std::string_view",
  53. BaseTy.DeviceIndex: "c10::DeviceIndex",
  54. BaseTy.Layout: "c10::Layout",
  55. BaseTy.MemoryFormat: "c10::MemoryFormat",
  56. BaseTy.ScalarType: "c10::ScalarType",
  57. BaseTy.Generator: "at::Generator",
  58. }
  59. base_type_to_callsite_expr = {
  60. BaseTy.Tensor: "resolve_tensor_dispatch_flags",
  61. BaseTy.bool: "",
  62. BaseTy.int: "",
  63. BaseTy.SymInt: "",
  64. BaseTy.Scalar: "",
  65. BaseTy.float: "",
  66. BaseTy.str: "",
  67. BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
  68. BaseTy.Layout: "static_cast<c10::Layout>",
  69. BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
  70. BaseTy.ScalarType: "static_cast<c10::ScalarType>",
  71. BaseTy.Generator: "*generator_handle_to_generator_pointer",
  72. }
  73. # convert args to C types, names in declarations, and expressions in function bodies
  74. def convert_arg_type_and_name(
  75. typ: Type,
  76. name: str,
  77. is_write: bool = False,
  78. ) -> tuple[list[str], list[str], list[str], list[str]]:
  79. if isinstance(typ, BaseType):
  80. if typ.name in base_type_to_c_type:
  81. if typ.name == BaseTy.Tensor and is_write:
  82. # For output tensors, our normal call to resolve_tensor_dispatch_flags
  83. # results in an rvalue tensor, which can't be passed to at::Tensor&.
  84. # Override this case specifically.
  85. callsite_expr = [f"*tensor_handle_to_tensor_pointer({name})"]
  86. else:
  87. callsite_expr = [
  88. f"{base_type_to_callsite_expr[typ.name]}({name})"
  89. if base_type_to_callsite_expr[typ.name]
  90. else name
  91. ]
  92. return (
  93. [base_type_to_c_type[typ.name]],
  94. [name],
  95. [base_type_to_aten_type[typ.name]],
  96. callsite_expr,
  97. )
  98. elif typ.name == BaseTy.Device:
  99. return (
  100. ["int32_t", "int32_t"],
  101. [name, name + "_index_"],
  102. ["c10::Device"],
  103. [
  104. f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
  105. ],
  106. )
  107. else:
  108. # TODO: BaseTy.Dimname, etc.
  109. raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
  110. elif isinstance(typ, OptionalType):
  111. c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
  112. typ.elem, name
  113. )
  114. j = 0 # index for names
  115. new_aten_types = []
  116. new_callsite_exprs = []
  117. for aten_type in aten_types:
  118. # Use pointer to denote optional type
  119. c_types[j] = c_types[j] + "*"
  120. if aten_type.startswith("c10::ArrayRef<"):
  121. # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
  122. new_aten_types.append(f"::std::optional<{aten_type}>")
  123. base_type = aten_type[len("c10::ArrayRef<") : -1]
  124. new_callsite_exprs.append(
  125. f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})"
  126. )
  127. j += 2
  128. elif aten_type == "c10::Device":
  129. # Device is passed as device_type + device_index
  130. new_aten_types.append("::std::optional<c10::Device>")
  131. new_callsite_exprs.append(
  132. f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
  133. )
  134. j += 2
  135. elif aten_type == "at::Tensor":
  136. new_aten_types.append(f"::std::optional<{aten_type}>")
  137. new_callsite_exprs.append(f"resolve_tensor_dispatch_flags({names[j]})")
  138. j += 1
  139. else:
  140. new_aten_types.append(f"::std::optional<{aten_type}>")
  141. new_callsite_exprs.append(
  142. f"pointer_to_optional<{aten_type}>({names[j]})"
  143. )
  144. j += 1
  145. return (
  146. c_types,
  147. names,
  148. new_aten_types,
  149. new_callsite_exprs,
  150. )
  151. elif isinstance(typ, ListType):
  152. # Need to explicitly pass the list as pointer + length
  153. c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
  154. assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
  155. # The list content should never be modified
  156. c_types[0] = f"const {c_types[0]}*"
  157. c_types.append("int64_t")
  158. name = names[0]
  159. names.append(name + "_len_")
  160. atype = aten_types[0]
  161. callsite_exprs = []
  162. if atype == "bool":
  163. # no converter from std::vector<bool> to c10::ArrayRef<bool>
  164. # construct std::array<bool, N> instead
  165. assert typ.size is not None
  166. callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
  167. elif atype == "at::Tensor" and not is_write:
  168. callsite_exprs.append(
  169. f"resolve_tensor_list_dispatch_flags({name}, {name}_len_)"
  170. )
  171. elif atype == "::std::optional<at::Tensor>":
  172. # convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
  173. callsite_exprs.append(
  174. f"c10::List<{atype}>(c10::ArrayRef<{atype}>(resolve_tensor_list_dispatch_flags({name}, {name}_len_)))"
  175. )
  176. else:
  177. callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
  178. aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
  179. return (
  180. c_types,
  181. names,
  182. aten_types,
  183. callsite_exprs,
  184. )
  185. raise NotImplementedError(f"Argument type {repr(typ)} not supported!")
  186. def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
  187. return [typ + " " + name for typ, name in zip(types, names)]
  188. # Generate argument declarations and callsite expressions
  189. def gen_arguments(
  190. flat_arguments: Sequence[Argument], skipped_args: set[str]
  191. ) -> tuple[list[str], list[str]]:
  192. types: list[str] = []
  193. new_names: list[str] = []
  194. callsite_exprs: list[str] = []
  195. for arg in flat_arguments:
  196. if arg.name in skipped_args:
  197. callsite_exprs.append("std::nullopt")
  198. continue
  199. new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
  200. arg.type, arg.name, arg.is_write
  201. )
  202. types.extend(new_types)
  203. new_names.extend(names)
  204. callsite_exprs.extend(new_callsite_exprs)
  205. return zip_type_and_name(types, new_names), callsite_exprs
  206. # Return values are passed out as pointer arguments because all the C shim functions
  207. # are expected to return AOTITorchError.
  208. # Generate returns as declarations and callsite expressions
  209. def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
  210. types = []
  211. names = []
  212. for idx, ret in enumerate(schema.returns):
  213. names.append(f"ret{idx}")
  214. if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
  215. types.append(base_type_to_c_type[ret.type.name] + "*")
  216. else:
  217. raise NotImplementedError(
  218. f"TODO: add support for return type {repr(ret.type)}"
  219. )
  220. def convert_return(typ: BaseType, val: str) -> str:
  221. if typ.name == BaseTy.Tensor:
  222. return f"new_tensor_handle(std::move({val}))"
  223. elif typ.name == BaseTy.SymInt:
  224. return f"{val}.expect_int()"
  225. elif typ.name == BaseTy.Scalar:
  226. return f"{val}.toDouble()"
  227. else:
  228. return val
  229. ret_pointer_can_be_null = False
  230. unambiguous_name = schema.name.unambiguous_name()
  231. for name in (
  232. "_functional_sym_constrain_range",
  233. "_scaled_dot_product_cudnn_attention",
  234. "_scaled_dot_product_efficient_attention_backward",
  235. "_scaled_dot_product_efficient_attention",
  236. "_scaled_dot_product_flash_attention",
  237. "_scaled_dot_product_fused_attention_overrideable",
  238. "_thhn_fused_lstm_cell_backward_impl",
  239. "convolution_backward",
  240. "grid_sampler_2d_backward",
  241. "grid_sampler_3d_backward",
  242. "linear_backward",
  243. ):
  244. if name in unambiguous_name:
  245. ret_pointer_can_be_null = True
  246. break
  247. callsite_exprs: list[str] = []
  248. for idx, ret in enumerate(schema.returns):
  249. tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
  250. assert isinstance(ret.type, BaseType)
  251. rval = convert_return(ret.type, tmp)
  252. if ret_pointer_can_be_null:
  253. callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
  254. else:
  255. callsite_exprs.append(f"*{names[idx]} = {rval};")
  256. return zip_type_and_name(types, names), callsite_exprs
  257. # gen.py generates header first and then src, so caching the result here to avoid duplicate work
  258. declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
  259. def gen_declaration_and_definition(
  260. schema: FunctionSchema,
  261. device: str,
  262. backend_call: str,
  263. version_info: dict[str, list[str]],
  264. ) -> tuple[str, str]:
  265. base_name = schema.name.unambiguous_name()
  266. global declaration_definition_cache
  267. if (base_name, device, backend_call) in declaration_definition_cache:
  268. return declaration_definition_cache[(base_name, device, backend_call)]
  269. # Check the validity of version_info. The format should look like
  270. # {"v2" : ["new_arg1"], "v3": ["new_arg2, new_arg3"]}.
  271. indexed_version_info: dict[int, list[str]] = {1: []}
  272. for ver_str, new_args in sorted(version_info.items()):
  273. assert ver_str.startswith("v"), (
  274. f"Version number for {base_name} is {ver_str}, not starting with 'v'"
  275. )
  276. try:
  277. ver_id = int(ver_str[1:])
  278. except ValueError as e:
  279. raise AssertionError(
  280. f"Version number for {base_name} is {ver_str}, not a valid integer after 'v'"
  281. ) from e
  282. assert ver_id not in indexed_version_info, (
  283. f"{ver_str} for {base_name} has already been defined"
  284. )
  285. indexed_version_info[ver_id] = new_args
  286. declarations: list[str] = []
  287. definitions: list[str] = []
  288. skipped_args: set[str] = set()
  289. for ver_id, new_args in sorted(indexed_version_info.items(), reverse=True):
  290. # Iterate in the reverse order, so the latest version of an op will get generated first
  291. # with all the arguments included, while a set of to-be-trimmed args is carried down
  292. # to generate earlier version of the op.
  293. func_name = base_name if ver_id == 1 else f"{base_name}_v{ver_id}"
  294. if schema.is_out_fn():
  295. # out_variant has out arguments in the front, and it's ok to ignore return values
  296. # because C shim functions only return AOTITorchError
  297. args, callsite_exprs = gen_arguments(
  298. [*schema.arguments.out, *schema.arguments.flat_non_out], skipped_args
  299. )
  300. ret_assignments: list[str] = []
  301. else:
  302. args, callsite_exprs = gen_arguments(
  303. schema.arguments.flat_all, skipped_args
  304. )
  305. # ignore return values for inplace ops
  306. ret_declarations, ret_assignments = (
  307. ([], []) if schema.name.name.inplace else gen_returns(schema)
  308. )
  309. args.extend(ret_declarations)
  310. declaration = textwrap.dedent(
  311. f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
  312. )
  313. tmp_result = "auto tmp_result = " if ret_assignments else ""
  314. indent = "\t\t"
  315. ret_assignments_str = (
  316. "\n".join(indent + r for r in ret_assignments) if ret_assignments else ""
  317. )
  318. definition = (
  319. textwrap.dedent(f"""
  320. {declaration} {{
  321. AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
  322. {tmp_result}{backend_call}(
  323. {", ".join(callsite_exprs)}
  324. );
  325. """)
  326. + ret_assignments_str
  327. + textwrap.dedent("""
  328. });
  329. }
  330. """)
  331. )
  332. skipped_args.update(new_args)
  333. declarations.append(f"AOTI_TORCH_EXPORT {declaration};")
  334. definitions.append(definition)
  335. declaration_definition_cache[(base_name, device, backend_call)] = (
  336. "\n".join(declarations),
  337. "\n".join(definitions),
  338. )
  339. return declaration_definition_cache[(base_name, device, backend_call)]
  340. def gen_static_dispatch_backend_call_signature(
  341. sig: CppSignature | DispatcherSignature,
  342. f: NativeFunction,
  343. ) -> CppSignature:
  344. sig = DispatcherSignature.from_schema(f.func)
  345. cpp_sigs = CppSignatureGroup.from_native_function(
  346. f, method=False, fallback_binding=False
  347. )
  348. if sig.symint and f.func.has_symint():
  349. cpp_sig = cpp_sigs.symint_signature
  350. else:
  351. cpp_sig = cpp_sigs.signature
  352. assert cpp_sig is not None
  353. return cpp_sig
  354. def gen_static_dispatch_backend_call(
  355. f: NativeFunction,
  356. backend_index: Optional[BackendIndex] = None,
  357. ) -> str:
  358. sig = DispatcherSignature.from_schema(f.func)
  359. cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
  360. if backend_index is None:
  361. # Check if this is a symint function and if the function only has method variants
  362. if sig.symint and f.func.has_symint():
  363. has_function_variant = Variant.function in f.variants
  364. if not has_function_variant:
  365. # Functions with both function and method variants can use the at::{*}_symint version
  366. # (e.g., narrow -> at::narrow_symint), BUT
  367. # Method-only functions with symint parameters should use at::symint:: namespace
  368. # Remove the _symint suffix since at::symint:: namespace uses the base name
  369. # (e.g., new_empty -> at::symint::new_empty<c10::SymInt>)
  370. base_name = cpp_sig.name()
  371. base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix
  372. return f"at::symint::{base_name}<c10::SymInt>"
  373. return f"at::{cpp_sig.name()}"
  374. else:
  375. return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
  376. def get_backend_index_for_aoti(
  377. func: NativeFunction,
  378. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  379. dispatch_key: Optional[DispatchKey],
  380. backend_indices: dict[DispatchKey, BackendIndex],
  381. extend_aoti_c_shim: bool,
  382. ) -> BackendIndex | None:
  383. backend_index = None
  384. if dispatch_key is None:
  385. return backend_index
  386. if backend_indices[dispatch_key].has_kernel(func) or (
  387. func.structured_delegate is not None
  388. and func.structured_delegate in func_group_mapping
  389. and backend_indices[dispatch_key].has_kernel(
  390. func_group_mapping[func.structured_delegate]
  391. )
  392. ):
  393. backend_index = backend_indices[dispatch_key]
  394. else:
  395. # for the extend out-of-tree kernels, we don't need to
  396. # duplicatly create C shim wrappers for other dispatch keys
  397. if extend_aoti_c_shim:
  398. return backend_index
  399. elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
  400. # We need to create C shim wrappers for CompositeExplicitAutograd kernels
  401. backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
  402. elif backend_indices[
  403. DispatchKey.CompositeExplicitAutogradNonFunctional
  404. ].has_kernel(func):
  405. # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
  406. backend_index = backend_indices[
  407. DispatchKey.CompositeExplicitAutogradNonFunctional
  408. ]
  409. elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
  410. backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
  411. return backend_index
  412. def get_header_for_aoti(
  413. func: NativeFunction,
  414. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  415. dispatch_key: Optional[DispatchKey],
  416. backend_indices: dict[DispatchKey, BackendIndex],
  417. extend_aoti_c_shim: bool,
  418. ) -> str | None:
  419. backend_index = get_backend_index_for_aoti(
  420. func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
  421. )
  422. if backend_index is None:
  423. if dispatch_key is None:
  424. return f"#include <ATen/ops/{func.root_name}.h>"
  425. return None
  426. return f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
  427. def get_fallback_op_name(func: NativeFunction) -> str:
  428. return (
  429. f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
  430. if func.func.name.overload_name
  431. else f"{func.namespace}.{func.func.name.name}.default"
  432. )
  433. def gen_c_shim(
  434. func: NativeFunction,
  435. version_info: dict[str, list[str]],
  436. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  437. dispatch_key: Optional[DispatchKey],
  438. backend_indices: dict[DispatchKey, BackendIndex],
  439. header: bool,
  440. extend_aoti_c_shim: bool,
  441. ) -> str | None:
  442. backend_index = get_backend_index_for_aoti(
  443. func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
  444. )
  445. if backend_index is None and dispatch_key is not None:
  446. return None
  447. schema = func.func
  448. device = "aten" if dispatch_key is None else dispatch_key.lower()
  449. backend_call = gen_static_dispatch_backend_call(
  450. func,
  451. backend_index,
  452. )
  453. try:
  454. if header:
  455. declaration, _ = gen_declaration_and_definition(
  456. schema, device, backend_call, version_info
  457. )
  458. return declaration
  459. else:
  460. _, definition = gen_declaration_and_definition(
  461. schema, device, backend_call, version_info
  462. )
  463. return definition
  464. except NotImplementedError:
  465. return None
  466. @dataclass(frozen=True)
  467. class ShimGenerator:
  468. inductor_fallback_ops: dict[str, dict[str, list[str]]]
  469. func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
  470. dispatch_key: Optional[DispatchKey]
  471. backend_indices: dict[DispatchKey, BackendIndex]
  472. header: bool # True to generate .h and False to generate .cpp
  473. extend_aoti_c_shim: bool
  474. @method_with_native_function
  475. def __call__(
  476. self,
  477. func: NativeFunction,
  478. ) -> str | None:
  479. version_info = self.inductor_fallback_ops[get_fallback_op_name(func)]
  480. result = gen_c_shim(
  481. func,
  482. version_info,
  483. self.func_group_mapping,
  484. self.dispatch_key,
  485. self.backend_indices,
  486. self.header,
  487. self.extend_aoti_c_shim,
  488. )
  489. return result
  490. def gen_aoti_c_shim(
  491. native_functions: Sequence[NativeFunction],
  492. inductor_fallback_ops: dict[str, dict[str, list[str]]],
  493. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  494. dispatch_key: Optional[DispatchKey],
  495. backend_indices: dict[DispatchKey, BackendIndex],
  496. header: bool,
  497. extend_aoti_c_shim: bool,
  498. includes: str = "",
  499. ) -> str:
  500. body = "\n".join(
  501. list(
  502. mapMaybe(
  503. ShimGenerator(
  504. inductor_fallback_ops,
  505. func_group_mapping,
  506. dispatch_key,
  507. backend_indices,
  508. header,
  509. extend_aoti_c_shim,
  510. ),
  511. native_functions,
  512. )
  513. )
  514. )
  515. device = "aten" if dispatch_key is None else dispatch_key.lower()
  516. include_device_functions = (
  517. "#include <ATen/Functions.h>"
  518. if dispatch_key is None
  519. else f"#include <ATen/{str(dispatch_key)}Functions.h>"
  520. )
  521. aten_warning = (
  522. (
  523. "\n\n// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py\n"
  524. )
  525. if dispatch_key is None
  526. else ""
  527. )
  528. warning = """
  529. // WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
  530. // See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
  531. if header:
  532. return (
  533. warning
  534. + aten_warning
  535. + textwrap.dedent("""
  536. #pragma once
  537. #include <torch/csrc/inductor/aoti_torch/c/shim.h>
  538. #ifdef __cplusplus
  539. extern "C" {
  540. #endif
  541. """)
  542. + body
  543. + textwrap.dedent("""
  544. #ifdef __cplusplus
  545. } // extern "C"
  546. #endif
  547. """)
  548. )
  549. else:
  550. return (
  551. warning
  552. + aten_warning
  553. + textwrap.dedent(f"""
  554. #include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
  555. #include <torch/csrc/inductor/aoti_torch/utils.h>
  556. #ifndef AT_PER_OPERATOR_HEADERS
  557. {include_device_functions}
  558. #include <ATen/CompositeExplicitAutogradFunctions.h>
  559. #include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
  560. #include <ATen/CompositeImplicitAutogradFunctions.h>
  561. #else
  562. """)
  563. + includes
  564. + textwrap.dedent("""
  565. #endif // AT_PER_OPERATOR_HEADERS
  566. using namespace torch::aot_inductor;
  567. """)
  568. + body
  569. )
  570. def gen_aoti_c_shim_files(
  571. aoti_fm: FileManager,
  572. aoti_backends: set[Optional[DispatchKey]],
  573. native_functions: Sequence[NativeFunction],
  574. backend_indices: dict[DispatchKey, BackendIndex],
  575. structured_native_functions: Sequence[NativeFunctionsGroup],
  576. extra_cuda_headers: str,
  577. extend_aoti_c_shim: bool,
  578. update_aoti_c_shim: bool,
  579. ) -> None:
  580. structured_func_group_dict = {}
  581. for func_group in structured_native_functions:
  582. for func in func_group.functions():
  583. if func.structured_delegate is not None:
  584. structured_func_group_dict[func.structured_delegate] = func_group
  585. break
  586. for dispatch_key in aoti_backends:
  587. # Use aten_shimified_ops for the aten backend, inductor_fallback_ops for others
  588. fallback_ops_dict = (
  589. aten_shimified_ops if dispatch_key is None else inductor_fallback_ops
  590. )
  591. fallbacks = {}
  592. for func in native_functions:
  593. op_name = get_fallback_op_name(func)
  594. if op_name in fallback_ops_dict:
  595. fallbacks[op_name] = func
  596. fallback_native_functions = tuple(
  597. value for _, value in sorted(fallbacks.items())
  598. )
  599. # Use "aten" as the device name when dispatch_key is Generic
  600. device_name = "aten" if dispatch_key is None else dispatch_key.lower()
  601. # header files were checked in for ABI-compatiblilty checking
  602. header_file_name = f"c_shim_{device_name}.h"
  603. new_header = gen_aoti_c_shim(
  604. fallback_native_functions,
  605. fallback_ops_dict,
  606. structured_func_group_dict,
  607. dispatch_key,
  608. backend_indices,
  609. header=True,
  610. extend_aoti_c_shim=extend_aoti_c_shim,
  611. includes="",
  612. )
  613. if update_aoti_c_shim:
  614. aoti_fm.write(
  615. header_file_name,
  616. lambda: new_header,
  617. )
  618. else:
  619. try:
  620. with open(
  621. os.path.join(aoti_fm.install_dir, header_file_name)
  622. ) as old_file:
  623. old_header = old_file.read()
  624. if old_header != new_header:
  625. diff = "\n".join(
  626. difflib.unified_diff(
  627. old_header.splitlines(),
  628. new_header.splitlines(),
  629. fromfile="expected",
  630. tofile="actual",
  631. lineterm="",
  632. )
  633. )
  634. raise RuntimeError(f"""
  635. The generated AOTInductor C shim header files have unexpectedly changed. This
  636. indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
  637. Only in a limited number of situations, this is allowed:
  638. 1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
  639. If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
  640. existing C shim header files.
  641. 2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
  642. change in the AOTInductor land. You need to annotate the new default argument in
  643. torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
  644. update the C shim header files by creating different versions of the fallback op. See
  645. https://github.com/pytorch/pytorch/pull/154848 as an example.
  646. {diff}
  647. """)
  648. except FileNotFoundError:
  649. print(
  650. f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
  651. )
  652. # cpp files are always generated on-the-fly
  653. def headers_for_aoti() -> str:
  654. headers = []
  655. for func in fallback_native_functions:
  656. header = get_header_for_aoti(
  657. func,
  658. structured_func_group_dict,
  659. dispatch_key,
  660. backend_indices,
  661. extend_aoti_c_shim=extend_aoti_c_shim,
  662. )
  663. if header is not None:
  664. headers.append(header)
  665. return "\n".join(sorted(set(headers)))
  666. extra_headers = (
  667. extra_cuda_headers
  668. if dispatch_key is not None and is_cuda_dispatch_key(dispatch_key)
  669. else ""
  670. )
  671. aoti_fm.write(
  672. f"c_shim_{device_name}.cpp",
  673. lambda: gen_aoti_c_shim(
  674. fallback_native_functions,
  675. fallback_ops_dict,
  676. structured_func_group_dict,
  677. dispatch_key,
  678. backend_indices,
  679. header=False,
  680. extend_aoti_c_shim=extend_aoti_c_shim,
  681. includes=headers_for_aoti() + "\n" + extra_headers,
  682. ),
  683. )