gen_lazy_tensor.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. from __future__ import annotations
  2. import argparse
  3. import os
  4. from collections import namedtuple
  5. from pathlib import Path
  6. from typing import Any, Callable, TYPE_CHECKING
  7. import yaml
  8. import torchgen.dest as dest
  9. from torchgen.api.lazy import setValueT
  10. from torchgen.api.types import BaseCppType
  11. from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
  12. from torchgen.gen import get_grouped_native_functions, parse_native_yaml
  13. from torchgen.gen_backend_stubs import (
  14. error_on_missing_kernels,
  15. gen_dispatcher_registrations,
  16. gen_dispatchkey_nativefunc_headers,
  17. parse_backend_yaml,
  18. )
  19. from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
  20. from torchgen.selective_build.selector import SelectiveBuilder
  21. from torchgen.utils import FileManager, NamespaceHelper
  22. from torchgen.yaml_utils import YamlLoader
  23. if TYPE_CHECKING:
  24. from collections.abc import Iterable, Iterator, Sequence
  25. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  26. #
  27. # Lazy Tensor Codegen
  28. #
  29. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  30. # Overview
  31. # ~~~~~~~~
  32. #
  33. # This codegen script builds on existing data models and helpers used
  34. # by all ATen backends, and adds new functionality specific to lazy
  35. # tensor backends.
  36. #
  37. # Inputs:
  38. # - <backend>_native_functions.yaml: controls which operators are
  39. # supported by the backend.
  40. #
  41. # Outputs:
  42. # (for all backends)
  43. # <DispatchKey>Ir.h defines Lazy IR classes to be constructed during tracing
  44. # - opt-in: also generate 'lowering' methods for the TorchScript backend only
  45. # <DispatchKey>NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
  46. # - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
  47. # <DispatchKey>NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
  48. # ops
  49. #
  50. # Register<DispatchKey>.cpp registers all op implementations with the dispatcher
  51. # RegisterAutograd<DispatchKey>.cpp registers all autograd implementations with the dispatcher
  52. #
  53. # Validation Helpers:
  54. # - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
  55. # implementations in torch/csrc/lazy/core/shape_inference.*
  56. # - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
  57. # (non-codegen) implementation file
  58. #
  59. #
  60. # About the Data Model
  61. # ~~~~~~~~~~~~~~~~~~~~
  62. #
  63. # Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
  64. # we care about. In this case, the <backend>_native_functions yaml defines a subset of the core operators
  65. # (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
  66. # Backends can list ops in two categories:
  67. # - `supported` ops require hand-implementations but still get codegenned declarations and registrations
  68. # - `full_codegen` ops get implementations (and IR classes) generated too
  69. #
  70. # Each native function is modeled as an object with a schema, and each schema has objects representing their
  71. # arguments. Much of the codegen is manipulation of the arguments and their types. For example, lazy tensor
  72. # backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
  73. # types (stringref) with actual string objects, and this is done by manipulating the data model objects.
  74. # - see api/lazy.py for the lazy data model
  75. #
  76. # Once the data model is set up, the rest of this script processes a number of templates for output CPP file
  77. # and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`. These
  78. # helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
  79. #
  80. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  81. # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
  82. # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
  83. ParsedExternalYaml = namedtuple(
  84. "ParsedExternalYaml",
  85. ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
  86. )
  87. def parse_native_functions_keys(
  88. backend_yaml_path: str,
  89. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  90. ) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
  91. with open(backend_yaml_path) as f:
  92. yaml_values = yaml.load(f, Loader=YamlLoader)
  93. assert isinstance(yaml_values, dict)
  94. full_codegen = yaml_values.pop("full_codegen", [])
  95. non_native = yaml_values.pop("non_native", [])
  96. ir_gen = yaml_values.pop("ir_gen", [])
  97. assert isinstance(full_codegen, list)
  98. assert isinstance(non_native, list)
  99. assert isinstance(ir_gen, list)
  100. full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
  101. ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
  102. return full_codegen_opnames, non_native, ir_gen_opnames
  103. def validate_shape_inference_header(
  104. shape_inference_hdr: str, expected_shape_infr_decls: list[str]
  105. ) -> None:
  106. try:
  107. with open(shape_inference_hdr) as f:
  108. shape_infr_decls = f.read()
  109. shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
  110. except OSError as e:
  111. raise AssertionError(
  112. f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
  113. ) from e
  114. # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
  115. missing_decls = [
  116. decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
  117. ]
  118. if missing_decls:
  119. raise Exception( # noqa: TRY002
  120. f"""Missing shape inference function.\n
  121. Please add declare this function in {shape_inference_hdr}:\n
  122. and implement it in the corresponding shape_inference.cpp file.\n
  123. {os.linesep.join(missing_decls)}"""
  124. )
  125. # Some helper functions for the codegen.
  126. def get_ltc_helper_fns() -> str:
  127. return """\
  128. at::Tensor to_meta(const at::Tensor& tensor) {
  129. // undefined tensors can't be converted to the meta device, since they don't have sizes/strides
  130. if (!tensor.defined()) return tensor;
  131. auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
  132. /*dtype=*/tensor.scalar_type(), /*layout=*/tensor.layout(), \
  133. /*device=*/c10::Device(c10::kMeta), /*pin_memory=*/std::nullopt);
  134. // needs to handle wrapped numbers, so dtype promotion works properly.
  135. if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
  136. out.unsafeGetTensorImpl()->set_wrapped_number(true);
  137. }
  138. return out;
  139. }
  140. std::optional<at::Tensor> to_meta(const std::optional<at::Tensor>& tensor) {
  141. if (tensor.has_value()) {
  142. return to_meta(*tensor);
  143. }
  144. return std::nullopt;
  145. }
  146. std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
  147. std::vector<at::Tensor> outs;
  148. outs.reserve(t_list.size());
  149. for (const auto& tensor : t_list) {
  150. outs.push_back(to_meta(tensor));
  151. }
  152. return outs;
  153. }
  154. """
  155. class default_args:
  156. node_base: str = "Node"
  157. node_base_hdr: str | None = None
  158. shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
  159. tensor_class: str = "torch::lazy::LazyTensor"
  160. tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
  161. lazy_ir_generator: type[GenLazyIR] = GenLazyIR
  162. native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
  163. GenLazyNativeFuncDefinition
  164. )
  165. backend_name: str = "TorchScript"
  166. def main() -> None:
  167. parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
  168. parser.add_argument(
  169. "-s",
  170. "--source-yaml",
  171. "--source_yaml",
  172. help="path to source yaml file containing operator external definitions",
  173. )
  174. parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
  175. parser.add_argument(
  176. "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
  177. )
  178. parser.add_argument(
  179. "--impl-path",
  180. "--impl_path",
  181. type=str,
  182. default=None,
  183. help="path to the source C++ file containing kernel definitions",
  184. )
  185. parser.add_argument(
  186. "--gen-ts-lowerings",
  187. "--gen_ts_lowerings",
  188. action="store_true",
  189. help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
  190. )
  191. parser.add_argument(
  192. "--node-base",
  193. "--node_base",
  194. type=str,
  195. default=default_args.node_base,
  196. help="Name of backend specific custom Lazy IR Node base class",
  197. )
  198. parser.add_argument(
  199. "--node-base-hdr",
  200. "--node_base_hdr",
  201. type=str,
  202. default=default_args.node_base_hdr,
  203. help="Path to header file defining custom Lazy IR Node base class",
  204. )
  205. parser.add_argument(
  206. "--shape-inference-hdr",
  207. "--shape_inference_hdr",
  208. type=str,
  209. default=default_args.shape_inference_hdr,
  210. help="Path to header file defining custom Lazy shape inference functions",
  211. )
  212. parser.add_argument(
  213. "--tensor-class",
  214. "--tensor_class",
  215. type=str,
  216. default=default_args.tensor_class,
  217. help="Name of backend specific custom Lazy Tensor class",
  218. )
  219. parser.add_argument(
  220. "--tensor-class-hdr",
  221. "--tensor_class_hdr",
  222. type=str,
  223. default=default_args.tensor_class_hdr,
  224. help="Path to header file defining custom Lazy Tensor class",
  225. )
  226. parser.add_argument(
  227. "--backend-name",
  228. "--backend_name",
  229. type=str,
  230. default=default_args.backend_name,
  231. help="Name of the backend to generate",
  232. )
  233. options = parser.parse_args()
  234. # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
  235. torch_root = Path(__file__).absolute().parents[2]
  236. aten_path = str(torch_root / "aten" / "src" / "ATen")
  237. lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
  238. if options.gen_ts_lowerings:
  239. lazy_ir_generator = GenTSLazyIR
  240. native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
  241. default_args.native_func_definition_generator
  242. )
  243. run_gen_lazy_tensor(
  244. aten_path,
  245. options.source_yaml,
  246. options.output_dir,
  247. options.dry_run,
  248. options.impl_path,
  249. options.node_base,
  250. options.node_base_hdr,
  251. options.tensor_class,
  252. options.tensor_class_hdr,
  253. options.shape_inference_hdr,
  254. lazy_ir_generator,
  255. native_func_definition_generator,
  256. options.backend_name,
  257. )
  258. def run_gen_lazy_tensor(
  259. aten_path: str,
  260. source_yaml: str,
  261. output_dir: str,
  262. dry_run: bool,
  263. impl_path: str | None,
  264. node_base: str = default_args.node_base,
  265. node_base_hdr: str | None = default_args.node_base_hdr,
  266. tensor_class: str = default_args.tensor_class,
  267. tensor_class_hdr: str = default_args.tensor_class_hdr,
  268. shape_inference_hdr: str = default_args.shape_inference_hdr,
  269. lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
  270. native_func_definition_generator: type[
  271. GenLazyNativeFuncDefinition
  272. ] = default_args.native_func_definition_generator,
  273. # build_in_tree is true for TS backend and affects include paths
  274. build_in_tree: bool = False,
  275. # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
  276. # it must match how ATen was built
  277. per_operator_headers: bool = False,
  278. backend_name: str = default_args.backend_name,
  279. gen_forced_fallback_code: bool = False,
  280. use_lazy_shape: bool = True,
  281. # the following arguments are temporary customization points for xla backend migration.
  282. # do not rely on them otherwise, they should be removed once migration is complete
  283. backend_namespace: str = "torch::lazy",
  284. get_tensorlist: str = "GetTensorList",
  285. get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
  286. try_get_tensor: str = "TryGetLtcTensor",
  287. metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
  288. create_tensor: str = "LazyTensor::Create",
  289. create_from_first_tensor: bool = False,
  290. create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
  291. tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
  292. lazy_value_class: str = "torch::lazy::Value",
  293. lazy_tensor_ptr: str = "LazyTensorPtr",
  294. get_device_fn: str = "torch::lazy::GetBackendDevice",
  295. ) -> None:
  296. lv_tokens = lazy_value_class.split("::")
  297. lv_class = lv_tokens[-1]
  298. lv_ns = "::".join(lv_tokens[:-1])
  299. setValueT(BaseCppType(lv_ns, lv_class))
  300. template_dir = os.path.join(aten_path, "templates")
  301. def make_file_manager(install_dir: str) -> FileManager:
  302. return FileManager(
  303. install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
  304. )
  305. fm = make_file_manager(output_dir)
  306. native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
  307. tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
  308. parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
  309. native_functions, backend_indices = (
  310. parsed_yaml.native_functions,
  311. parsed_yaml.backend_indices,
  312. )
  313. grouped_native_functions = get_grouped_native_functions(native_functions)
  314. def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
  315. """
  316. We sort the native function because of the note in concat_map_codegen.
  317. TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
  318. """
  319. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  320. return str(func.name.name)
  321. grouped_native_functions = sorted(
  322. grouped_native_functions, key=sort_native_function
  323. )
  324. parsed_backend_yaml = parse_backend_yaml(
  325. source_yaml, grouped_native_functions, backend_indices
  326. )
  327. backend_key = parsed_backend_yaml.backend_key
  328. autograd_key = parsed_backend_yaml.autograd_key
  329. cpp_namespace = parsed_backend_yaml.cpp_namespace
  330. backend_indices = parsed_backend_yaml.backend_indices
  331. # the following 3 keys are all processed differently
  332. # for full_codegen, we generate IR, kernels, etc
  333. # for ir_gen, we generate only IR
  334. # non_native is used to register kernels not declared in
  335. # native_functions.yaml
  336. full_codegen, non_native, ir_gen = parse_native_functions_keys(
  337. source_yaml, grouped_native_functions
  338. )
  339. def concat_map_codegen(
  340. func: Callable[[NativeFunction], Sequence[str]],
  341. xs: Iterable[NativeFunctionsGroup | NativeFunction],
  342. ops_list: list[OperatorName] = full_codegen,
  343. ) -> Iterator[str]:
  344. """
  345. We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
  346. only code-gen additional entries for the inplace variant for the native functions.
  347. """
  348. for x in xs:
  349. fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
  350. for f in fs:
  351. if f.func.name in ops_list:
  352. yield from func(f)
  353. selector = SelectiveBuilder.get_nop_selector()
  354. assert backend_key is not None
  355. class_name = backend_indices[backend_key].native_function_class_name()
  356. if impl_path is not None:
  357. error_on_missing_kernels(
  358. native_functions,
  359. backend_indices,
  360. backend_key,
  361. autograd_key,
  362. class_name,
  363. impl_path,
  364. full_codegen,
  365. )
  366. """ Validate Shape Inference Definitions
  367. Generated lazy native functions all perform shape inference, by first using a meta:: kernel
  368. if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator
  369. knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
  370. so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev
  371. to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
  372. the expected signature which can be copy-pasted into shape_inference.h.
  373. compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
  374. to structured kernels.
  375. See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
  376. """
  377. if shape_inference_hdr is not None:
  378. expected_shape_infr_decls = list(
  379. concat_map_codegen(
  380. dest.GenLazyShapeInferenceDefinition(
  381. backend_indices[backend_key], tensor_class
  382. ),
  383. grouped_native_functions,
  384. )
  385. )
  386. validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
  387. assert class_name is not None
  388. # Generate nativefunction declarations
  389. # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
  390. # may want to register their own lazy kernels instead of registering the TS ones.
  391. # The registration will lazily happen when init_ts_backend is called.
  392. gen_dispatchkey_nativefunc_headers(
  393. fm,
  394. class_name,
  395. cpp_namespace,
  396. backend_indices,
  397. grouped_native_functions,
  398. backend_key,
  399. autograd_key,
  400. backend_name,
  401. )
  402. # Generate Dispatcher registrations which hook up the nativefunctions
  403. for dispatch_key in (
  404. [backend_key] if autograd_key is None else [backend_key, autograd_key]
  405. ):
  406. gen_dispatcher_registrations(
  407. fm,
  408. output_dir,
  409. class_name,
  410. backend_indices,
  411. grouped_native_functions,
  412. backend_key,
  413. dispatch_key,
  414. selector,
  415. build_in_tree=build_in_tree,
  416. per_operator_headers=per_operator_headers,
  417. backend_name=backend_name,
  418. eager_registration=False,
  419. )
  420. # Generate native function impls that build IR nodes
  421. ns_helper = NamespaceHelper(cpp_namespace)
  422. fm.write_with_template(
  423. f"{backend_key}NativeFunctions.cpp",
  424. "DispatchKeyNativeFunctions.cpp",
  425. lambda: {
  426. "includes": [
  427. f"#include <{path}>"
  428. for path in [
  429. tensor_class_hdr,
  430. shape_inference_hdr,
  431. "ATen/Functions.h",
  432. "ATen/native/TensorConversions.h",
  433. "ATen/NativeFunctions.h",
  434. "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
  435. "ATen/MetaFunctions.h",
  436. "ATen/Operators.h",
  437. "ATen/native/CPUFallback.h",
  438. "torch/csrc/lazy/core/ir_builder.h",
  439. "torch/csrc/lazy/core/lazy_graph_executor.h",
  440. "torch/csrc/lazy/core/metrics.h",
  441. "torch/csrc/lazy/core/shape.h",
  442. f"{output_dir}/{backend_key}NativeFunctions.h",
  443. f"{output_dir}/LazyIr.h",
  444. ]
  445. + (
  446. ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
  447. if gen_forced_fallback_code
  448. else []
  449. )
  450. ],
  451. "helper_fns": get_ltc_helper_fns(),
  452. "native_functions_include": "",
  453. "namespace_prologue": ns_helper.prologue,
  454. "namespace_epilogue": ns_helper.epilogue,
  455. "native_function_definitions": list(
  456. concat_map_codegen(
  457. native_func_definition_generator(
  458. f"{backend_key}NativeFunctions",
  459. backend_indices[backend_key],
  460. tensor_class,
  461. gen_forced_fallback_code,
  462. backend_namespace,
  463. get_tensorlist,
  464. get_tensor_or_wrap_number,
  465. try_get_tensor,
  466. metrics_counter,
  467. create_tensor,
  468. create_from_first_tensor,
  469. create_aten_from_ltc_tensor,
  470. tuple_aten_from_ltc_tensors,
  471. lazy_tensor_ptr,
  472. get_device_fn,
  473. ),
  474. grouped_native_functions,
  475. )
  476. ),
  477. },
  478. )
  479. # Generate IR node classes
  480. lazy_ir_obj = lazy_ir_generator(
  481. backend_indices[backend_key], backend_name, node_base, use_lazy_shape
  482. )
  483. fm.write_with_template(
  484. "LazyIr.h",
  485. "LazyIr.h",
  486. lambda: {
  487. "lazy_ir_sysinc": [
  488. f"#include <{path}>"
  489. for path in [
  490. "ATen/core/Formatting.h",
  491. "c10/core/ScalarType.h",
  492. "torch/csrc/lazy/core/hash.h",
  493. "torch/csrc/lazy/core/ir.h",
  494. "torch/csrc/lazy/core/shape.h",
  495. "optional",
  496. "vector",
  497. ]
  498. ],
  499. "lazy_ir_inc": [f'#include "{node_base_hdr}"']
  500. if node_base_hdr is not None
  501. else [],
  502. "ir_declarations": list(
  503. concat_map_codegen(
  504. lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
  505. )
  506. ),
  507. "namespace_prologue": ns_helper.prologue,
  508. "namespace_epilogue": ns_helper.epilogue,
  509. },
  510. )
  511. # Generate Non Native IR Node classes
  512. fm.write_with_template(
  513. "LazyNonNativeIr.h",
  514. "LazyNonNativeIr.h",
  515. lambda: {
  516. "lazy_non_native_ir_inc": [
  517. f"#include <{path}>"
  518. for path in [
  519. "torch/csrc/lazy/core/ir.h",
  520. "torch/csrc/lazy/core/ir_builder.h",
  521. "torch/csrc/lazy/core/internal_ops/ltc_ops.h",
  522. "torch/csrc/lazy/core/shape_inference.h",
  523. ]
  524. + ([node_base_hdr] if node_base_hdr else [])
  525. if path
  526. ],
  527. "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
  528. non_native, lazy_ir_obj
  529. ),
  530. "namespace_prologue": ns_helper.prologue,
  531. "namespace_epilogue": ns_helper.epilogue,
  532. },
  533. )
  534. if __name__ == "__main__":
  535. main()