generator.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import math
  5. from typing import TYPE_CHECKING
  6. import torchgen.api.cpp as cpp
  7. from torchgen.context import native_function_manager
  8. from torchgen.model import (
  9. Argument,
  10. BackendIndex,
  11. BaseTy,
  12. BaseType,
  13. FunctionSchema,
  14. NativeFunctionsGroup,
  15. NativeFunctionsViewGroup,
  16. OptionalType,
  17. SelfArgument,
  18. TensorOptionsArguments,
  19. Type,
  20. )
  21. from torchgen.static_runtime import config
  22. if TYPE_CHECKING:
  23. from collections.abc import Sequence
  24. logger: logging.Logger = logging.getLogger()
  25. def has_alias(
  26. arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
  27. ) -> bool:
  28. for arg in arguments:
  29. annotation = getattr(arg, "annotation", None)
  30. if not annotation:
  31. continue
  32. alias_set = getattr(annotation, "alias_set", ())
  33. if alias_set:
  34. return True
  35. return False
  36. BLOCKED_OPS = frozenset(
  37. (
  38. # non cpu ops
  39. "sparse_sampled_addmm",
  40. "hspmm",
  41. "linalg_svdvals",
  42. # sparse ops
  43. "sspaddmm",
  44. "coalesce",
  45. "_indices",
  46. "indices",
  47. "_values",
  48. "values",
  49. "crow_indices",
  50. "col_indices",
  51. # deprecated ops
  52. "floor_divide",
  53. "ger",
  54. # buggy ops
  55. "conj_physical", # P495807361
  56. "binary_cross_entropy", # P496394764
  57. "arccosh",
  58. # uncommon ops
  59. "cholesky",
  60. "lu_solve",
  61. "linalg_cholesky",
  62. "linalg_householder_product",
  63. "linalg_ldl_solve",
  64. "_compute_linear_combination",
  65. # training related ops
  66. "_make_dual",
  67. # cannot call directly
  68. "_fw_primal",
  69. # no documentation
  70. "_index_reduce",
  71. # TODO: these ones got added recently and need manual inspection
  72. "_new_zeros_with_same_feature_meta",
  73. "_conj_physical",
  74. "binary_cross_entropy_with_logits",
  75. "bincount",
  76. "conv_tbc",
  77. "copy",
  78. "_copy_from",
  79. "_copy_from_and_resize",
  80. "count_nonzero",
  81. "cudnn_affine_grid_generator",
  82. "cudnn_affine_grid_generator_backward",
  83. "cudnn_grid_sampler",
  84. "diag_embed",
  85. "embedding",
  86. "embedding_dense_backward",
  87. "_embedding_bag_dense_backward",
  88. "_embedding_bag_per_sample_weights_backward",
  89. "grid_sampler_2d",
  90. "_grid_sampler_2d_cpu_fallback",
  91. "grid_sampler_3d",
  92. "isnan",
  93. "mkldnn_linear",
  94. "median",
  95. "nanmedian",
  96. "_sparse_sparse_matmul",
  97. "batch_norm_backward_elemt",
  98. "_euclidean_dist",
  99. "pixel_shuffle",
  100. "pixel_unshuffle",
  101. "channel_shuffle",
  102. "_reshape_nested_backward",
  103. "relu",
  104. "prelu",
  105. "celu",
  106. "slice_scatter",
  107. "select_scatter",
  108. "diagonal_scatter",
  109. "sum",
  110. "_mkldnn_transpose",
  111. "_nested_tensor_from_mask",
  112. "_nested_from_padded",
  113. "_nested_tensor_size",
  114. "_nested_from_padded_and_nested_example",
  115. "_standard_gamma_grad",
  116. "_dirichlet_grad",
  117. "native_norm",
  118. "_sparse_softmax",
  119. "_sparse_softmax_backward_data",
  120. "_sparse_log_softmax",
  121. "_sparse_log_softmax_backward_data",
  122. "zero",
  123. "_sparse_addmm",
  124. "sparse_mask",
  125. "_sparse_mask_projection",
  126. "_to_dense",
  127. "_coalesce",
  128. "_coalesced",
  129. "copy_sparse_to_sparse",
  130. "to_sparse",
  131. "to_sparse_csr",
  132. "to_sparse_csc",
  133. "to_mkldnn",
  134. "quantize_per_tensor_dynamic",
  135. "quantize_per_channel",
  136. "q_per_channel_scales",
  137. "q_per_channel_zero_points",
  138. "int_repr",
  139. "_make_per_channel_quantized_tensor",
  140. "set",
  141. "lift",
  142. "lift_fresh",
  143. "lift_fresh_copy",
  144. "masked_scatter",
  145. "_masked_softmax",
  146. "_masked_softmax_backward",
  147. "put",
  148. "index_reduce",
  149. "trace",
  150. "_cholesky_solve_helper",
  151. "dist",
  152. "max",
  153. "_torch_cuda_cu_linker_symbol_op",
  154. "glu_jvp",
  155. "glu_backward_jvp",
  156. "hardswish_backward",
  157. "rrelu_with_noise_backward",
  158. "mkldnn_adaptive_avg_pool2d_backward",
  159. "_adaptive_avg_pool2d_backward",
  160. "_adaptive_avg_pool3d_backward",
  161. "isinf",
  162. "linalg_lu_solve",
  163. "linalg_vecdot",
  164. "linalg_matrix_exp",
  165. "linalg_eigvalsh",
  166. "_test_warn_in_autograd",
  167. "_test_autograd_multiple_dispatch_view",
  168. "_test_autograd_multiple_dispatch_view_copy",
  169. "_segment_reduce",
  170. "_segment_reduce_backward",
  171. "_fw_primal_copy",
  172. "_make_dual_copy",
  173. "view_as_real_copy",
  174. "view_as_complex_copy",
  175. "_conj_copy",
  176. "_neg_view_copy",
  177. "diagonal_copy",
  178. "detach_copy",
  179. "squeeze_copy",
  180. "t_copy",
  181. "unsqueeze_copy",
  182. "_indices_copy",
  183. "_values_copy",
  184. "indices_copy",
  185. "values_copy",
  186. "crow_indices_copy",
  187. "col_indices_copy",
  188. "ccol_indices",
  189. "ccol_indices_copy",
  190. "row_indices",
  191. "row_indices_copy",
  192. "unfold_copy",
  193. "alias_copy",
  194. "_triton_multi_head_attention",
  195. "special_airy_ai",
  196. "special_bessel_j0",
  197. "special_bessel_j1",
  198. "special_bessel_y0",
  199. "special_bessel_y1",
  200. "special_chebyshev_polynomial_t",
  201. "special_chebyshev_polynomial_u",
  202. "special_chebyshev_polynomial_v",
  203. "special_chebyshev_polynomial_w",
  204. "special_hermite_polynomial_h",
  205. "special_hermite_polynomial_he",
  206. "special_laguerre_polynomial_l",
  207. "special_legendre_polynomial_p",
  208. "special_modified_bessel_i0",
  209. "special_modified_bessel_i1",
  210. "special_modified_bessel_k0",
  211. "special_modified_bessel_k1",
  212. "special_scaled_modified_bessel_k0",
  213. "special_scaled_modified_bessel_k1",
  214. "special_shifted_chebyshev_polynomial_t",
  215. "special_shifted_chebyshev_polynomial_u",
  216. "special_shifted_chebyshev_polynomial_v",
  217. "special_shifted_chebyshev_polynomial_w",
  218. "special_spherical_bessel_j0",
  219. "_foobar",
  220. "_nested_tensor_strides",
  221. "_nested_tensor_storage_offsets",
  222. "_nested_get_values", # no CPU backend
  223. "_nested_get_values_copy", # no CPU backend
  224. "_nested_view_from_jagged", # testing needs to be patched
  225. "_nested_view_from_jagged_copy", # testing needs to be patched
  226. "_nested_view_from_buffer", # testing needs to be patched
  227. "_nested_view_from_buffer_copy", # testing needs to be patched
  228. "_int_mm", # testing needs to be patched
  229. "_to_sparse_csc", # testing needs to be patched
  230. "_to_sparse_csr", # testing needs to be patched
  231. "segment_reduce", # testing needs to be patched
  232. )
  233. )
  234. def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
  235. base_op_name = ""
  236. func = None
  237. if isinstance(g, NativeFunctionsViewGroup):
  238. base_op_name = g.view.root_name
  239. func = g.view.func
  240. else:
  241. base_op_name = g.out.func.name.name.base
  242. func = g.out.func
  243. if config.is_hand_written(g):
  244. logger.info("HAND WRITTEN: %s", base_op_name)
  245. return False
  246. if base_op_name in BLOCKED_OPS:
  247. logger.info("BLOCKED: %s", base_op_name)
  248. return False
  249. for arg in func.schema_order_arguments():
  250. maybe_method = ivalue_type_conversion_method(arg.type)
  251. if not maybe_method:
  252. # Type converting is unsupported yet.
  253. logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func)
  254. return False
  255. if isinstance(g, NativeFunctionsViewGroup):
  256. # TODO: stop doing type tests by converting to C++ and then testing
  257. # the string, just test the dang thing directly
  258. if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
  259. # Returns a non-Tensor value.
  260. logger.info("NON-TENSOR RET TYPE: %s", str(func))
  261. return False
  262. return True
  263. # For out variant ops, we need to check the arguments of its functional func.
  264. for arg in g.functional.func.schema_order_arguments():
  265. maybe_method = ivalue_type_conversion_method(arg.type)
  266. if not maybe_method:
  267. # Type converting is unsupported yet.
  268. logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func)
  269. return False
  270. if not g.structured:
  271. # In case of unstructured op, we check if it has out variant implementation.
  272. # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
  273. # parameter.
  274. if (
  275. not hasattr(g, "out")
  276. or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
  277. or not str(func.name).endswith(".out")
  278. ):
  279. return False
  280. # TODO: stop type testing by converting to C++
  281. if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
  282. logger.info("NON_TENSOR RET TYPE: %s", func)
  283. return False
  284. if has_alias(func.arguments.non_out):
  285. # This op may create an alias of inputs.
  286. logger.info("INPUTS ALIAS: %s", base_op_name)
  287. return False
  288. return True
  289. def ivalue_type_conversion_method(
  290. arg_type: BaseType | OptionalType | Type,
  291. ) -> tuple[bool, str] | None:
  292. """
  293. Return the method call expression of `c10::ivalue' to convert its contained value to
  294. the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
  295. this function returns ".toTensor()", so that it can be appended to the ivalue's
  296. variable name to get the value of the expected type.
  297. """
  298. type_conversion_methods = {
  299. BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
  300. BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
  301. BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
  302. BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
  303. BaseTy.ScalarType: (
  304. (False, "toScalarType()"),
  305. (False, "toOptional<at::ScalarType>()"),
  306. ),
  307. BaseTy.str: (
  308. (False, "toStringView()"),
  309. (False, "toOptional<c10::string_view>()"),
  310. (False, "toOptional<::std::string_view>()"),
  311. ),
  312. }
  313. base_ty_object = None
  314. if isinstance(arg_type, BaseType):
  315. base_ty_object = arg_type.name
  316. elif isinstance(arg_type, OptionalType):
  317. if not isinstance(arg_type.elem, BaseType):
  318. # ListType is currently unsupported.
  319. return None
  320. base_ty_object = arg_type.elem.name
  321. else:
  322. return None
  323. if base_ty_object not in type_conversion_methods:
  324. return None
  325. methods = type_conversion_methods[base_ty_object]
  326. if isinstance(arg_type, BaseType):
  327. return methods[0]
  328. return methods[1]
  329. should_use_int_tensor_ops_ = frozenset(
  330. (
  331. "bitwise_not",
  332. "bitwise_and",
  333. "bitwise_or",
  334. "bitwise_xor",
  335. "bitwise_left_shift",
  336. "bitwise_right_shift",
  337. "gcd",
  338. "lcm",
  339. "scatter",
  340. "gather",
  341. "_convert_indices_from_coo_to_csr",
  342. "_convert_indices_from_csr_to_coo",
  343. )
  344. )
  345. should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
  346. def should_use_int_tensor(op_name: str) -> bool:
  347. return op_name in should_use_int_tensor_ops_
  348. def should_use_complex_tensor(op_name: str) -> bool:
  349. return op_name in should_use_complex_tensor_ops_
  350. test_tensor_dim_ops_1_ = frozenset(
  351. (
  352. "addmv",
  353. "index_add",
  354. "_convert_indices_from_coo_to_csr",
  355. "_convert_indices_from_csr_to_coo",
  356. "nll_loss_backward",
  357. "dot",
  358. "vdot",
  359. "outer",
  360. "ger",
  361. )
  362. )
  363. test_tensor_dim_ops_2_ = frozenset(
  364. ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
  365. )
  366. def test_tensor_dim(op_name: str) -> int:
  367. if op_name in test_tensor_dim_ops_1_:
  368. return 1
  369. if op_name in test_tensor_dim_ops_2_:
  370. return 2
  371. return 3
  372. test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
  373. test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
  374. def test_tensor_shape(op_name: str) -> str:
  375. if op_name in test_tensor_shape_json:
  376. return test_tensor_shape_json[op_name]
  377. else:
  378. return ""
  379. def test_value_expression(
  380. arg_type: BaseType | OptionalType | Type, index: int, op_name: str
  381. ) -> str:
  382. tensor_size_ex = test_tensor_shape(op_name)
  383. if tensor_size_ex == "":
  384. num_tensors = 16 if index == 0 else 64
  385. num_dim = test_tensor_dim(op_name)
  386. size_per_dim = math.ceil(num_tensors / float(num_dim))
  387. size_per_dim += size_per_dim % 2
  388. tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim))
  389. if should_use_int_tensor(op_name):
  390. tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
  391. elif should_use_complex_tensor(op_name):
  392. tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
  393. else:
  394. tensor_expression = f"at::rand({tensor_size_ex})"
  395. value_expressions = {
  396. BaseTy.Tensor: tensor_expression,
  397. BaseTy.int: "1",
  398. BaseTy.bool: "false",
  399. BaseTy.Scalar: "2",
  400. BaseTy.ScalarType: "at::ScalarType::Float",
  401. BaseTy.str: '"floor"',
  402. }
  403. base_ty_object = None
  404. if isinstance(arg_type, BaseType):
  405. base_ty_object = arg_type.name
  406. else:
  407. assert isinstance(arg_type, OptionalType) and isinstance(
  408. arg_type.elem, BaseType
  409. )
  410. base_ty_object = arg_type.elem.name
  411. assert base_ty_object in value_expressions, "not expected type"
  412. value_expression = value_expressions[base_ty_object]
  413. return value_expression
  414. def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
  415. assert not schema.is_out_fn()
  416. schema_name = schema.name.name.base
  417. arg_map = {}
  418. for arg in schema.schema_order_arguments():
  419. test_value_exp = test_value_expression(arg.type, index, schema_name)
  420. arg_map[arg.name] = test_value_exp
  421. config.override_test_values(arg_map, schema_name, index)
  422. arg_populations = []
  423. for arg_name, arg_value in arg_map.items():
  424. arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
  425. return ";\n ".join(arg_populations) + ";"
  426. def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
  427. assert not schema.is_out_fn()
  428. return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
  429. generate_test_ir_arguments_base_ty_to_type_str_ = {
  430. BaseTy.Tensor: "Tensor",
  431. BaseTy.int: "int",
  432. BaseTy.float: "float",
  433. BaseTy.str: "str",
  434. BaseTy.Scalar: "int",
  435. BaseTy.ScalarType: "int",
  436. BaseTy.bool: "bool",
  437. }
  438. def generate_test_ir_arguments(
  439. schema: FunctionSchema,
  440. ) -> list[tuple[str, str | None]]:
  441. def ir_argument(arg: Argument) -> tuple[str, str | None]:
  442. t = arg.type
  443. add_optional = False
  444. if isinstance(t, OptionalType):
  445. t = t.elem
  446. add_optional = True
  447. assert isinstance(t, BaseType)
  448. type_str = None
  449. if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
  450. type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
  451. if type_str and add_optional:
  452. type_str = f"{type_str}?"
  453. return ("%" + arg.name, type_str)
  454. return [ir_argument(arg) for arg in schema.schema_order_arguments()]
  455. def generate_arg_extraction(schema: FunctionSchema) -> str:
  456. arg_populations = []
  457. for i, arg in enumerate(schema.schema_order_arguments()):
  458. maybe_method = ivalue_type_conversion_method(arg.type)
  459. assert maybe_method
  460. is_reference, type_conversion_method = maybe_method
  461. reference = "&" if is_reference else ""
  462. arg_populations.append(
  463. f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
  464. )
  465. return ";\n ".join(arg_populations) + ";"
  466. def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
  467. kernel = backend_index.get_kernel(g.functional)
  468. if g.structured or kernel is None:
  469. return cpp.name(g.functional.func)
  470. return kernel.kernel
  471. def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
  472. kernel = backend_index.get_kernel(g.out)
  473. if g.structured or kernel is None:
  474. return cpp.name(g.out.func)
  475. return kernel.kernel
  476. def generate_non_out_variant_call(
  477. g: NativeFunctionsGroup, backend_index: BackendIndex
  478. ) -> str:
  479. schema = g.functional.func
  480. assert not schema.is_out_fn()
  481. kernel_name = get_kernel_name(g, backend_index)
  482. arg_names = (arg.name for arg in schema.schema_order_arguments())
  483. namespace_name = "cpu" if g.structured else "native"
  484. return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
  485. def generate_call_to_view_ops(
  486. g: NativeFunctionsViewGroup, backend_index: BackendIndex
  487. ) -> str:
  488. schema = g.view.func
  489. kernel_name = cpp.name(schema)
  490. kernel = backend_index.get_kernel(g.view)
  491. if kernel:
  492. kernel_name = kernel.kernel
  493. arg_names = (arg.name for arg in schema.schema_order_arguments())
  494. namespace_name = "native"
  495. return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
  496. def generate_out_variant_call(
  497. g: NativeFunctionsGroup, backend_index: BackendIndex
  498. ) -> str:
  499. schema = g.out.func
  500. assert schema.is_out_fn()
  501. arg_names = []
  502. kernel_name = get_out_kernel_name(g, backend_index)
  503. if g.structured:
  504. # structured op starts with the output tensor argument.
  505. arg_names = [out_arg.name for out_arg in schema.arguments.out]
  506. else:
  507. arg_names = []
  508. for arg in schema.arguments.non_out:
  509. if isinstance(arg, SelfArgument):
  510. arg_names.append(arg.argument.name)
  511. else:
  512. assert isinstance(arg, Argument)
  513. arg_names.append(arg.name)
  514. if not g.structured:
  515. assert len(schema.arguments.out) == 1
  516. arg_names.append(schema.arguments.out[0].name)
  517. cpp_arg_names = ",".join(arg_names)
  518. namespace_name = "cpu" if g.structured else "native"
  519. return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
  520. no_memory_resize_ops = frozenset(
  521. (
  522. "isin.Scalar_Tensor",
  523. "index_add",
  524. "dot",
  525. "vdot",
  526. "nuclear_norm",
  527. "histc",
  528. "l1_loss",
  529. "multi_margin_loss",
  530. "multilabel_margin_loss",
  531. "nll_loss",
  532. "nll_loss2d",
  533. "prod",
  534. )
  535. )
  536. def should_check_resize(schema: FunctionSchema) -> bool:
  537. schema_str = str(schema)
  538. type_variant_op_name = schema_str[: schema_str.find("(")]
  539. return type_variant_op_name not in no_memory_resize_ops
  540. def op_name_from_group(g: NativeFunctionsGroup) -> str:
  541. return g.functional.func.name.name.base
  542. class GenOpDispatcher:
  543. def out_variant(
  544. self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
  545. ) -> str:
  546. if not groups:
  547. return ""
  548. generated_type_variants = []
  549. for g in groups:
  550. with native_function_manager(g):
  551. assert is_supported(g)
  552. assert isinstance(g, NativeFunctionsGroup)
  553. generated_type_variant = self.out_variant_op_generator(g, backend_index)
  554. generated_type_variants.append(generated_type_variant)
  555. op_name = op_name_from_group(groups[0])
  556. body = "\n".join(generated_type_variants)
  557. generated = f"""
  558. REGISTER_OPERATOR_FUNCTOR(
  559. aten::{op_name},
  560. aten_{op_name},
  561. [](Node* n) -> SROperator {{
  562. {body}
  563. LogAndDumpSchema(n);
  564. return nullptr;
  565. }})
  566. """
  567. return generated
  568. def view(
  569. self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
  570. ) -> str:
  571. if not groups:
  572. return ""
  573. generated_type_variants = []
  574. for g in groups:
  575. with native_function_manager(g):
  576. assert is_supported(g)
  577. assert isinstance(g, NativeFunctionsViewGroup)
  578. generated_type_variant = self.view_op_generator(g, backend_index)
  579. generated_type_variants.append(generated_type_variant)
  580. op_name = config.func_name_base_str(groups[0])
  581. body = "\n".join(generated_type_variants)
  582. generated = f"""
  583. REGISTER_NATIVE_OPERATOR_FUNCTOR(
  584. aten::{op_name},
  585. aten_{op_name},
  586. [](Node* n) -> SROperator {{
  587. {body}
  588. LogAndDumpSchema(n);
  589. return nullptr;
  590. }});
  591. """
  592. return generated
  593. def out_variant_op_generator(
  594. self, g: NativeFunctionsGroup, backend_index: BackendIndex
  595. ) -> str:
  596. functional = g.functional
  597. schema = str(functional.func)
  598. populated_argument = generate_arg_extraction(g.functional.func)
  599. functional_variant_call = generate_non_out_variant_call(g, backend_index)
  600. assert len(g.out.func.arguments.out) == 1
  601. out_variable_name = str(g.out.func.arguments.out[0].name)
  602. out_variant_call = generate_out_variant_call(g, backend_index)
  603. generated = f"""
  604. if (n->matches(torch::schema("aten::{schema}"))) {{
  605. return [](ProcessedNode* p_node) {{
  606. {populated_argument}
  607. if (p_node->Output(0).isNone()) {{
  608. p_node->Output(0) = {functional_variant_call};
  609. return;
  610. }}
  611. auto& {out_variable_name} = p_node->Output(0).toTensor();
  612. fastResizeToZero({out_variable_name});
  613. {out_variant_call};
  614. }};
  615. }}"""
  616. return generated
  617. def view_op_generator(
  618. self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
  619. ) -> str:
  620. schema = str(g.view.func)
  621. populated_argument = generate_arg_extraction(g.view.func)
  622. functional_variant_call = generate_call_to_view_ops(g, backend_index)
  623. generated = f"""
  624. if (n->matches(torch::schema("aten::{schema}"))) {{
  625. return [](ProcessedNode* p_node) {{
  626. {populated_argument}
  627. p_node->Output(0) = {functional_variant_call};
  628. }};
  629. }}"""
  630. return generated
  631. class GenOpTestCase:
  632. def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
  633. if not groups:
  634. return ""
  635. generated_type_variants = []
  636. for g in groups:
  637. with native_function_manager(g):
  638. assert is_supported(g)
  639. assert isinstance(g, NativeFunctionsGroup)
  640. generated_type_variant = self.out_variant_op_test_case_generator(g)
  641. generated_type_variants.append(generated_type_variant)
  642. return "\n".join(generated_type_variants)
  643. def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
  644. if not groups:
  645. return ""
  646. generated_type_variants = []
  647. for g in groups:
  648. with native_function_manager(g):
  649. assert is_supported(g)
  650. assert isinstance(g, NativeFunctionsViewGroup)
  651. generated_type_variant = self.view_op_test_case_generator(g)
  652. generated_type_variants.append(generated_type_variant)
  653. return "\n".join(generated_type_variants)
  654. def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
  655. schema = g.functional.func
  656. schema_str = str(schema)
  657. assert schema_str.find("(") > 0
  658. type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
  659. op_name = op_name_from_group(g)
  660. assert type_variant_op_name.startswith(op_name)
  661. arg_types = generate_test_ir_arguments(schema)
  662. arg_declarations = ", ".join(
  663. (
  664. arg_name if arg_type is None else f"{arg_name}: {arg_type}"
  665. for arg_name, arg_type in arg_types
  666. )
  667. )
  668. arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
  669. assert (
  670. len(schema.returns) == 1
  671. and isinstance(schema.returns[0].type, BaseType)
  672. and schema.returns[0].type.name is BaseTy.Tensor
  673. )
  674. test_value_definitions = generate_test_value_definitions(schema, 0)
  675. test_value_names = generate_test_value_names(schema, 0)
  676. test_value_definitions2 = generate_test_value_definitions(schema, 1)
  677. test_value_names2 = generate_test_value_names(schema, 1)
  678. check_resize = "true" if should_check_resize(schema) else "false"
  679. generated = f"""
  680. TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  681. const std::string script = R"IR(
  682. graph({arg_declarations}):
  683. %bias: None = prim::Constant()
  684. %ret = aten::{op_name}({arg_names})
  685. %cloned = aten::clone(%ret, %bias)
  686. return (%cloned)
  687. )IR";
  688. {test_value_definitions}
  689. std::vector<IValue> args{{{test_value_names}}};
  690. testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
  691. {test_value_definitions2}
  692. std::vector<IValue> args2{{{test_value_names2}}};
  693. testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
  694. }}
  695. """
  696. return generated
  697. def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
  698. schema = g.view.func
  699. schema_str = str(schema)
  700. assert schema_str.find("(") > 0
  701. type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
  702. op_name = g.view.root_name
  703. assert type_variant_op_name.startswith(op_name)
  704. arg_types = generate_test_ir_arguments(schema)
  705. arg_declarations = ", ".join(
  706. (
  707. arg_name if arg_type is None else f"{arg_name}: {arg_type}"
  708. for arg_name, arg_type in arg_types
  709. )
  710. )
  711. arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
  712. assert (
  713. len(schema.returns) == 1
  714. and isinstance(schema.returns[0].type, BaseType)
  715. and schema.returns[0].type.name is BaseTy.Tensor
  716. )
  717. test_value_definitions = generate_test_value_definitions(schema, 0)
  718. test_value_names = generate_test_value_names(schema, 0)
  719. generated = f"""
  720. TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  721. const std::string script = R"IR(
  722. graph({arg_declarations}):
  723. %bias: None = prim::Constant()
  724. %ret = aten::{op_name}({arg_names})
  725. %cloned = aten::clone(%ret, %bias)
  726. return (%cloned)
  727. )IR";
  728. {test_value_definitions}
  729. std::vector<IValue> args{{{test_value_names}}};
  730. testStaticRuntime(script, args);
  731. }}
  732. """
  733. return generated