cpp.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from typing_extensions import assert_never
  4. from torchgen import local
  5. from torchgen.api.types import (
  6. ArgName,
  7. ArrayCType,
  8. ArrayRefCType,
  9. BaseCType,
  10. BaseTypeToCppMapping,
  11. Binding,
  12. boolT,
  13. ConstRefCType,
  14. CType,
  15. dimnameListT,
  16. intArrayRefT,
  17. iTensorListRefT,
  18. ListCType,
  19. longT,
  20. MutRefCType,
  21. NamedCType,
  22. OptionalCType,
  23. optionalIntArrayRefT,
  24. optionalSymIntArrayRefT,
  25. scalarT,
  26. SpecialArgName,
  27. symIntArrayRefT,
  28. SymIntT,
  29. tensorListT,
  30. tensorOptionsT,
  31. tensorT,
  32. TupleCType,
  33. VectorCType,
  34. voidT,
  35. )
  36. from torchgen.model import (
  37. Argument,
  38. Arguments,
  39. BaseTy,
  40. BaseType,
  41. FunctionSchema,
  42. ListType,
  43. NativeFunction,
  44. OptionalType,
  45. Return,
  46. SelfArgument,
  47. TensorOptionsArguments,
  48. Type,
  49. )
  50. if TYPE_CHECKING:
  51. from collections.abc import Sequence
  52. # This file describes the translation of JIT schema to the public C++
  53. # API, which is what people use when they call functions like at::add.
  54. #
  55. # Prominent characteristics of the C++ API:
  56. #
  57. # - dtype, layout, device and pin_memory are collected into
  58. # a single C++ type TensorOptions (the native functions API
  59. # also has this, but tensor options is really most relevant
  60. # for the C++ API; it makes calling kwarg factory functions
  61. # pleasant)
  62. #
  63. # - defaulting lives here (in fact, the dispatcher is completely
  64. # oblivious of defaults!)
  65. #
  66. # BTW: policy on name collisions: we try not to have types with
  67. # collisions, but functions are fair game to collide
  68. def name(
  69. func: FunctionSchema,
  70. *,
  71. faithful_name_for_out_overloads: bool = False,
  72. symint_overload: bool = False,
  73. ) -> str:
  74. name = str(func.name.name)
  75. if symint_overload:
  76. name += "_symint"
  77. if func.is_out_fn():
  78. if faithful_name_for_out_overloads:
  79. name += "_outf"
  80. else:
  81. name += "_out"
  82. return name
  83. # Translation of "value types" in JIT schema to C++ API type. Value
  84. # types look the same no matter if they are argument types or return
  85. # types. Returns None if the type in question is not a value type.
  86. def valuetype_type(
  87. t: Type,
  88. *,
  89. binds: ArgName,
  90. mutable: bool = True,
  91. symint: bool = False,
  92. ) -> NamedCType | None:
  93. if isinstance(t, BaseType):
  94. if t.name in (BaseTy.Tensor, BaseTy.Scalar):
  95. return None
  96. elif str(t) == "SymInt":
  97. if symint:
  98. return NamedCType(binds, BaseCType(SymIntT))
  99. else:
  100. return NamedCType(binds, BaseCType(longT))
  101. # All other BaseType currently map directly to BaseCppTypes.
  102. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
  103. elif isinstance(t, OptionalType):
  104. elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
  105. if elem is None:
  106. return None
  107. return NamedCType(binds, OptionalCType(elem.type))
  108. elif isinstance(t, ListType):
  109. if str(t.elem) == "bool":
  110. assert t.size is not None
  111. return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
  112. else:
  113. return None
  114. else:
  115. raise AssertionError(f"unrecognized type {repr(t)}")
  116. # Translation of types occurring in JIT arguments to a C++ argument type.
  117. # If remove_non_owning_ref_types is set, we'll guarantee that the output CType is not a non-owning reference type.
  118. # For example, we'll return std::vector<int> instead of IntArrayRef.
  119. # See Note [translation from C++ reference to value types]
  120. def argumenttype_type(
  121. t: Type,
  122. *,
  123. mutable: bool,
  124. binds: ArgName,
  125. remove_non_owning_ref_types: bool = False,
  126. symint: bool = False,
  127. ) -> NamedCType:
  128. # If it's a value type, do the value type translation
  129. r = valuetype_type(
  130. t,
  131. binds=binds,
  132. mutable=mutable,
  133. symint=symint,
  134. )
  135. if r is not None:
  136. return r
  137. if isinstance(t, BaseType):
  138. if t.name == BaseTy.Tensor:
  139. if mutable and not local.use_const_ref_for_mutable_tensors():
  140. return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
  141. else:
  142. return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
  143. elif t.name == BaseTy.Scalar:
  144. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  145. else:
  146. raise AssertionError(f"base type should have been value type {t}")
  147. elif isinstance(t, OptionalType):
  148. if str(t.elem) == "Tensor":
  149. if mutable and not local.use_const_ref_for_mutable_tensors():
  150. return NamedCType(
  151. binds, MutRefCType(BaseCType(tensorT))
  152. ) # TODO: fix this discrepancy
  153. else:
  154. return NamedCType(
  155. binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
  156. )
  157. elif str(t.elem) == "Scalar":
  158. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  159. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
  160. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  161. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
  162. if symint:
  163. return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
  164. else:
  165. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  166. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
  167. return NamedCType(binds, OptionalCType(elem.type))
  168. elif isinstance(t, ListType):
  169. # TODO: remove these special cases, ArrayRef fallthrough works fine
  170. if str(t.elem) == "int":
  171. if remove_non_owning_ref_types:
  172. return NamedCType(binds, VectorCType(BaseCType(longT)))
  173. else:
  174. return NamedCType(binds, BaseCType(intArrayRefT))
  175. if str(t.elem) == "SymInt":
  176. if remove_non_owning_ref_types:
  177. if symint:
  178. return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
  179. else:
  180. return NamedCType(binds, VectorCType(BaseCType(longT)))
  181. else:
  182. if symint:
  183. return NamedCType(binds, BaseCType(symIntArrayRefT))
  184. else:
  185. return NamedCType(binds, BaseCType(intArrayRefT))
  186. if str(t.elem) == "Tensor":
  187. if local.use_ilistref_for_tensor_lists():
  188. return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
  189. else:
  190. return NamedCType(binds, BaseCType(tensorListT))
  191. elif str(t.elem) == "Scalar":
  192. return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
  193. elif str(t.elem) == "Dimname":
  194. return NamedCType(binds, BaseCType(dimnameListT))
  195. elif str(t.elem) == "Tensor?":
  196. return NamedCType(
  197. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  198. )
  199. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
  200. return NamedCType(binds, ArrayRefCType(elem.type))
  201. else:
  202. raise AssertionError(f"unrecognized type {repr(t)}")
  203. # Translate a JIT argument into its C++ type
  204. def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
  205. return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
  206. # Translation of a (non-multi) return type from JIT to C++
  207. # N.B: returntype_type returns a CType, not a NamedCType.
  208. # This is mostly because of the mismatch between return types and return names.
  209. # e.g. a function with a return type of 'void' has 0 return names,
  210. # and a function with a return type of 'std::tuple' has >1 return name.
  211. def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
  212. # placeholder is ignored
  213. # NB: symint is ALWAYS respected for return types. So symint argument
  214. # here is IGNORED
  215. r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
  216. if r is not None:
  217. return r.type
  218. if isinstance(t, BaseType):
  219. if t.name == BaseTy.Tensor:
  220. if mutable:
  221. if local.use_const_ref_for_mutable_tensors():
  222. return ConstRefCType(BaseCType(tensorT))
  223. else:
  224. return MutRefCType(BaseCType(tensorT))
  225. else:
  226. # Note [Tensor Copy Returns]
  227. # Currently, we use "Argument.is_write" to determine
  228. # whether or not Tensor return types should be copies or references.
  229. # If that ever changes, take a look at other locations of this note!
  230. return BaseCType(tensorT)
  231. elif t.name == BaseTy.Scalar:
  232. return BaseCType(scalarT)
  233. elif isinstance(t, ListType):
  234. assert not mutable, (
  235. "Native functions should never return a mutable tensor list. They should return void."
  236. )
  237. elem = returntype_type(t.elem, mutable=False)
  238. assert t.size is None, f"fixed size list returns not supported: {t}"
  239. return VectorCType(elem)
  240. elif isinstance(t, OptionalType):
  241. elem = returntype_type(t.elem, mutable=mutable)
  242. if str(t.elem) == "Tensor":
  243. return OptionalCType(elem)
  244. raise AssertionError(f"unrecognized return type {t}")
  245. # Translation of a single return to its C++ type
  246. def return_type(r: Return, *, symint: bool = False) -> CType:
  247. return returntype_type(r.type, mutable=r.is_write, symint=symint)
  248. # Translation of a full (possibly multi) return from JIT to its C++ type
  249. def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
  250. if len(rs) == 0:
  251. return BaseCType(voidT)
  252. elif len(rs) == 1:
  253. return return_type(rs[0], symint=symint)
  254. else:
  255. return TupleCType([return_type(r, symint=symint) for r in rs])
  256. def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
  257. returns: list[str] = []
  258. for i, r in enumerate(f.func.returns):
  259. # If we have an inplace function, the return argument is
  260. # implicitly named self.
  261. # TODO: Consider incorporating this into the data model
  262. if f.func.name.name.inplace:
  263. assert i == 0, "illegal inplace function with multiple returns"
  264. name = "self"
  265. # If we are out function, the name is the name of the
  266. # corresponding output function (r.name will get recorded
  267. # in field_name later.)
  268. elif f.func.is_out_fn():
  269. name = f.func.arguments.out[i].name
  270. # If the return argument is explicitly named...
  271. elif r.name:
  272. name_conflict = any(
  273. r.name == a.name for a in f.func.schema_order_arguments()
  274. )
  275. if name_conflict and not f.func.is_out_fn():
  276. name = f"{r.name}_return"
  277. else:
  278. name = r.name
  279. # If there is no explicit name and no fallback name was passed in, we just name the output result,
  280. # unless it's a multi-return, in which case it's result0,
  281. # result1, etc (zero-indexed)
  282. else:
  283. name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
  284. returns.append(name)
  285. return returns
  286. JIT_TO_CPP_DEFAULT = {
  287. "False": "false",
  288. "True": "true",
  289. "None": "::std::nullopt", # UGH this one is type directed
  290. "Mean": "at::Reduction::Mean",
  291. "[]": "{}",
  292. "contiguous_format": "c10::MemoryFormat::Contiguous",
  293. "long": "at::kLong",
  294. }
  295. # Convert a JIT default into C++ expression representing the default
  296. def default_expr(d: str, t: Type, *, symint: bool) -> str:
  297. if d == "None" and str(t) == "Tensor?":
  298. return "{}"
  299. if isinstance(t, BaseType) and t.name is BaseTy.str:
  300. # Schema allows single quotes but C++ needs double
  301. if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
  302. s = ""
  303. i = 1
  304. while i + 1 < len(d):
  305. if d[i] != "\\":
  306. if d[i] == '"':
  307. s += '\\"'
  308. else:
  309. s += d[i]
  310. i += 1
  311. else:
  312. if d[i + 1] == "'":
  313. s += "'"
  314. else:
  315. s += d[i : i + 2]
  316. i += 2
  317. return f'"{s}"'
  318. if isinstance(t, OptionalType):
  319. if d == "None":
  320. return "::std::nullopt"
  321. return default_expr(d, t.elem, symint=symint)
  322. if isinstance(t, ListType):
  323. if d.startswith("[") and d.endswith("]"):
  324. return "{" + d[1:-1] + "}"
  325. elif symint and d.isdigit() and str(t.elem) == "SymInt":
  326. return f"c10::SymInt({d})"
  327. elif t.size is None:
  328. # NOTE: Sized lists can have scalar defaults
  329. raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
  330. return JIT_TO_CPP_DEFAULT.get(d, d)
  331. # Convert an argument into its C++ API form
  332. def argument(
  333. a: Argument | TensorOptionsArguments | SelfArgument,
  334. *,
  335. cpp_no_default_args: set[str],
  336. method: bool,
  337. faithful: bool,
  338. symint: bool = False,
  339. has_tensor_options: bool,
  340. ) -> list[Binding]:
  341. def sub_argument(
  342. a: Argument | TensorOptionsArguments | SelfArgument,
  343. ) -> list[Binding]:
  344. return argument(
  345. a,
  346. cpp_no_default_args=cpp_no_default_args,
  347. method=method,
  348. faithful=faithful,
  349. symint=symint,
  350. has_tensor_options=has_tensor_options,
  351. )
  352. if isinstance(a, Argument):
  353. binds: ArgName
  354. if a.name == "memory_format" and has_tensor_options:
  355. binds = SpecialArgName.possibly_redundant_memory_format
  356. else:
  357. binds = a.name
  358. default: str | None = None
  359. if a.name not in cpp_no_default_args and a.default is not None:
  360. default = default_expr(a.default, a.type, symint=symint)
  361. return [
  362. Binding(
  363. nctype=argument_type(a, binds=binds, symint=symint),
  364. name=a.name,
  365. default=default,
  366. argument=a,
  367. )
  368. ]
  369. elif isinstance(a, TensorOptionsArguments):
  370. if faithful:
  371. return (
  372. sub_argument(a.dtype)
  373. + sub_argument(a.layout)
  374. + sub_argument(a.device)
  375. + sub_argument(a.pin_memory)
  376. )
  377. else:
  378. default = None
  379. # Enforced by NativeFunction.__post_init__
  380. assert "options" not in cpp_no_default_args
  381. if all(x.default == "None" for x in a.all()):
  382. default = "{}"
  383. elif a.dtype.default == "long":
  384. default = "at::kLong" # TODO: this is wrong
  385. return [
  386. Binding(
  387. nctype=NamedCType("options", BaseCType(tensorOptionsT)),
  388. name="options",
  389. default=default,
  390. argument=a,
  391. )
  392. ]
  393. elif isinstance(a, SelfArgument):
  394. if method:
  395. # Caller is responsible for installing implicit this in context!
  396. return []
  397. else:
  398. return sub_argument(a.argument)
  399. else:
  400. assert_never(a)
  401. def arguments(
  402. arguments: Arguments,
  403. *,
  404. faithful: bool,
  405. symint: bool = False,
  406. method: bool,
  407. cpp_no_default_args: set[str],
  408. ) -> list[Binding]:
  409. args: list[Argument | TensorOptionsArguments | SelfArgument] = []
  410. if faithful:
  411. args.extend(arguments.non_out)
  412. args.extend(arguments.out)
  413. else:
  414. args.extend(arguments.out)
  415. args.extend(arguments.non_out)
  416. return [
  417. r.no_default() if faithful else r
  418. for a in args
  419. for r in argument(
  420. a,
  421. faithful=faithful,
  422. symint=symint,
  423. method=method,
  424. has_tensor_options=arguments.tensor_options is not None,
  425. cpp_no_default_args=cpp_no_default_args,
  426. )
  427. ]