autocast_mode.h 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include <ATen/NativeFunctions.h>
  4. #include <ATen/Operators.h>
  5. #include <torch/library.h>
  6. #include <c10/core/impl/LocalDispatchKeySet.h>
  7. #include <c10/util/intrusive_ptr.h>
  8. namespace at::autocast {
  9. TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
  10. TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
  11. TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
  12. TORCH_API void set_autocast_dtype(
  13. at::DeviceType device_type,
  14. at::ScalarType dtype);
  15. TORCH_API void clear_cache();
  16. TORCH_API int increment_nesting();
  17. TORCH_API int decrement_nesting();
  18. TORCH_API bool is_autocast_cache_enabled();
  19. TORCH_API void set_autocast_cache_enabled(bool enabled);
  20. // deprecated CUDA-specific autocast APIs
  21. C10_DEPRECATED_MESSAGE(
  22. "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
  23. inline bool is_enabled() {
  24. TORCH_WARN_DEPRECATION(
  25. "at::autocast::",
  26. __func__,
  27. "() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
  28. return is_autocast_enabled(at::kCUDA);
  29. }
  30. C10_DEPRECATED_MESSAGE(
  31. "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
  32. inline void set_enabled(bool enabled) {
  33. TORCH_WARN_DEPRECATION(
  34. "at::autocast::",
  35. __func__,
  36. "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
  37. set_autocast_enabled(at::kCUDA, enabled);
  38. }
  39. C10_DEPRECATED_MESSAGE(
  40. "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
  41. inline at::ScalarType get_autocast_gpu_dtype() {
  42. TORCH_WARN_DEPRECATION(
  43. "at::autocast::",
  44. __func__,
  45. "() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
  46. return get_autocast_dtype(at::kCUDA);
  47. }
  48. C10_DEPRECATED_MESSAGE(
  49. "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
  50. inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
  51. TORCH_WARN_DEPRECATION(
  52. "at::autocast::",
  53. __func__,
  54. "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
  55. set_autocast_dtype(at::kCUDA, dtype);
  56. }
  57. #define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
  58. C10_DEPRECATED_MESSAGE( \
  59. "at::autocast::is_" #name \
  60. "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
  61. ") instead.") \
  62. inline bool is_##name##_enabled() { \
  63. TORCH_WARN_DEPRECATION( \
  64. "at::autocast::", \
  65. __func__, \
  66. "() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
  67. ") instead.") \
  68. return is_autocast_enabled(device_type); \
  69. } \
  70. \
  71. C10_DEPRECATED_MESSAGE( \
  72. "at::autocast::set_" #name \
  73. "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
  74. ", enabled) instead.") \
  75. inline void set_##name##_enabled(bool enabled) { \
  76. TORCH_WARN_DEPRECATION( \
  77. "at::autocast::", \
  78. __func__, \
  79. "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
  80. ", enabled) instead.") \
  81. set_autocast_enabled(device_type, enabled); \
  82. } \
  83. \
  84. C10_DEPRECATED_MESSAGE( \
  85. "at::autocast::get_autocast_" #name \
  86. "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
  87. ") instead.") \
  88. inline at::ScalarType get_autocast_##name##_dtype() { \
  89. TORCH_WARN_DEPRECATION( \
  90. "at::autocast::", \
  91. __func__, \
  92. "() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
  93. ") instead.") \
  94. return get_autocast_dtype(device_type); \
  95. } \
  96. \
  97. C10_DEPRECATED_MESSAGE( \
  98. "at::autocast::set_autocast_" #name \
  99. "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
  100. ", dtype) instead.") \
  101. inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
  102. TORCH_WARN_DEPRECATION( \
  103. "at::autocast::", \
  104. __func__, \
  105. "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
  106. ", dtype) instead.") \
  107. set_autocast_dtype(device_type, dtype); \
  108. }
  109. #define AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(_) \
  110. _(cpu, at::kCPU) \
  111. _(mtia, at::kMTIA) \
  112. _(xpu, at::kXPU) \
  113. _(xla, at::kXLA) \
  114. _(hpu, at::kHPU) \
  115. _(ipu, at::kIPU) \
  116. _(privateuseone, at::kPrivateUse1)
  117. // deprecated other backend specific autocast APIs
  118. // NOLINTNEXTLINE(misc-use-internal-linkage)
  119. AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
  120. const std::array<at::DeviceType, 10> _AUTOCAST_SUPPORTED_DEVICES{
  121. at::kCPU,
  122. at::kCUDA,
  123. at::kMTIA,
  124. at::kMAIA,
  125. at::kXPU,
  126. at::kIPU,
  127. at::kHPU,
  128. at::kXLA,
  129. at::kPrivateUse1,
  130. at::kMPS};
  131. namespace {
  132. inline bool is_autocast_eligible(
  133. const Tensor& tensor,
  134. c10::DeviceType device_type) {
  135. switch (device_type) {
  136. case c10::DeviceType::CUDA:
  137. return (tensor.is_cuda() || tensor.is_xla()) &&
  138. tensor.is_floating_point();
  139. case c10::DeviceType::CPU:
  140. return (tensor.is_cpu() || tensor.is_mkldnn()) &&
  141. tensor.is_floating_point();
  142. case c10::DeviceType::MTIA:
  143. return tensor.is_mtia() && tensor.is_floating_point();
  144. case c10::DeviceType::MAIA:
  145. return tensor.is_maia() && tensor.is_floating_point();
  146. case c10::DeviceType::XPU:
  147. return tensor.is_xpu() && tensor.is_floating_point();
  148. case c10::DeviceType::IPU:
  149. return tensor.is_ipu() && tensor.is_floating_point();
  150. case c10::DeviceType::HPU:
  151. return tensor.is_hpu() && tensor.is_floating_point();
  152. case c10::DeviceType::XLA:
  153. return tensor.is_xla() && tensor.is_floating_point();
  154. case c10::DeviceType::PrivateUse1:
  155. return tensor.is_privateuseone() && tensor.is_floating_point();
  156. case c10::DeviceType::MPS:
  157. return tensor.is_mps() && tensor.is_floating_point();
  158. default:
  159. return false;
  160. }
  161. }
  162. } // namespace
  163. inline DispatchKey get_autocast_dispatch_key_from_device_type(
  164. c10::DeviceType device_type) {
  165. switch (device_type) {
  166. case c10::DeviceType::CUDA:
  167. return DispatchKey::Autocast;
  168. case c10::DeviceType::CPU:
  169. return DispatchKey::AutocastCPU;
  170. case c10::DeviceType::MTIA:
  171. return DispatchKey::AutocastMTIA;
  172. case c10::DeviceType::MAIA:
  173. return DispatchKey::AutocastMAIA;
  174. case c10::DeviceType::XPU:
  175. return DispatchKey::AutocastXPU;
  176. case c10::DeviceType::IPU:
  177. return DispatchKey::AutocastIPU;
  178. case c10::DeviceType::HPU:
  179. return DispatchKey::AutocastHPU;
  180. case c10::DeviceType::XLA:
  181. return DispatchKey::AutocastXLA;
  182. case c10::DeviceType::PrivateUse1:
  183. return DispatchKey::AutocastPrivateUse1;
  184. case c10::DeviceType::MPS:
  185. return DispatchKey::AutocastMPS;
  186. default:
  187. TORCH_CHECK(
  188. false,
  189. "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
  190. }
  191. }
  192. inline bool is_autocast_available(c10::DeviceType device_type) {
  193. if (std::find(
  194. _AUTOCAST_SUPPORTED_DEVICES.begin(),
  195. _AUTOCAST_SUPPORTED_DEVICES.end(),
  196. device_type) != _AUTOCAST_SUPPORTED_DEVICES.end()) {
  197. return true;
  198. } else {
  199. return false;
  200. }
  201. }
  202. inline at::ScalarType get_lower_precision_fp_from_device_type(
  203. c10::DeviceType device_type) {
  204. if (is_autocast_available(device_type)) {
  205. return get_autocast_dtype(device_type);
  206. } else {
  207. TORCH_CHECK(
  208. false,
  209. "unknown device type for autocast in get_lower_precision_fp_from_device_type");
  210. }
  211. }
  212. /********************************************************************
  213. Logic to extract the promote type from any Tensor or TensorList args.
  214. ********************************************************************/
  215. // Overload to catch Tensor args.
  216. // If nextArg is floating-point, compare its scalar_type with our
  217. // current best guess for the promote type, and update if necessary.
  218. inline at::ScalarType prioritize(
  219. at::ScalarType current,
  220. const Tensor& nextArg,
  221. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  222. if (current == at::kDouble) {
  223. TORCH_CHECK(false, "promote type is double in at::autocast::prioritize");
  224. return current;
  225. }
  226. at::ScalarType lower_precision_fp =
  227. get_lower_precision_fp_from_device_type(device_type);
  228. if (is_autocast_eligible(nextArg, device_type)) {
  229. auto next = nextArg.scalar_type();
  230. if (next == at::kDouble) {
  231. return current; // ignores double tensors
  232. } else if (current == at::kFloat || next == at::kFloat) {
  233. return at::kFloat; // prioritizes float over lower_precision_fp
  234. } else if (current == lower_precision_fp && next == lower_precision_fp) {
  235. return lower_precision_fp;
  236. } else {
  237. TORCH_CHECK(
  238. false, "Unexpected floating ScalarType in at::autocast::prioritize");
  239. return current;
  240. }
  241. } else {
  242. return current;
  243. }
  244. }
  245. // Overload to catch TensorList args (for e.g. cat, stack).
  246. // Reuses the overload above to process each Tensor in the list.
  247. inline at::ScalarType prioritize(
  248. at::ScalarType current,
  249. const TensorList& list,
  250. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  251. for (const auto& tensor : list) {
  252. current = prioritize(current, tensor, device_type);
  253. }
  254. return current;
  255. }
  256. inline at::ScalarType prioritize(
  257. at::ScalarType current,
  258. const ITensorListRef& list,
  259. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  260. for (const auto& tensor : list) {
  261. current = prioritize(current, tensor, device_type);
  262. }
  263. return current;
  264. }
  265. // Template to catch non-Tensor args (no-op that returns current best guess)
  266. template <typename T>
  267. inline at::ScalarType prioritize(
  268. at::ScalarType current,
  269. T nextArg,
  270. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  271. return current;
  272. }
  273. // Overload for the tail case.
  274. inline at::ScalarType promote_type(
  275. at::ScalarType current,
  276. c10::DeviceType device_type) {
  277. return current;
  278. }
  279. // Unpack args and determine if incoming lower_precision_fp tensors need to be
  280. // promoted to float32. Non-Tensor arguments are ignored.
  281. template <typename Arg0, typename... Args>
  282. inline at::ScalarType promote_type(
  283. at::ScalarType current,
  284. c10::DeviceType device_type,
  285. Arg0 arg0,
  286. Args... args) {
  287. auto new_current = prioritize(current, arg0, device_type);
  288. return promote_type(new_current, device_type, args...);
  289. }
  290. /****************************************************
  291. Logic to apply cached casting to any Tensor argument.
  292. ****************************************************/
  293. inline bool is_eligible(
  294. const Tensor& arg,
  295. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  296. return (
  297. arg.defined() && is_autocast_eligible(arg, device_type) &&
  298. (arg.scalar_type() != at::kDouble));
  299. }
  300. // Overload to catch Tensor args
  301. TORCH_API Tensor cached_cast(
  302. at::ScalarType to_type,
  303. const Tensor& arg,
  304. c10::DeviceType device_type = c10::DeviceType::CUDA);
  305. // Overload to process std::optional<Tensor>
  306. inline std::optional<Tensor> cached_cast(
  307. at::ScalarType to_type,
  308. const std::optional<Tensor>& arg,
  309. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  310. if (arg.has_value()) {
  311. return cached_cast(to_type, *arg, device_type);
  312. } else {
  313. return std::nullopt;
  314. }
  315. }
  316. // Overload to process TensorLists
  317. inline std::vector<Tensor> cached_cast(
  318. at::ScalarType to_type,
  319. const TensorList& arg,
  320. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  321. std::vector<Tensor> vec;
  322. vec.reserve(arg.size());
  323. for (const auto& t : arg) {
  324. vec.emplace_back(cached_cast(to_type, t, device_type));
  325. }
  326. return vec;
  327. }
  328. inline std::vector<Tensor> cached_cast(
  329. at::ScalarType to_type,
  330. const ITensorListRef& arg,
  331. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  332. std::vector<Tensor> vec;
  333. vec.reserve(arg.size());
  334. for (const auto& t : arg) {
  335. vec.emplace_back(cached_cast(to_type, t, device_type));
  336. }
  337. return vec;
  338. }
  339. // Template to catch non-Tensor args.
  340. template <typename T>
  341. inline T cached_cast(
  342. at::ScalarType to_type,
  343. T arg,
  344. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  345. return arg;
  346. }
  347. /*******************************************************
  348. Logic to flip an output dtype flag.
  349. Keep it simple for now by assuming only one such flag is
  350. present in the argument list. If I ever need a function
  351. with more than flag I'll figure out something else.
  352. The policy is:
  353. If the user has explicitly specified a dtype, respect it.
  354. Otherwise, set it to the autocast type.
  355. ********************************************************/
  356. // Overload to catch dtype flags
  357. std::optional<ScalarType> inline set_opt_dtype(
  358. at::ScalarType to_type,
  359. const std::optional<ScalarType>& dtype) {
  360. return dtype.has_value() ? dtype : to_type;
  361. }
  362. // Template to catch other args
  363. template <typename T>
  364. inline T set_opt_dtype(at::ScalarType to_type, T arg) {
  365. return arg;
  366. }
  367. template <typename... Args>
  368. inline bool firstarg_is_eligible(
  369. c10::DeviceType device_type,
  370. const Tensor& arg,
  371. Args... args) {
  372. return is_eligible(arg, device_type);
  373. }
  374. template <typename... Args>
  375. inline at::ScalarType type_from_firstarg(
  376. c10::DeviceType device_type,
  377. at::ScalarType to_type,
  378. const Tensor& arg,
  379. Args... args) {
  380. return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
  381. }
  382. // Policies correspond to op categories that need code-divergent handling.
  383. // Wrapper templates below are specialized based on a policy template parameter.
  384. enum class CastPolicy : uint8_t {
  385. lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
  386. // running the op. Currently, lower_precision_fp is
  387. // fp16 for AutocastCUDA, and is defined by user
  388. // (default bf16) for AutocastCPU or other device.
  389. fp32, // Cast all inputs to at::kFloat before running the op.
  390. fp32_set_opt_dtype, // Treats functions (like softmax) that
  391. // 1. we'd like to run in fp32 and
  392. // 2. have a std::optional<ScalarType> arg that controls
  393. // the output type.
  394. // fp32_set_opt_dtype wrappers' policy is: if the output
  395. // type is already set, don't touch it, otherwise, set
  396. // it to at::kFloat.
  397. fp32_append_dtype, // Treats functions (like norm) that
  398. // 1. we'd like to run in fp32 and
  399. // 2. have some overloads that accept an output type and
  400. // other overloads that don't.
  401. // fp32_append_dtype wrappers wrap the overloads that don't
  402. // have an output dtype.
  403. // The wrapper policy is: append at::kFloat to the args,
  404. // and redispatch to the type-aware overload.
  405. promote, // Run in the widest dtype among several args.
  406. };
  407. /********************************************************************************************************
  408. Templates to provide wrapper functions
  409. I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
  410. extract args and return type. (see also
  411. https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
  412. This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
  413. of (in my case several specializations of) an interior "WrapFunction_".
  414. Interior WrapFunction_ specializations are defined for each CastPolicy.
  415. ********************************************************************************************************/
  416. // Base template for WrapFunction_, which is specialized to contain a "call"
  417. // method each CastPolicy
  418. template <
  419. CastPolicy policy,
  420. c10::DeviceType device_type,
  421. class Redispatch,
  422. Redispatch* F,
  423. class Ret,
  424. class ArgList>
  425. struct WrapFunction_ {};
  426. // CastPolicy::lower_precision_fp General_DeviceType
  427. template <
  428. c10::DeviceType device_type,
  429. class Redispatch,
  430. Redispatch* F,
  431. class Ret,
  432. class... Args>
  433. struct WrapFunction_<
  434. CastPolicy::lower_precision_fp,
  435. device_type,
  436. Redispatch,
  437. F,
  438. Ret,
  439. guts::typelist::typelist<Args...>> {
  440. static Ret call(Args... args) {
  441. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  442. get_autocast_dispatch_key_from_device_type(device_type));
  443. return (*F)(cached_cast(
  444. get_lower_precision_fp_from_device_type(device_type),
  445. args,
  446. device_type)...);
  447. }
  448. };
  449. // CastPolicy::fp32 General_DeviceType
  450. template <
  451. c10::DeviceType device_type,
  452. class Redispatch,
  453. Redispatch* F,
  454. class Ret,
  455. class... Args>
  456. struct WrapFunction_<
  457. CastPolicy::fp32,
  458. device_type,
  459. Redispatch,
  460. F,
  461. Ret,
  462. guts::typelist::typelist<Args...>> {
  463. static Ret call(Args... args) {
  464. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  465. get_autocast_dispatch_key_from_device_type(device_type));
  466. return (*F)(cached_cast(at::kFloat, args, device_type)...);
  467. }
  468. };
  469. // CastPolicy::fp32_set_opt_dtype General_DeviceType
  470. template <
  471. c10::DeviceType device_type,
  472. class Redispatch,
  473. Redispatch* F,
  474. class Ret,
  475. class... Args>
  476. struct WrapFunction_<
  477. CastPolicy::fp32_set_opt_dtype,
  478. device_type,
  479. Redispatch,
  480. F,
  481. Ret,
  482. guts::typelist::typelist<Args...>> {
  483. static Ret call(Args... args) {
  484. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  485. get_autocast_dispatch_key_from_device_type(device_type));
  486. if (firstarg_is_eligible(device_type, args...)) {
  487. return (*F)(set_opt_dtype(at::kFloat, args)...);
  488. } else {
  489. // If ineligible, calls F with unaltered args. Does not set opt dtype,
  490. // because setting opt dtype explicitly may interfere with internal
  491. // implicit promotion decisions.
  492. return (*F)(args...);
  493. }
  494. }
  495. };
  496. // CastPolicy::fp32_append_dtype General_DeviceType
  497. template <
  498. c10::DeviceType device_type,
  499. class Redispatch,
  500. Redispatch* F,
  501. class Ret,
  502. class... Args>
  503. struct WrapFunction_<
  504. CastPolicy::fp32_append_dtype,
  505. device_type,
  506. Redispatch,
  507. F,
  508. Ret,
  509. guts::typelist::typelist<Args...>> {
  510. static Ret call(Args... args) {
  511. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  512. get_autocast_dispatch_key_from_device_type(device_type));
  513. at::ScalarType out_type =
  514. type_from_firstarg(device_type, at::kFloat, args...);
  515. return (*F)(args..., out_type);
  516. }
  517. };
  518. // CastPolicy::promote General_DeviceType
  519. template <
  520. c10::DeviceType device_type,
  521. class Redispatch,
  522. Redispatch* F,
  523. class Ret,
  524. class... Args>
  525. struct WrapFunction_<
  526. CastPolicy::promote,
  527. device_type,
  528. Redispatch,
  529. F,
  530. Ret,
  531. guts::typelist::typelist<Args...>> {
  532. static Ret call(Args... args) {
  533. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  534. get_autocast_dispatch_key_from_device_type(device_type));
  535. auto to_type = promote_type(
  536. get_lower_precision_fp_from_device_type(device_type),
  537. device_type,
  538. args...);
  539. return (*F)(cached_cast(to_type, args, device_type)...);
  540. }
  541. };
  542. // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
  543. // core/boxing/impl/WrapFunctionIntoFunctor.h)
  544. template <
  545. CastPolicy policy,
  546. c10::DeviceType device_type,
  547. class Registered, // The signature for which we're registering. The
  548. // dispatcher's calling code invokes our registered
  549. // functions with arguments matching Registered, so we
  550. // register WrapFunction_::call methods with a matching
  551. // signature to properly field those arguments.
  552. // guts::function_traits below extracts return_type and
  553. // parameter_types from Registered, which WrapFunction_
  554. // templates above use to declare their call methods.
  555. class Redispatch, // The signature for the function we're redispatching to.
  556. // In most cases this is the same as Registered, but for
  557. // some ops (for example, ops where we append a dtype)
  558. // it's useful to redispatch to a function with a
  559. // different signature.
  560. Redispatch* F> // The actual function we're redispatching to.
  561. struct WrapFunction final {
  562. using type = WrapFunction_<
  563. policy,
  564. device_type,
  565. Redispatch,
  566. F,
  567. typename guts::function_traits<Registered>::return_type,
  568. typename guts::function_traits<Registered>::parameter_types>;
  569. };
  570. /*****************************************************************************************************************
  571. This section performs load-time registration for autocast wrappers.
  572. It's debatable at what level operations should be patched. We'd like casts to
  573. be autograd-exposed and precede autograd history recording, so that for
  574. lower_precision_fp ops, input tensors are saved for backward in
  575. lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp
  576. can significantly reduce a model's memory footprint.
  577. Option 1 (strawman): Patch only at the level of explicit calls into
  578. cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
  579. guaranteed to use Tensor Cores, therefore they're the ones that will benefit
  580. most from lower_precision_fp. Potential pitfall: convolutions (and other ops)
  581. are wrapped in several layers of at::* calls. If one of those happens to record
  582. autograd history, then we've lost the opportunity to save inputs in
  583. lower_precision_fp.
  584. Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd
  585. history recording can't sneak in ahead of autocast. This mirrors Apex most
  586. closely.
  587. I think Option 2 is the right answer for all ops, not just convolutions. Option
  588. 2 is what I implement here.
  589. *****************************************************************************************************************/
  590. /********************************************************************************************************************
  591. Explicit registration for out-of-place ops
  592. The stuff below could be codegenned. Ed said
  593. > you are going to have to write the function definition at some point, I
  594. wouldn't try to get clever about it Therefore, for the moment, this is all
  595. copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
  596. ********************************************************************************************************************/
  597. } // namespace at::autocast
  598. #define ADD_NS(RAW_OP) at::RAW_OP
  599. #define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N
  600. #define _KERNEL_OVERLOAD_NARG(...) \
  601. C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))
  602. // Common cases where registration signature matches redispatch signature
  603. // (that's why SIGNATURE is repeated in the WrapFunction instantiation)
  604. #define KERNEL1(DISPATCHKEY, OP, POLICY) \
  605. m.impl( \
  606. TORCH_SELECTIVE_NAME("aten::" #OP), \
  607. &::at::autocast::WrapFunction< \
  608. ::at::autocast::CastPolicy::POLICY, \
  609. DISPATCHKEY, \
  610. decltype(ATEN_FN(OP)), \
  611. decltype(ATEN_FN(OP)), \
  612. &ATEN_FN(OP)>::type::call);
  613. #define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \
  614. m.impl( \
  615. TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
  616. &::at::autocast::WrapFunction< \
  617. ::at::autocast::CastPolicy::POLICY, \
  618. DISPATCHKEY, \
  619. decltype(ATEN_FN2(OP, OVERLOAD)), \
  620. decltype(ATEN_FN2(OP, OVERLOAD)), \
  621. &ATEN_FN2(OP, OVERLOAD)>::type::call);
  622. #define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \
  623. C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__)
  624. #define _KERNEL_IMPL(DISPATCHKEY, ...) \
  625. _KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__)
  626. // It will dispatch to KERNEL1 or KERNEL2 based on its inputs.
  627. #define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__)
  628. // Less-common but still useful case: redispatching to a function
  629. // with a new signature (e.g. appending a dtype)
  630. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  631. DISPATCHKEY, \
  632. REDISPATCH_FUNC, \
  633. REGISTER_NAME, \
  634. REGISTER_SIGNATURE, \
  635. REDISPATCH_SIGNATURE, \
  636. POLICY) \
  637. m.impl( \
  638. TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
  639. &::at::autocast::WrapFunction< \
  640. ::at::autocast::CastPolicy::POLICY, \
  641. DISPATCHKEY, \
  642. REGISTER_SIGNATURE, \
  643. REDISPATCH_SIGNATURE, \
  644. &REDISPATCH_FUNC>::type::call);
  645. // KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
  646. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU
  647. #define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__)
  648. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
  649. REDISPATCH_FUNC, \
  650. REGISTER_NAME, \
  651. REGISTER_SIGNATURE, \
  652. REDISPATCH_SIGNATURE, \
  653. POLICY) \
  654. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  655. c10::DeviceType::CPU, \
  656. REDISPATCH_FUNC, \
  657. REGISTER_NAME, \
  658. REGISTER_SIGNATURE, \
  659. REDISPATCH_SIGNATURE, \
  660. POLICY)
  661. // KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
  662. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA
  663. #define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__)
  664. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
  665. REDISPATCH_FUNC, \
  666. REGISTER_NAME, \
  667. REGISTER_SIGNATURE, \
  668. REDISPATCH_SIGNATURE, \
  669. POLICY) \
  670. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  671. c10::DeviceType::CUDA, \
  672. REDISPATCH_FUNC, \
  673. REGISTER_NAME, \
  674. REGISTER_SIGNATURE, \
  675. REDISPATCH_SIGNATURE, \
  676. POLICY)
  677. // KERNEL_MTIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA
  678. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMTIA
  679. #define KERNEL_MTIA(...) KERNEL(c10::DeviceType::MTIA, __VA_ARGS__)
  680. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA( \
  681. REDISPATCH_FUNC, \
  682. REGISTER_NAME, \
  683. REGISTER_SIGNATURE, \
  684. REDISPATCH_SIGNATURE, \
  685. POLICY) \
  686. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  687. c10::DeviceType::MTIA, \
  688. REDISPATCH_FUNC, \
  689. REGISTER_NAME, \
  690. REGISTER_SIGNATURE, \
  691. REDISPATCH_SIGNATURE, \
  692. POLICY)
  693. // KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA
  694. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA
  695. #define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__)
  696. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \
  697. REDISPATCH_FUNC, \
  698. REGISTER_NAME, \
  699. REGISTER_SIGNATURE, \
  700. REDISPATCH_SIGNATURE, \
  701. POLICY) \
  702. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  703. c10::DeviceType::MAIA, \
  704. REDISPATCH_FUNC, \
  705. REGISTER_NAME, \
  706. REGISTER_SIGNATURE, \
  707. REDISPATCH_SIGNATURE, \
  708. POLICY)
  709. // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
  710. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
  711. #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
  712. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \
  713. REDISPATCH_FUNC, \
  714. REGISTER_NAME, \
  715. REGISTER_SIGNATURE, \
  716. REDISPATCH_SIGNATURE, \
  717. POLICY) \
  718. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  719. c10::DeviceType::XPU, \
  720. REDISPATCH_FUNC, \
  721. REGISTER_NAME, \
  722. REGISTER_SIGNATURE, \
  723. REDISPATCH_SIGNATURE, \
  724. POLICY)
  725. // KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
  726. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
  727. #define KERNEL_PRIVATEUSEONE(...) \
  728. KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
  729. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
  730. REDISPATCH_FUNC, \
  731. REGISTER_NAME, \
  732. REGISTER_SIGNATURE, \
  733. REDISPATCH_SIGNATURE, \
  734. POLICY) \
  735. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  736. c10::DeviceType::PrivateUse1, \
  737. REDISPATCH_FUNC, \
  738. REGISTER_NAME, \
  739. REGISTER_SIGNATURE, \
  740. REDISPATCH_SIGNATURE, \
  741. POLICY)
  742. // KERNEL_MPS
  743. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS
  744. #define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__)
  745. // Op lists for different policies.
  746. // To make sure other backends can reuse the policy op list.
  747. #define AT_FORALL_LOWER_PRECISION_FP(_) \
  748. _(_convolution, deprecated) \
  749. _(_convolution) \
  750. _(conv1d) \
  751. _(conv2d) \
  752. _(conv3d) \
  753. _(conv_tbc) \
  754. _(conv_transpose1d) \
  755. _(conv_transpose2d, input) \
  756. _(conv_transpose3d, input) \
  757. _(convolution) \
  758. _(prelu) \
  759. _(addmm) \
  760. _(addmv) \
  761. _(addr) \
  762. _(matmul) \
  763. _(einsum) \
  764. _(mm) \
  765. _(mv) \
  766. _(linalg_vecdot) \
  767. _(linear) \
  768. _(addbmm) \
  769. _(baddbmm) \
  770. _(bmm) \
  771. _(chain_matmul) \
  772. _(linalg_multi_dot) \
  773. _(_thnn_fused_lstm_cell) \
  774. _(_thnn_fused_gru_cell) \
  775. _(lstm_cell) \
  776. _(gru_cell) \
  777. _(rnn_tanh_cell) \
  778. _(rnn_relu_cell) \
  779. _(_scaled_dot_product_flash_attention) \
  780. _(scaled_dot_product_attention)
  781. #define AT_FORALL_FP32(_) \
  782. _(acos) \
  783. _(asin) \
  784. _(cosh) \
  785. _(erfinv) \
  786. _(exp) \
  787. _(expm1) \
  788. _(log) \
  789. _(log10) \
  790. _(log2) \
  791. _(log1p) \
  792. _(reciprocal) \
  793. _(rsqrt) \
  794. _(sinh) \
  795. _(tan) \
  796. _(pow, Tensor_Scalar) \
  797. _(pow, Tensor_Tensor) \
  798. _(pow, Scalar) \
  799. _(softplus) \
  800. _(layer_norm) \
  801. _(native_layer_norm) \
  802. _(group_norm) \
  803. _(frobenius_norm, dim) \
  804. _(nuclear_norm) \
  805. _(nuclear_norm, dim) \
  806. _(cosine_similarity) \
  807. _(poisson_nll_loss) \
  808. _(cosine_embedding_loss) \
  809. _(nll_loss) \
  810. _(nll_loss2d) \
  811. _(hinge_embedding_loss) \
  812. _(kl_div) \
  813. _(l1_loss) \
  814. _(smooth_l1_loss) \
  815. _(huber_loss) \
  816. _(mse_loss) \
  817. _(margin_ranking_loss) \
  818. _(multilabel_margin_loss) \
  819. _(soft_margin_loss) \
  820. _(triplet_margin_loss) \
  821. _(multi_margin_loss) \
  822. _(binary_cross_entropy_with_logits) \
  823. _(dist) \
  824. _(pdist) \
  825. _(cdist) \
  826. _(renorm) \
  827. _(logsumexp) \
  828. _(upsample_nearest1d) \
  829. _(_upsample_nearest_exact1d) \
  830. _(upsample_nearest2d) \
  831. _(_upsample_nearest_exact2d) \
  832. _(upsample_nearest3d) \
  833. _(_upsample_nearest_exact3d) \
  834. _(upsample_linear1d) \
  835. _(upsample_bilinear2d) \
  836. _(_upsample_bilinear2d_aa) \
  837. _(upsample_trilinear3d) \
  838. _(upsample_bicubic2d) \
  839. _(_upsample_bicubic2d_aa)
  840. #define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
  841. _(prod) \
  842. _(prod, dim_int) \
  843. _(prod, dim_Dimname) \
  844. _(softmax, int) \
  845. _(softmax, Dimname) \
  846. _(log_softmax, int) \
  847. _(log_softmax, Dimname) \
  848. _(cumprod) \
  849. _(cumprod, dimname) \
  850. _(cumsum) \
  851. _(cumsum, dimname) \
  852. _(linalg_vector_norm) \
  853. _(linalg_matrix_norm) \
  854. _(linalg_matrix_norm, str_ord) \
  855. _(sum) \
  856. _(sum, dim_IntList) \
  857. _(sum, dim_DimnameList)
  858. #define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
  859. _(ADD_NS(norm), \
  860. "norm.Scalar", \
  861. Tensor(const Tensor&, const Scalar&), \
  862. Tensor(const Tensor&, const std::optional<Scalar>&, ScalarType), \
  863. fp32_append_dtype) \
  864. _(ADD_NS(norm), \
  865. "norm.ScalarOpt_dim", \
  866. Tensor(const Tensor&, const std::optional<Scalar>&, IntArrayRef, bool), \
  867. Tensor( \
  868. const Tensor&, \
  869. const std::optional<Scalar>&, \
  870. IntArrayRef, \
  871. bool, \
  872. ScalarType), \
  873. fp32_append_dtype) \
  874. _(ADD_NS(norm), \
  875. "norm.names_ScalarOpt_dim", \
  876. Tensor(const Tensor&, const std::optional<Scalar>&, DimnameList, bool), \
  877. Tensor( \
  878. const Tensor&, \
  879. const std::optional<Scalar>&, \
  880. DimnameList, \
  881. bool, \
  882. ScalarType), \
  883. fp32_append_dtype)
  884. #define AT_FORALL_PROMOTE(_) \
  885. _(addcdiv) \
  886. _(addcmul) \
  887. _(atan2) \
  888. _(bilinear) \
  889. _(cross) \
  890. _(dot) \
  891. _(vdot) \
  892. _(grid_sampler) \
  893. _(index_put) \
  894. _(tensordot) \
  895. _(scatter_add)