python.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING
  4. from torchgen.api import cpp
  5. from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
  6. from torchgen.gen import pythonify_default
  7. from torchgen.model import (
  8. Argument,
  9. BaseTy,
  10. BaseType,
  11. FunctionSchema,
  12. ListType,
  13. NativeFunction,
  14. OptionalType,
  15. Return,
  16. Type,
  17. Variant,
  18. )
  19. if TYPE_CHECKING:
  20. from collections.abc import Iterable, Sequence
  21. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  22. #
  23. # Data Models
  24. #
  25. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  26. #
  27. # [Notes] python binding codegen
  28. #
  29. # The Python binding codegen produces code that takes the input list of
  30. # PyObjects, finds the matching ATen C++ function using PythonArgParser,
  31. # converts the PyObjects into C++ types and calls the ATen C++ function:
  32. #
  33. # +--------+ parsing +------------------------+ binding +-----------------------+
  34. # | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
  35. # +--------+ +------------------------+ +-----------------------+
  36. #
  37. # The following examples demonstrate the data models the Python binding
  38. # codegen needs to deal with and the tasks it needs to accomplish. It
  39. # helps understand the purpose of the new data types we introduced below.
  40. #
  41. # - Function Schema (source of truth)
  42. #
  43. # aten::empty.names(int[] size, *, Dimname[]? names,
  44. # ScalarType? dtype=None, Layout? layout=None,
  45. # Device? device=None, bool? pin_memory=None,
  46. # MemoryFormat? memory_format=None) -> Tensor
  47. #
  48. # - Python Signature
  49. #
  50. # It's used to generate input schema string for PythonArgParser.
  51. # Note: TensorOptions fields are reordered and the additional
  52. # 'requires_grad' field is added:
  53. #
  54. # empty(IntArrayRef size, *, DimnameList? names,
  55. # MemoryFormat? memory_format=None, ScalarType dtype=None,
  56. # Layout layout=torch.strided, Device device=None,
  57. # bool pin_memory=False, bool requires_grad=False)
  58. #
  59. # - C++ Signature
  60. #
  61. # It's used to generate C++ lambda formals & dispatch call.
  62. # Note: the scattered TensorOptions fields are packed into 'options'.
  63. #
  64. # auto dispatch_empty =
  65. # [](IntArrayRef size, std::optional<DimnameList> names,
  66. # const TensorOptions & options,
  67. # std::optional<MemoryFormat> memory_format) -> Tensor {
  68. # pybind11::gil_scoped_release no_gil;
  69. # return torch::empty(size, names, options, memory_format);
  70. # };
  71. #
  72. # - Binding between Python Arguments and C++ Arguments
  73. #
  74. # Given a set of Python Arguments in scope, we need produce the
  75. # binding expressions that translate the Python API into C++ API:
  76. #
  77. # Python Args Cpp Args Binding Exprs
  78. # -----------------------------------------------------------------
  79. # 0: size size '_r.intlist(0)'
  80. # 1: names names 'names' [special init]
  81. # 2: memory_format -------+
  82. # 3: dtype -----+-|--> options 'options' [special packing]
  83. # 4: layout / |
  84. # 5: device / +--> memory_format '_r.memoryformatOptional(2)'
  85. # 6: pin_memory /
  86. # 7: requires_grad -+
  87. #
  88. # So the full dispatch expression would look like:
  89. #
  90. # dispatch_empty(_r.intlist(0), names, options,
  91. # _r.memoryformatOptional(2))
  92. #
  93. # Where does 'names' come from? It involves special local init:
  94. #
  95. # auto __names = _r.toDimnameListOptional(1);
  96. # std::optional<DimnameList> names =
  97. # __names ? std::make_optional(DimnameList(__names.value()))
  98. # : std::nullopt;
  99. #
  100. # Where does 'options' come from? It involves special local init
  101. # for TensorOptions. Note that Python side has the additional
  102. # 'requires_grad' field:
  103. #
  104. # const auto options = TensorOptions()
  105. # .dtype(_r.scalartype(3))
  106. # .device(_r.device(5))
  107. # .layout(_r.layoutOptional(4))
  108. # .requires_grad(_r.toBool(7))
  109. # .pinned_memory(_r.toBool(6));
  110. #
  111. # In some other cases one Python Argument can map to multiple C++
  112. # Arguments. For example:
  113. #
  114. # aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
  115. # -> (Tensor values, Tensor indices)
  116. #
  117. # Python Args Cpp Args Binding Exprs
  118. # ---------------------------------------------------------------------
  119. # +----> max 'out[0]'
  120. # /-----> max_values 'out[1]
  121. # 0: input / self '_r.tensor(0)'
  122. # 1: dim / dim '_r.dimname(1)'
  123. # 2: keepdim / keepdim '_r.toBool(2)'
  124. # 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
  125. #
  126. # As demonstrated above, the binding can involve reordering,
  127. # packing, unpacking and special local inits.
  128. #
  129. #
  130. # Let's look at a concrete example:
  131. #
  132. # static PythonArgParser parser({
  133. # "abs(Tensor input, *, Tensor out=None)",
  134. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  135. # ^
  136. # +--- Python Schema, represented by PythonSignature and PythonArgument
  137. #
  138. # }, /*traceable=*/true);
  139. #
  140. # ParsedArgs<2> parsed_args;
  141. # auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
  142. #
  143. # ...
  144. #
  145. # if (_r.isNone(1)) {
  146. # ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
  147. # represented by PythonArgParserOutputExpr
  148. #
  149. # // aten::abs(Tensor self) -> Tensor
  150. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  151. # ^
  152. # +--- NativeFunction schema, base version
  153. #
  154. # auto dispatch_abs = [](const Tensor & self) -> Tensor {
  155. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  156. # ^
  157. # +--- dispatch_lambda_args / dispatch_lambda_return_str
  158. # generated from NativeFunction / CppSignature
  159. # (deprecated PythonSignature is special)
  160. # arguments are represented by DispatchLambdaArgument
  161. #
  162. # pybind11::gil_scoped_release no_gil;
  163. # return self.abs();
  164. # ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
  165. # generated from NativeFunction / CppSignature
  166. # };
  167. # return wrap(dispatch_abs(_r.tensor(0)));
  168. # ~~~~~~~~~~~~~
  169. # ^
  170. # +--- dispatch_lambda_exprs
  171. # binding PythonArgParserOutputExpr (python args)
  172. # and DispatchLambdaArgument (c++ args)
  173. #
  174. # } else {
  175. # // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
  176. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  177. # ^
  178. # +--- NativeFunction schema, out-variant
  179. #
  180. # auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
  181. # pybind11::gil_scoped_release no_gil;
  182. # return at::abs_out(out, self);
  183. # };
  184. # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
  185. # }
  186. #
  187. #
  188. # [Notes] python interface codegen
  189. # The python dataclasses below are used used to generate both python binding code
  190. # and pyi type hint signatures.
  191. # In theory these two should look very similar, but there are number of differences
  192. # in how pyi signatures vs. python_arg_parser signatures are generated.
  193. # These differences have been encapsulated in signature_str() vs. signature_str_pyi()
  194. # to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
  195. # For examples, only pyi signatures include return types.
  196. def format_function_signature(
  197. name: str, arguments: Iterable[str] = (), return_type: str | None = None
  198. ) -> str:
  199. if not isinstance(arguments, (list, tuple)):
  200. arguments = tuple(arguments)
  201. return_type = f" -> {return_type}" if return_type is not None else ""
  202. sig = f"def {name}({', '.join(arguments)}){return_type}: ..."
  203. if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
  204. return sig
  205. lines = [
  206. f"def {name}(",
  207. *(f" {arg}," for arg in arguments),
  208. f"){return_type}: ...",
  209. ]
  210. sig = "\n".join(lines)
  211. if all(len(line) <= 80 for line in lines):
  212. return sig
  213. # ruff format bug for compound statements: https://github.com/astral-sh/ruff/issues/18658
  214. # use `skip` instead of `on` + `off`
  215. return sig.removesuffix(" ...") + " # fmt: skip\n ..."
  216. @dataclass(frozen=True)
  217. class PythonReturns:
  218. returns: tuple[Return, ...]
  219. @dataclass(frozen=True)
  220. class PythonArgument:
  221. name: str
  222. type: Type
  223. default: str | None
  224. # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
  225. #
  226. # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
  227. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  228. # ^
  229. # +--- default_init str
  230. default_init: str | None
  231. # Compute argument formal for python argument parsing.
  232. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
  233. def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
  234. type_str = (
  235. argument_type_str(self.type, symint=symint)
  236. .replace("const ", "")
  237. .replace(" &", "")
  238. )
  239. name = self.name
  240. # s/self/input/ outside method bindings
  241. # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
  242. # for the parse string
  243. if name == "self" and type_str in ["Tensor", "Number"] and not method:
  244. name = "input"
  245. # add default
  246. if self.default is not None:
  247. default = {
  248. "nullptr": "None",
  249. "::std::nullopt": "None",
  250. "std::nullopt": "None",
  251. "{}": "None",
  252. }.get(self.default, self.default)
  253. return f"{type_str} {name}={default}"
  254. else:
  255. return f"{type_str} {name}"
  256. def argument_str_pyi(
  257. self, *, method: bool = False, deprecated: bool = False
  258. ) -> str:
  259. type_str = argument_type_str_pyi(self.type)
  260. name = self.name
  261. # s/self/input/ outside method bindings
  262. # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
  263. # for the parse string
  264. if name == "self" and type_str == "Tensor" and not method and not deprecated:
  265. name = "input"
  266. if name == "from": # from is a Python keyword...
  267. name += "_"
  268. # pyi merges the _out and functional variants into the same signature, with an optional out arg
  269. if name == "out" and type_str == "Tensor" and not deprecated:
  270. type_str = f"{type_str} | None".replace(" | None | None", " | None")
  271. # pyi deprecated signatures don't get defaults for their out arg
  272. treat_as_no_default = (
  273. deprecated
  274. and isinstance(self, PythonOutArgument)
  275. and self.default == "None"
  276. )
  277. # add default
  278. if self.default is not None and not treat_as_no_default:
  279. if (
  280. isinstance(self.type, ListType)
  281. and self.type.elem == BaseType(BaseTy.int)
  282. and self.default.startswith("{")
  283. and self.default.endswith("}")
  284. ):
  285. default = (
  286. "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
  287. )
  288. else:
  289. default = {
  290. "nullptr": "None",
  291. "::std::nullopt": "None",
  292. "std::nullopt": "None",
  293. "{}": "None",
  294. "c10::MemoryFormat::Contiguous": "contiguous_format",
  295. "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
  296. }.get(self.default, self.default)
  297. return f"{name}: {type_str} = {default}"
  298. else:
  299. return f"{name}: {type_str}"
  300. @dataclass(frozen=True)
  301. class PythonOutArgument(PythonArgument):
  302. # In Python signature multiple output fields are packed into one 'out' argument.
  303. # When binding to C++, it's first binded to a local 'out' variable:
  304. # 'auto out = _r.tensorlist_n<2>(2);',
  305. # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
  306. # TODO: maybe don't need keep scattered out fields for python signature?
  307. outputs: tuple[PythonArgument, ...]
  308. @staticmethod
  309. def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
  310. if not outputs:
  311. return None
  312. size = len(outputs)
  313. if size == 1:
  314. return PythonOutArgument(
  315. name=outputs[0].name,
  316. type=outputs[0].type,
  317. default="None",
  318. default_init=None,
  319. outputs=outputs,
  320. )
  321. elif size > 1:
  322. if any(not a.type.is_tensor_like() for a in outputs):
  323. raise RuntimeError(f"Unsupported output type: {outputs}")
  324. return PythonOutArgument(
  325. name="out",
  326. # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
  327. type=ListType(BaseType(BaseTy.Tensor), size),
  328. default="None",
  329. default_init=None,
  330. outputs=outputs,
  331. )
  332. raise AssertionError(r"Unexpected PythonOutArgument size")
  333. @dataclass(frozen=True)
  334. class PythonSignature:
  335. # Base operator name, without inplace/outplace suffix.
  336. name: str
  337. # Positional arguments.
  338. # TODO: create a dedicated SelfArgument type for 'self'?
  339. input_args: tuple[PythonArgument, ...]
  340. # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
  341. # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
  342. input_kwargs: tuple[PythonArgument, ...]
  343. output_args: PythonOutArgument | None
  344. # Return types, which are only used by pyi
  345. returns: PythonReturns
  346. # These are scattered kwargs arguments belonging to TensorOptions.
  347. # When binding to C++, they are packed into a TensorOptions object 'options'.
  348. # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
  349. # for out variant), in which case they will be used as scattered fields without
  350. # being packed into 'options'.
  351. # TODO: maybe create a PythonTensorOptionsArgument?
  352. tensor_options_args: tuple[PythonArgument, ...]
  353. # method or function signature?
  354. method: bool
  355. @property
  356. def deprecated(self) -> bool:
  357. return False
  358. def arguments(
  359. self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
  360. ) -> tuple[PythonArgument | PythonOutArgument, ...]:
  361. result: list[PythonArgument | PythonOutArgument] = []
  362. result.extend(self.input_args)
  363. result.extend(self.input_kwargs)
  364. if self.output_args is not None and not skip_outputs:
  365. result.append(self.output_args)
  366. if not skip_tensor_options:
  367. result.extend(self.tensor_options_args)
  368. return tuple(result)
  369. def arguments_count(self) -> int:
  370. return len(self.arguments())
  371. def output_idx(self) -> int:
  372. return len(self.input_args) + len(self.input_kwargs)
  373. # [old codegen] Compute the Python function signature for argument parsing,
  374. # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
  375. # this is NOT the same type signature as specified by PEP 484
  376. # as understood by mypy; our format was independently developed
  377. # and has some quirks to make it more suitable specifically
  378. # for error parsing.
  379. #
  380. # For a translation to mypy-valid type signatures, see
  381. # signature_str_pyi().
  382. def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
  383. args = self.arguments(skip_outputs=skip_outputs)
  384. schema_formals: list[str] = [
  385. a.argument_str(method=self.method, symint=symint) for a in args
  386. ]
  387. positional_argc = len(self.input_args)
  388. if len(schema_formals) > positional_argc:
  389. schema_formals.insert(positional_argc, "*")
  390. return f"{self.name}({', '.join(schema_formals)})"
  391. def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
  392. args = self.arguments(skip_outputs=skip_outputs)
  393. schema_formals: list[str] = [
  394. a.argument_str_pyi(method=self.method) for a in args
  395. ]
  396. positional_argc = len(self.input_args)
  397. if len(schema_formals) > positional_argc:
  398. schema_formals.insert(positional_argc, "*")
  399. # only pyi signatures include returns
  400. returns_str = returns_str_pyi(self)
  401. # pyi also includes self (with no typing/defaults) for methods
  402. if self.method:
  403. schema_formals.insert(0, "self")
  404. return format_function_signature(self.name, schema_formals, returns_str)
  405. def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
  406. # only pyi uses vararg signatures
  407. args = self.arguments(skip_outputs=skip_outputs)
  408. schema_formals: list[str] = [
  409. a.argument_str_pyi(method=self.method) for a in args
  410. ]
  411. # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
  412. num_args = self.arguments_count()
  413. if num_args == 0:
  414. return None
  415. num_positionalargs = len(self.input_args)
  416. vararg_type = args[0].type
  417. if not (
  418. isinstance(vararg_type, ListType)
  419. and str(vararg_type.elem) in ["int", "SymInt"]
  420. and num_positionalargs == 1
  421. ):
  422. return None
  423. # Below are the major changes in vararg vs. regular pyi signatures
  424. # vararg signatures also omit the asterix
  425. assert isinstance(vararg_type, ListType)
  426. schema_formals[0] = (
  427. "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
  428. )
  429. returns_str = returns_str_pyi(self)
  430. # pyi also includes self (with no typing/defaults) for methods
  431. if self.method:
  432. schema_formals.insert(0, "self")
  433. return format_function_signature(self.name, schema_formals, returns_str)
  434. # The deprecated python signature involves some special logic, so create a
  435. # dedicated data model to store these extra properties.
  436. @dataclass(frozen=True)
  437. class PythonSignatureDeprecated(PythonSignature):
  438. # Schema for the deprecated function
  439. deprecated_schema: FunctionSchema
  440. # The deprecated signature might miss some arguments that the corresponding
  441. # C++ signature expects. We need store the constant default values to pass in.
  442. # For example:
  443. # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
  444. # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
  445. # [func call]: self.addmm(mat1, mat2, beta, 1)
  446. # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
  447. deprecated_args_exprs: tuple[str, ...]
  448. @property
  449. def deprecated(self) -> bool:
  450. return True
  451. def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
  452. return (
  453. PythonSignature.signature_str(
  454. self, skip_outputs=skip_outputs, symint=symint
  455. )
  456. + "|deprecated"
  457. )
  458. def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
  459. args = self.arguments(skip_outputs=skip_outputs)
  460. schema_formals: list[str] = [
  461. a.argument_str_pyi(method=self.method, deprecated=True) for a in args
  462. ]
  463. positional_argc = len(self.input_args)
  464. if len(schema_formals) > positional_argc:
  465. schema_formals.insert(positional_argc, "*")
  466. returns_str = returns_str_pyi(self)
  467. return format_function_signature(self.name, schema_formals, returns_str)
  468. def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
  469. # the codegen doesn't include vararg variants for deprecated signatures
  470. return None
  471. # This struct is used to hold the PythonSignature and its corresponding
  472. # NativeFunction BEFORE grouping base and out-variant functions.
  473. # Why not store NativeFunction in PythonSignature or construct PythonSignature
  474. # from NativeFunction? Because they are not 1-1 mapped.
  475. # One native function could have both deprecated and non-deprecated python
  476. # signatures - NativeFunction doesn't contain information to construct the
  477. # deprecated python signature.
  478. # One python signature is used to handle both the base and the out-variant
  479. # function - see 'PythonSignatureGroup'.
  480. @dataclass(frozen=True)
  481. class PythonSignatureNativeFunctionPair:
  482. signature: PythonSignature
  483. function: NativeFunction
  484. # We merge pairs of functions with signatures that are equivalent mod
  485. # output arguments, and use a single entry in the python_arg_parser sig
  486. # list for both (output arguments become optional).
  487. @dataclass(frozen=True)
  488. class PythonSignatureGroup:
  489. # The signature used for Python argument parsing. The outplace signature
  490. # is preferred if exists, because it can be used to parse inputs for both
  491. # the out-place variant and the base version (with output omitted).
  492. signature: PythonSignature
  493. # The regular ATen declaration (e.g. conv2d)
  494. base: NativeFunction
  495. # The out variant (e.g. conv2d_out)
  496. outplace: NativeFunction | None
  497. @classmethod
  498. def from_pairs(
  499. cls,
  500. functional: PythonSignatureNativeFunctionPair,
  501. out: PythonSignatureNativeFunctionPair | None,
  502. ) -> PythonSignatureGroup:
  503. if out is None:
  504. return PythonSignatureGroup(
  505. signature=functional.signature,
  506. base=functional.function,
  507. outplace=None,
  508. )
  509. # prefer the signature with optional out=... arguments because it's the
  510. # superset that can be used to parse input for both base and outplace.
  511. signature_kwargs = out.signature.__dict__.copy()
  512. # Out overloads in C++ don't have TensorOptions arguments,
  513. # so take these from the functional variant
  514. signature_kwargs["tensor_options_args"] = (
  515. functional.signature.tensor_options_args
  516. )
  517. return PythonSignatureGroup(
  518. signature=type(out.signature)(**signature_kwargs),
  519. base=functional.function,
  520. outplace=out.function,
  521. )
  522. # C++ function dispatch is wrapped in a lambda function. The lambda function
  523. # has almost the same signature as the C++ function, only with some small
  524. # variants - see details below.
  525. # This data model is used to represent arguments of the lambda function
  526. # signature.
  527. @dataclass(frozen=True)
  528. class DispatchLambdaArgument:
  529. name: str
  530. type_str: str
  531. is_out_arg: bool
  532. # To pass PyObjects arguments to C++ function (via the lambda wrapper),
  533. # we need first convert PyObjects into simple C++ objects. This work
  534. # is done by PythonArgParser.
  535. # This data model is used to represent the output of PythonArgParser.
  536. # It has 1-1 mapping with PythonArgument in PythonSignature.
  537. @dataclass(frozen=True)
  538. class PythonArgParserOutputExpr:
  539. # argument name
  540. name: str
  541. # RHS expression to reference PythonArgParser output.
  542. expr: str
  543. # In some special cases we need create different expr, e.g.:
  544. # '_r.isNone(1)' instead of '_r.tensor(1)'.
  545. index: int
  546. # The python argument it maps to.
  547. argument: PythonArgument
  548. @property
  549. def is_none_expr(self) -> str:
  550. return f"_r.isNone({self.index})"
  551. # To pass PythonArgParser output to the lambda wrapper, we need bind
  552. # PythonArgParserOutputExpr to DispatchLambdaArgument.
  553. # They are not always 1-1 mapped, e.g. scattered TensorOptions fields
  554. # need be packed into a TensorOptions object, which is the argument
  555. # that the lambda function wrapper takes.
  556. @dataclass(frozen=True)
  557. class DispatchLambdaArgumentExprs:
  558. # The exprs that provide the binding for lambda arguments, e.g.:
  559. #
  560. # 'self' -> '_r.tensor(0)'
  561. # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
  562. # 'options' -> 'options'
  563. #
  564. # It has 1-1 mapping with DispatchLambdaArgument.
  565. exprs: Sequence[str]
  566. # Special local inits, which might introduce new variables that
  567. # the 'exprs' above reference, e.g.:
  568. #
  569. # 'auto out = _r.tensorlist_n<2>(2);'
  570. #
  571. inits: Sequence[str]
  572. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  573. #
  574. # Helper Functions
  575. #
  576. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  577. def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
  578. return CppSignatureGroup.from_native_function(f, method=method).signature
  579. def has_tensor_options(f: NativeFunction) -> bool:
  580. return f.func.arguments.tensor_options is not None
  581. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  582. #
  583. # Python Signature
  584. #
  585. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  586. # 'simple_type' was introduced by the old codegen, which is slightly
  587. # different from the python schema type, e.g.: doesn't have '?' suffix
  588. # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
  589. def argument_type_str(
  590. t: Type, *, simple_type: bool = False, symint: bool = True
  591. ) -> str:
  592. if isinstance(t, BaseType):
  593. if t.name == BaseTy.int:
  594. return "int64_t"
  595. elif t.name == BaseTy.float:
  596. return "double"
  597. elif t.name == BaseTy.str:
  598. return "c10::string_view"
  599. elif t.name in [
  600. BaseTy.Tensor,
  601. BaseTy.bool,
  602. BaseTy.QScheme,
  603. BaseTy.Scalar,
  604. BaseTy.ScalarType,
  605. BaseTy.Generator,
  606. BaseTy.Storage,
  607. BaseTy.Layout,
  608. BaseTy.Device,
  609. BaseTy.DeviceIndex,
  610. BaseTy.MemoryFormat,
  611. BaseTy.Dimname,
  612. BaseTy.Stream,
  613. BaseTy.SymInt,
  614. ]:
  615. # These python schema type names line up with their function schema names
  616. return t.name.name
  617. elif isinstance(t, OptionalType):
  618. elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
  619. return f"{elem}?"
  620. elif isinstance(t, ListType):
  621. size = t.size if not simple_type else None
  622. if str(t.elem) == "bool":
  623. assert t.size is not None
  624. return f"::std::array<bool,{t.size}>"
  625. elif str(t.elem) == "int":
  626. return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
  627. elif str(t.elem) == "SymInt":
  628. if symint:
  629. return (
  630. f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
  631. )
  632. else:
  633. return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
  634. elif str(t.elem) == "Tensor":
  635. return f"TensorList[{size}]" if size is not None else "TensorList"
  636. elif str(t.elem) == "Scalar":
  637. return f"ScalarList[{size}]" if size is not None else "ScalarList"
  638. elif str(t.elem) == "Tensor?":
  639. if simple_type:
  640. return "c10::List<::std::optional<Tensor>>"
  641. else:
  642. return "const c10::List<::std::optional<Tensor>> &"
  643. elif str(t.elem) == "Dimname":
  644. return f"DimnameList[{size}]" if size is not None else "DimnameList"
  645. elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
  646. return f"ArrayRef<{elem}>"
  647. raise RuntimeError(f"unrecognized type {repr(t)}")
  648. def argument_type_size(t: Type) -> int | None:
  649. l = t.is_list_like()
  650. if l is not None and str(l.elem) != "bool":
  651. return l.size
  652. else:
  653. return None
  654. def argument(a: Argument) -> PythonArgument:
  655. return PythonArgument(
  656. name=a.name,
  657. type=a.type,
  658. # TODO: directly translate a.default to python default
  659. default=(
  660. str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
  661. if a.default is not None
  662. else None
  663. ),
  664. default_init=None,
  665. )
  666. # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
  667. def signature(
  668. f: NativeFunction, *, method: bool = False, pyi: bool = False
  669. ) -> PythonSignature:
  670. return signature_from_schema(
  671. f.func, category_override=f.category_override, method=method, pyi=pyi
  672. )
  673. def signature_from_schema(
  674. func: FunctionSchema,
  675. *,
  676. category_override: str | None,
  677. method: bool = False,
  678. pyi: bool = False,
  679. ) -> PythonSignature:
  680. args: list[Argument] = []
  681. args.extend(func.arguments.pre_self_positional)
  682. # Skip SelfArgument if this is method.
  683. if not method and func.arguments.self_arg is not None:
  684. args.append(func.arguments.self_arg.argument)
  685. args.extend(func.arguments.post_self_positional)
  686. args.extend(func.arguments.pre_tensor_options_kwarg_only)
  687. # Skip TensorOptionsArguments. Python side TensorOptions
  688. # arguments are created based on different rules - see below.
  689. args.extend(func.arguments.post_tensor_options_kwarg_only)
  690. args.extend(func.arguments.out)
  691. input_arg_set = {a.name for a in func.arguments.flat_positional}
  692. kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
  693. out_arg_set = {a.name for a in func.arguments.out}
  694. input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
  695. input_kwargs = tuple(
  696. map(argument, filter(lambda a: a.name in kwarg_only_set, args))
  697. )
  698. outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
  699. # Reintroduce the scattered fields of TensorOptions for Python.
  700. # Compared to the cpp counterpart, the python arguments have new property
  701. # (default_init) and a new argument 'requires_grad', which require some
  702. # special handlings.
  703. # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
  704. # to the original versions in the yaml, this recreation is a potential
  705. # source of drift between eager and JIT. Pull this logic out to a shared place.
  706. has_tensor_input_arg = any(
  707. a.type.is_tensor_like() for a in func.arguments.flat_non_out
  708. )
  709. if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
  710. raise ValueError(
  711. "argument named requires_grad is reserved, should not explicitly add it in the schema"
  712. )
  713. # [old codegen] this probably won't work if one of the returns is not a tensor,
  714. # but it will produce a compile-time error that is obvious.
  715. has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
  716. name: str = cpp.name(func)
  717. is_factory_function = category_override == "factory" or (
  718. has_tensor_return and not has_tensor_input_arg
  719. )
  720. is_like_or_new_function = (
  721. category_override in ("new", "like")
  722. or name.startswith("new_")
  723. or name.endswith("_like")
  724. )
  725. is_dummy_function = category_override == "dummy"
  726. tensor_options_args: list[PythonArgument] = []
  727. if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
  728. def topt_default_init(name: str) -> str | None:
  729. topt_args = func.arguments.tensor_options
  730. if topt_args is None:
  731. return None
  732. a = getattr(topt_args, name)
  733. if a.default is None or a.default == "None":
  734. return None
  735. return cpp.default_expr(a.default, a.type, symint=False)
  736. tensor_options_args.append(
  737. PythonArgument(
  738. name="dtype",
  739. type=OptionalType(BaseType(BaseTy.ScalarType)),
  740. default="None",
  741. default_init=(
  742. None if is_like_or_new_function else topt_default_init("dtype")
  743. ),
  744. )
  745. )
  746. tensor_options_args.append(
  747. PythonArgument(
  748. name="layout",
  749. type=OptionalType(BaseType(BaseTy.Layout)),
  750. default="None",
  751. default_init=(
  752. None if is_like_or_new_function else topt_default_init("layout")
  753. ),
  754. )
  755. )
  756. tensor_options_args.append(
  757. PythonArgument(
  758. name="device",
  759. type=OptionalType(BaseType(BaseTy.Device)),
  760. default="None",
  761. default_init=(
  762. None
  763. if is_like_or_new_function
  764. else (
  765. topt_default_init("device")
  766. or "torch::tensors::get_default_device()"
  767. )
  768. ),
  769. )
  770. )
  771. tensor_options_args.append(
  772. PythonArgument(
  773. name="pin_memory",
  774. type=OptionalType(BaseType(BaseTy.bool)),
  775. default="False",
  776. default_init=None,
  777. )
  778. )
  779. tensor_options_args.append(
  780. PythonArgument(
  781. name="requires_grad",
  782. type=OptionalType(BaseType(BaseTy.bool)),
  783. default="False",
  784. default_init=None,
  785. )
  786. )
  787. returns = PythonReturns(returns=func.returns)
  788. return PythonSignature(
  789. name=str(func.name.name),
  790. input_args=input_args,
  791. input_kwargs=input_kwargs,
  792. output_args=PythonOutArgument.from_outputs(outputs),
  793. tensor_options_args=tuple(tensor_options_args),
  794. returns=returns,
  795. method=method,
  796. )
  797. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  798. #
  799. # Python Interface
  800. #
  801. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  802. def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
  803. if len(returns) <= 1 or all(r.name is None for r in returns):
  804. return []
  805. else:
  806. if any(r.name is None for r in returns):
  807. # When building on Windows, `PyStructSequence_UnnamedField` could not be
  808. # resolved by the linker for some reason, which cause error in building:
  809. #
  810. # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
  811. # PyStructSequence_UnnamedField
  812. #
  813. # Thus, at this point in time, we do not support unnamed
  814. # fields in structseq; you must either name all fields,
  815. # or none of them.
  816. raise ValueError("Unnamed field is not supported by codegen")
  817. return [str(r.name) for r in returns]
  818. def argument_type_str_pyi(t: Type) -> str:
  819. add_optional = False
  820. if isinstance(t, OptionalType):
  821. t = t.elem
  822. add_optional = True
  823. ret = ""
  824. if isinstance(t, BaseType):
  825. if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
  826. ret = "_int"
  827. if t.name == BaseTy.SymInt:
  828. ret = "_int | SymInt"
  829. elif t.name == BaseTy.float:
  830. ret = "_float"
  831. elif t.name == BaseTy.str:
  832. ret = "str"
  833. elif t.name == BaseTy.Scalar:
  834. ret = "Number | _complex"
  835. elif t.name == BaseTy.ScalarType:
  836. ret = "_dtype"
  837. elif t.name == BaseTy.bool:
  838. ret = "_bool"
  839. elif t.name == BaseTy.QScheme:
  840. ret = "_qscheme"
  841. elif t.name == BaseTy.Layout:
  842. ret = "_layout"
  843. elif t.name == BaseTy.Device:
  844. ret = "DeviceLikeType | None"
  845. elif t.name == BaseTy.MemoryFormat:
  846. ret = "memory_format"
  847. elif t.name == BaseTy.Dimname:
  848. ret = "str | EllipsisType | None"
  849. elif t.name == BaseTy.Storage:
  850. ret = "Storage | UntypedStorage"
  851. elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
  852. # These python schema type names line up with their function schema names
  853. ret = t.name.name
  854. elif isinstance(t, ListType):
  855. if str(t.elem) == "int":
  856. ret = "_int | _size" if t.size is not None else "_size"
  857. elif t.is_tensor_like():
  858. # TODO: this doesn't seem right...
  859. # Tensor?[] currently translates to tuple[Tensor, ...] | list[Tensor] | None
  860. # It should probably translate to tuple[Tensor | None, ...] | list[Tensor | None]
  861. add_optional = True
  862. ret = (
  863. "Tensor | tuple[Tensor, ...] | list[Tensor]"
  864. if t.size is not None
  865. else "tuple[Tensor, ...] | list[Tensor]"
  866. )
  867. elif str(t.elem) == "float":
  868. ret = "Sequence[_float]"
  869. elif str(t.elem) == "SymInt" and t.size is not None:
  870. elem = argument_type_str_pyi(t.elem)
  871. ret = f"{elem} | Sequence[{elem}]"
  872. else:
  873. elem = argument_type_str_pyi(t.elem)
  874. ret = f"Sequence[{elem}]"
  875. else:
  876. raise RuntimeError(f"unrecognized type {repr(t)}")
  877. if add_optional:
  878. ret = f"{ret} | None".replace(" | None | None", " | None")
  879. return ret
  880. def return_type_str_pyi(t: Type) -> str:
  881. # Where arguments are open to accepting Union, return types should return
  882. # concrete types
  883. if isinstance(t, OptionalType):
  884. inner = return_type_str_pyi(t.elem)
  885. return f"{inner} | None".replace(" | None | None", " | None")
  886. if isinstance(t, BaseType):
  887. if t.name == BaseTy.Device:
  888. return "_device"
  889. elif t.name == BaseTy.Dimname:
  890. return "str | None"
  891. else:
  892. return argument_type_str_pyi(t)
  893. if isinstance(t, ListType):
  894. inner = return_type_str_pyi(t.elem)
  895. return f"tuple[{inner}, ...]"
  896. return argument_type_str_pyi(t)
  897. def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
  898. python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
  899. structseq_name = signature.name
  900. field_names = structseq_fieldnames(signature.returns.returns)
  901. if field_names:
  902. # These types are structseq objects which act like named NamedTuples, but
  903. # the constructor acts like the constructor of tuple. Using typing.NamedTuple
  904. # does not allow us to override __init__.
  905. seq_type = f"tuple[{', '.join(python_returns)}]"
  906. structseq_def_lines = [
  907. f"class {structseq_name}({seq_type}): # fmt: skip",
  908. ]
  909. for name, ret_type in zip(field_names, python_returns):
  910. structseq_def_lines.extend(
  911. [
  912. " @property",
  913. f" def {name}(self) -> {ret_type}: ...",
  914. ]
  915. )
  916. structseq_def_lines.extend(
  917. [
  918. " def __new__(",
  919. " cls,",
  920. f" sequence: {seq_type},",
  921. " ) -> Self: # fmt: skip",
  922. " ...",
  923. f" n_fields: Final[_int] = {len(field_names)}",
  924. f" n_sequence_fields: Final[_int] = {len(field_names)}",
  925. " n_unnamed_fields: Final[_int] = 0",
  926. " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
  927. "", # add an extra newline
  928. ]
  929. )
  930. structseq_def = "\n".join(structseq_def_lines)
  931. # Example:
  932. # structseq_def = (
  933. # "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
  934. # " @property\n"
  935. # " def values(self) -> Tensor: ...\n"
  936. # " @property\n"
  937. # " def indices(self) -> Tensor: ...\n"
  938. # " def __new__(\n"
  939. # " cls,\n"
  940. # " sequence: tuple[Tensor, Tensor],\n"
  941. # " ) -> Self: # fmt: skip\n"
  942. # " ...\n"
  943. # " n_fields: Final[_int] = 2",
  944. # " n_sequence_fields: Final[_int] = 2",
  945. # " n_unnamed_fields: Final[_int] = 0",
  946. # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
  947. # )
  948. return structseq_name, structseq_def
  949. return None
  950. def returns_str_pyi(signature: PythonSignature) -> str:
  951. field_names = structseq_fieldnames(signature.returns.returns)
  952. if field_names:
  953. return f"torch.return_types.{signature.name}"
  954. python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
  955. if len(python_returns) > 1:
  956. return "tuple[" + ", ".join(python_returns) + "]"
  957. if len(python_returns) == 1:
  958. return python_returns[0]
  959. return "None"
  960. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  961. #
  962. # C++ Function Dispatch
  963. #
  964. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  965. # This section provides APIs to generate the code that does C++ function
  966. # dispatch. The C++ function call is wrapped by a lambda function.
  967. # For example:
  968. #
  969. # // aten::selu_(Tensor(a!) self) -> Tensor(a!)
  970. # auto dispatch_selu_ = [](Tensor self) -> Tensor {
  971. # pybind11::gil_scoped_release no_gil;
  972. # return at::selu_(self);
  973. # };
  974. #
  975. # The lambda function's signature follows the C++ signature in common
  976. # cases, e.g.:
  977. #
  978. # // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  979. # [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
  980. #
  981. # For out variant the 'out' argument's type is changed from 'Tensor &'
  982. # to 'Tensor'. It's because when calling the lambda it passes in the
  983. # PythonArgParser output '_r.tensor(3)', which is stack allocated object
  984. # and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
  985. #
  986. # // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
  987. # [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
  988. #
  989. # For multi-output case it can keep using reference type because the
  990. # PythonArgParser output has been unpacked to local variables, e.g.:
  991. #
  992. # // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
  993. # // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
  994. # [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
  995. #
  996. # For deprecated python signature, it should follow deprecated python arg order.
  997. # TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
  998. def dispatch_lambda_args(
  999. ps: PythonSignature, f: NativeFunction, symint: bool = True
  1000. ) -> tuple[DispatchLambdaArgument, ...]:
  1001. if isinstance(ps, PythonSignatureDeprecated):
  1002. schema = ps.deprecated_schema
  1003. else:
  1004. schema = f.func
  1005. # Start with cpp arguments - dispatch lambda signature always include 'self'
  1006. cpp_args = cpp.arguments(
  1007. arguments=schema.arguments,
  1008. faithful=False,
  1009. symint=symint,
  1010. method=False,
  1011. cpp_no_default_args=f.cpp_no_default_args,
  1012. )
  1013. out_args: set[str] = {a.name for a in schema.arguments.out}
  1014. # Convert from cpp argument to lambda argument
  1015. def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
  1016. type_str = cpp_arg.type
  1017. is_out_arg = cpp_arg.name in out_args
  1018. if ps.method and cpp_arg.name == "self":
  1019. # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
  1020. type_str = "const at::Tensor &"
  1021. else:
  1022. # For other cases we need prevent dangling refs to temps (unless it's
  1023. # unpacked scattered output)
  1024. # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
  1025. # TODO: avoid this special handling?
  1026. ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
  1027. if ensure_temp_safe:
  1028. type_str = {
  1029. "at::Tensor &": "at::Tensor",
  1030. }.get(type_str, type_str)
  1031. return DispatchLambdaArgument(
  1032. name=cpp_arg.name,
  1033. type_str=type_str,
  1034. is_out_arg=is_out_arg,
  1035. )
  1036. return tuple(map(dispatch_lambda_arg, cpp_args))
  1037. # [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
  1038. # it's enough to just extend the list here. Before you do this, make sure
  1039. # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
  1040. SUPPORTED_RETURN_TYPES = {
  1041. "at::Tensor",
  1042. "::std::tuple<at::Tensor,at::Tensor>",
  1043. "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
  1044. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  1045. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  1046. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  1047. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
  1048. "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
  1049. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
  1050. "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
  1051. "::std::tuple<double,int64_t>",
  1052. "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
  1053. "::std::vector<at::Tensor>",
  1054. # Needed for flash attention forw/backward
  1055. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
  1056. "at::Scalar",
  1057. "bool",
  1058. "int64_t",
  1059. "void*",
  1060. "void",
  1061. "at::QScheme",
  1062. "double",
  1063. "at::IntArrayRef",
  1064. "at::ScalarType",
  1065. "at::Stream",
  1066. }
  1067. def dispatch_lambda_return_str(f: NativeFunction) -> str:
  1068. # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
  1069. # because the dispatch lambdas take mutable arguments *by value*, not
  1070. # by reference. If you then return a reference to such an argument, you
  1071. # will now have a pointer to a dangling stack entry. Not good.
  1072. #
  1073. # You want:
  1074. #
  1075. # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
  1076. # ^^^^^^
  1077. #
  1078. # *not*
  1079. #
  1080. # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
  1081. # ^^^^^^^
  1082. #
  1083. # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
  1084. # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
  1085. # mutable reference to temporary. Maybe we could assign it to a
  1086. # variable itself.)
  1087. returns_without_annotation = tuple(
  1088. Return(r.name, r.type, None) for r in f.func.returns
  1089. )
  1090. return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
  1091. if return_str not in SUPPORTED_RETURN_TYPES:
  1092. raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
  1093. return return_str
  1094. def cpp_dispatch_target(f: NativeFunction) -> str:
  1095. symint = f.func.has_symint()
  1096. name = cpp.name(f.func, symint_overload=symint)
  1097. if Variant.method in f.variants:
  1098. return f"self.{name}"
  1099. if Variant.function in f.variants:
  1100. if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
  1101. namespace = "torch"
  1102. else:
  1103. namespace = "at"
  1104. return f"{namespace}::{name}"
  1105. raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
  1106. def cpp_dispatch_exprs(
  1107. f: NativeFunction,
  1108. *,
  1109. python_signature: PythonSignature | None = None,
  1110. ) -> tuple[str, ...]:
  1111. cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
  1112. exprs: tuple[str, ...] = ()
  1113. if not isinstance(python_signature, PythonSignatureDeprecated):
  1114. # By default the exprs are consistent with the C++ signature.
  1115. exprs = tuple(a.name for a in cpp_args)
  1116. else:
  1117. # For deprecated python signature we may need fill in some constants.
  1118. exprs = tuple(
  1119. filter(
  1120. lambda n: n != "out" or f.func.is_out_fn(),
  1121. python_signature.deprecated_args_exprs,
  1122. )
  1123. )
  1124. if Variant.method in f.variants:
  1125. exprs = tuple(filter("self".__ne__, exprs))
  1126. return exprs
  1127. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1128. #
  1129. # Python / C++ Args Binding
  1130. #
  1131. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1132. # We explicitly enumerate the PythonArgParser unpacking methods for all
  1133. # supported types. This might be more verbose than necessary, partially
  1134. # because of the irregularity of unpacking method naming, partially
  1135. # because we want to mimic the old codegen behavior - to reject
  1136. # unexpected and/or unsupported cases which the old codegen rejects.
  1137. # For certain cases it is intentionally more restrictive than necessary,
  1138. # e.g.: it doesn't accepts doublelist with definite size.
  1139. def arg_parser_unpack_method(
  1140. t: Type, default: str | None, default_init: str | None, *, symint: bool = True
  1141. ) -> str:
  1142. has_default_init = default_init is not None
  1143. if has_default_init and str(t) not in (
  1144. "ScalarType?",
  1145. "ScalarType",
  1146. "Device",
  1147. "Device?",
  1148. "Layout",
  1149. "Layout?",
  1150. "bool",
  1151. "bool?",
  1152. ):
  1153. raise RuntimeError(f"type '{t}' does not supported unpacking with default")
  1154. if isinstance(t, BaseType):
  1155. if t.name in [
  1156. BaseTy.Tensor,
  1157. BaseTy.Stream,
  1158. BaseTy.Storage,
  1159. BaseTy.Scalar,
  1160. BaseTy.Dimname,
  1161. ]:
  1162. # These unpack methods line up with their schema names
  1163. return t.name.name.lower()
  1164. elif t.name == BaseTy.ScalarType:
  1165. return "scalartypeWithDefault" if has_default_init else "scalartype"
  1166. elif t.name == BaseTy.Device:
  1167. return "deviceWithDefault" if has_default_init else "device"
  1168. elif t.name == BaseTy.DeviceIndex:
  1169. return "toInt64"
  1170. elif t.name == BaseTy.int:
  1171. return "toInt64"
  1172. elif t.name == BaseTy.SymInt:
  1173. return "toSymInt" if symint else "toInt64"
  1174. elif t.name == BaseTy.bool:
  1175. return "toBoolWithDefault" if has_default_init else "toBool"
  1176. elif t.name == BaseTy.float:
  1177. return "toDouble"
  1178. elif t.name == BaseTy.str:
  1179. return "stringView"
  1180. elif t.name == BaseTy.Layout:
  1181. return "layoutWithDefault" if has_default_init else "layout"
  1182. elif t.name == BaseTy.MemoryFormat:
  1183. return "memoryformat"
  1184. elif isinstance(t, OptionalType):
  1185. if str(t.elem) == "Tensor":
  1186. return "optionalTensor"
  1187. elif str(t.elem) == "Generator":
  1188. return "generator"
  1189. elif str(t.elem) == "Dimname[]":
  1190. return "toDimnameListOptional"
  1191. elif not has_default_init and default in (
  1192. None,
  1193. "None",
  1194. "::std::nullopt",
  1195. "std::nullopt",
  1196. ):
  1197. # If default is None: append 'Optional' to elem's unpacking method
  1198. return (
  1199. arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
  1200. )
  1201. else:
  1202. # Otherwise, load as underlying type with default
  1203. return arg_parser_unpack_method(
  1204. t.elem, default, default_init, symint=symint
  1205. )
  1206. elif isinstance(t, ListType):
  1207. if str(t.elem) == "Tensor":
  1208. # accept and use definite size
  1209. return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
  1210. elif str(t.elem) == "Tensor?":
  1211. return "list_of_optional_tensors"
  1212. elif str(t.elem) == "Dimname":
  1213. # accept definite size
  1214. return "dimnamelist"
  1215. elif str(t.elem) == "int":
  1216. # accept definite size
  1217. return "intlist"
  1218. elif str(t.elem) == "float":
  1219. return "doublelist"
  1220. elif str(t.elem) == "SymInt":
  1221. # accept definite size
  1222. return "symintlist" if symint else "intlist"
  1223. elif str(t.elem) == "Scalar":
  1224. return "scalarlist"
  1225. raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
  1226. # Return RHS expression for python argument using PythonArgParser output.
  1227. # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
  1228. def arg_parser_output_expr(
  1229. arg_index: int, a: PythonArgument, *, symint: bool = True
  1230. ) -> PythonArgParserOutputExpr:
  1231. has_default = a.default_init is not None
  1232. unpack_method = arg_parser_unpack_method(
  1233. t=a.type, default=a.default, default_init=a.default_init, symint=symint
  1234. )
  1235. default = f", {a.default_init}" if has_default else ""
  1236. expr = f"_r.{unpack_method}({arg_index}{default})"
  1237. return PythonArgParserOutputExpr(
  1238. name=a.name,
  1239. expr=expr,
  1240. index=arg_index,
  1241. argument=a,
  1242. )
  1243. # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
  1244. def arg_parser_output_exprs(
  1245. ps: PythonSignature, f: NativeFunction, *, symint: bool = True
  1246. ) -> dict[str, PythonArgParserOutputExpr]:
  1247. return {
  1248. e.name: e
  1249. for i, a in enumerate(ps.arguments())
  1250. for e in (arg_parser_output_expr(i, a, symint=symint),)
  1251. }
  1252. # argument name to type for scattered tensor options fields
  1253. TENSOR_OPTIONS_FIELDS = {
  1254. "dtype": "ScalarType?",
  1255. "device": "Device?",
  1256. "layout": "Layout?",
  1257. "pin_memory": "bool?",
  1258. "requires_grad": "bool?",
  1259. }
  1260. # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
  1261. def dispatch_lambda_exprs(
  1262. ps: PythonSignature, f: NativeFunction, *, symint: bool = True
  1263. ) -> DispatchLambdaArgumentExprs:
  1264. # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
  1265. # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
  1266. # outputs.
  1267. arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
  1268. lambda_args = dispatch_lambda_args(ps, f, symint=symint)
  1269. inits: list[str] = []
  1270. lambda_args_exprs: dict[str, str] = {}
  1271. has_toptions = has_tensor_options(f)
  1272. # 1. special inits/unpacking to provide binding exprs for lambda arguments.
  1273. for a in ps.arguments(skip_tensor_options=True):
  1274. name = a.name
  1275. arg_parser_expr = arg_parser_outputs[a.name].expr
  1276. if has_toptions and name == "self":
  1277. # TODO: why this needs to be special case?
  1278. inits.extend(
  1279. [
  1280. f"auto self = {arg_parser_expr};",
  1281. ]
  1282. )
  1283. lambda_args_exprs[name] = name
  1284. elif (
  1285. isinstance(a, PythonOutArgument)
  1286. and len(a.outputs) > 1
  1287. and f.func.is_out_fn()
  1288. ):
  1289. inits.extend(
  1290. [
  1291. f"auto out = {arg_parser_expr};",
  1292. ]
  1293. )
  1294. for i, out_arg in enumerate(a.outputs):
  1295. lambda_args_exprs[out_arg.name] = f"out[{i}]"
  1296. elif str(a.type) == "Dimname[]?":
  1297. # [old codegen]
  1298. # TODO: make this part of something more general, or get rid of it.
  1299. # optional<ArrayRef<T>> are special. The PythonArgParser returns an
  1300. # optional<vector<T>>, which cannot be implicitly converted to
  1301. # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
  1302. inits.extend(
  1303. [
  1304. f"auto __{name} = {arg_parser_expr};",
  1305. f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
  1306. ]
  1307. )
  1308. lambda_args_exprs[name] = name
  1309. else:
  1310. # default case - directly using PythonArgParser output expr
  1311. lambda_args_exprs[name] = arg_parser_expr
  1312. # method's self is passed directly to python binding, rather than parsed
  1313. if ps.method:
  1314. lambda_args_exprs["self"] = "self"
  1315. # 2. special packing/checking for TensorOptions.
  1316. tensor_options_args_names = [a.name for a in ps.tensor_options_args]
  1317. if has_toptions:
  1318. if f.func.is_out_fn():
  1319. raise RuntimeError(f"{f.func}: tensor options with output arg")
  1320. for a in ps.tensor_options_args:
  1321. if a.name not in TENSOR_OPTIONS_FIELDS:
  1322. raise RuntimeError(
  1323. f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
  1324. )
  1325. if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
  1326. raise RuntimeError(
  1327. f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
  1328. )
  1329. if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
  1330. raise RuntimeError(
  1331. f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
  1332. )
  1333. inits.append(
  1334. f"""\
  1335. const auto options = TensorOptions()
  1336. .dtype({arg_parser_outputs["dtype"].expr})
  1337. .device({arg_parser_outputs["device"].expr})
  1338. .layout({arg_parser_outputs["layout"].expr})
  1339. .requires_grad({arg_parser_outputs["requires_grad"].expr})
  1340. .pinned_memory({arg_parser_outputs["pin_memory"].expr});
  1341. torch::utils::maybe_initialize_device(options);
  1342. """
  1343. )
  1344. lambda_args_exprs["options"] = "options"
  1345. # 3. special case - access scattered TensorOptions fields without packing
  1346. # TODO: maybe move to the generator side as it's not related to binding.
  1347. if not has_toptions and tensor_options_args_names:
  1348. if "dtype" in tensor_options_args_names:
  1349. # we're an output-arg variant, check these args against output tensor
  1350. if not f.func.is_out_fn():
  1351. raise RuntimeError(
  1352. f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
  1353. )
  1354. if not all(a in tensor_options_args_names for a in ("layout", "device")):
  1355. raise RuntimeError(
  1356. f"{f.func}: incomplete tensor options for output check"
  1357. )
  1358. inits.append(
  1359. f"""\
  1360. check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
  1361. {arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
  1362. {arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
  1363. """
  1364. )
  1365. # we'll set requires_grad on outgoing tensor
  1366. if "requires_grad" not in tensor_options_args_names:
  1367. raise RuntimeError(
  1368. f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
  1369. )
  1370. return DispatchLambdaArgumentExprs(
  1371. exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
  1372. inits=inits,
  1373. )