Context.h 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/BlasBackend.h>
  4. #include <ATen/CPUGeneratorImpl.h>
  5. #include <ATen/DeviceAccelerator.h>
  6. #include <ATen/LinalgBackend.h>
  7. #include <ATen/ROCmFABackend.h>
  8. #include <ATen/SDPBackend.h>
  9. #include <ATen/core/ATenGeneral.h>
  10. #include <ATen/core/DeprecatedTypeProperties.h>
  11. #include <ATen/core/Generator.h>
  12. #include <ATen/core/LegacyTypeDispatch.h>
  13. #include <ATen/detail/AcceleratorHooksInterface.h>
  14. #include <ATen/detail/CUDAHooksInterface.h>
  15. #include <ATen/detail/HIPHooksInterface.h>
  16. #include <ATen/detail/HPUHooksInterface.h>
  17. #include <ATen/detail/IPUHooksInterface.h>
  18. #include <ATen/detail/MAIAHooksInterface.h>
  19. #include <ATen/detail/MPSHooksInterface.h>
  20. #include <ATen/detail/MTIAHooksInterface.h>
  21. #include <ATen/detail/PrivateUse1HooksInterface.h>
  22. #include <ATen/detail/XLAHooksInterface.h>
  23. #include <ATen/detail/XPUHooksInterface.h>
  24. #include <c10/core/QEngine.h>
  25. #include <c10/core/impl/DeviceGuardImplInterface.h>
  26. #include <c10/util/CallOnce.h>
  27. #include <c10/util/Exception.h>
  28. #include <c10/util/env.h>
  29. #include <c10/util/hash.h>
  30. #include <c10/util/irange.h>
  31. #include <cstdint>
  32. #include <map>
  33. #include <mutex>
  34. #include <unordered_map>
  35. namespace at {
  36. class Tensor;
  37. enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
  38. enum class CuBLASReductionOption : uint8_t {
  39. AllowReducedPrecisionWithSplitK = 0,
  40. DisallowReducedPrecisionAllowSplitK = 1,
  41. DisallowReducedPrecisionDisallowSplitK = 2,
  42. };
  43. enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
  44. enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
  45. enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
  46. TORCH_API Float32Backend str2backend(const std::string& name);
  47. TORCH_API Float32Op str2op(const std::string& name);
  48. TORCH_API Float32Precision str2precision(const std::string& name);
  49. TORCH_API std::string precision2str(Float32Precision prec);
  50. class TORCH_API Context {
  51. public:
  52. Context();
  53. const Generator& defaultGenerator(Device device) {
  54. c10::DeviceType device_type = device.type();
  55. lazyInitDevice(device_type);
  56. if (device_type == at::kCPU) {
  57. return at::detail::getDefaultCPUGenerator();
  58. } else {
  59. return getAcceleratorHooksInterface(device_type)
  60. .getDefaultGenerator(device.index());
  61. }
  62. }
  63. const AcceleratorHooksInterface& getAcceleratorHooksInterface(
  64. std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
  65. if (!opt_device_type.has_value()) {
  66. opt_device_type = at::getAccelerator(true);
  67. }
  68. if (opt_device_type == at::kCUDA) {
  69. return at::detail::getCUDAHooks();
  70. } else if (opt_device_type == at::kXPU) {
  71. return at::detail::getXPUHooks();
  72. } else if (opt_device_type == at::kMPS) {
  73. return at::detail::getMPSHooks();
  74. } else if (opt_device_type == at::kPrivateUse1) {
  75. return at::detail::getPrivateUse1Hooks();
  76. } else if (opt_device_type == at::kMTIA) {
  77. return at::detail::getMTIAHooks();
  78. } else if (opt_device_type == at::kHIP) {
  79. return at::detail::getHIPHooks();
  80. } else if (opt_device_type == at::kHPU) {
  81. return at::detail::getHPUHooks();
  82. } else if (opt_device_type == at::kXLA) {
  83. return at::detail::getXLAHooks();
  84. } else {
  85. TORCH_CHECK(
  86. false,
  87. opt_device_type.has_value()
  88. ? c10::DeviceTypeName(opt_device_type.value())
  89. : "None",
  90. " device type not an accelerator.");
  91. }
  92. }
  93. Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
  94. lazyInitDevice(device_type);
  95. if (device_type == at::kCPU) {
  96. return c10::DeviceType::CPU;
  97. } else {
  98. return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
  99. }
  100. }
  101. bool isPinnedPtr(
  102. const void* data,
  103. std::optional<c10::DeviceType> device_type = std::nullopt) {
  104. auto opt_device_type =
  105. device_type.has_value() ? device_type : at::getAccelerator();
  106. if (!opt_device_type.has_value() || // there is no accelerator
  107. !at::isAccelerator(
  108. opt_device_type.value())) { // passed device not an accelerator
  109. return false;
  110. }
  111. if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
  112. // If the device is not initialized, no pointer can be pinned for it
  113. return false;
  114. }
  115. return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
  116. }
  117. Allocator* getPinnedMemoryAllocator(
  118. std::optional<c10::DeviceType> device_type = std::nullopt) {
  119. auto opt_device_type =
  120. device_type.has_value() ? device_type : at::getAccelerator();
  121. if (opt_device_type) {
  122. lazyInitDevice(opt_device_type.value());
  123. }
  124. return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
  125. }
  126. void lazyInitDevice(c10::DeviceType device_type) {
  127. if (device_type != at::kCPU) {
  128. c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
  129. getAcceleratorHooksInterface(device_type).init();
  130. });
  131. }
  132. }
  133. static bool hasOpenMP();
  134. static bool hasMKL();
  135. static bool hasKleidiAI();
  136. static bool hasLAPACK();
  137. static bool hasMKLDNN();
  138. static bool ckSupported();
  139. static bool hasEigenSparse();
  140. static bool hasMAGMA() {
  141. return detail::getCUDAHooks().hasMAGMA();
  142. }
  143. static bool hasCUDA() {
  144. return detail::getCUDAHooks().hasCUDA();
  145. }
  146. static bool hasMTIA() {
  147. return detail::getMTIAHooks().hasMTIA();
  148. }
  149. static bool hasCUDART() {
  150. return detail::getCUDAHooks().hasCUDART();
  151. }
  152. static long versionCUDART() {
  153. return detail::getCUDAHooks().versionCUDART();
  154. }
  155. static bool hasCuDNN() {
  156. return detail::getCUDAHooks().hasCuDNN();
  157. }
  158. static long versionCuDNN() {
  159. return detail::getCUDAHooks().versionCuDNN();
  160. }
  161. static long versionRuntimeCuDNN() {
  162. return detail::getCUDAHooks().versionRuntimeCuDNN();
  163. }
  164. static long versionCuDNNFrontend() {
  165. return detail::getCUDAHooks().versionCuDNNFrontend();
  166. }
  167. static bool hasCuSOLVER() {
  168. return detail::getCUDAHooks().hasCuSOLVER();
  169. }
  170. static bool hasCuBLASLt() {
  171. return detail::getCUDAHooks().hasCuBLASLt();
  172. }
  173. static bool hasROCM() {
  174. return detail::getCUDAHooks().hasROCM();
  175. }
  176. static bool hasCKSDPA() {
  177. return detail::getCUDAHooks().hasCKSDPA();
  178. }
  179. static bool hasCKGEMM() {
  180. return detail::getCUDAHooks().hasCKGEMM();
  181. }
  182. static bool hasHIP() {
  183. return detail::getHIPHooks().hasHIP();
  184. }
  185. static bool hasMPS() {
  186. return detail::getMPSHooks().hasMPS();
  187. }
  188. static bool hasIPU() {
  189. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
  190. }
  191. static bool hasXLA() {
  192. return detail::getXLAHooks().hasXLA();
  193. }
  194. static bool hasXPU() {
  195. return detail::getXPUHooks().hasXPU();
  196. }
  197. static bool hasLazy() {
  198. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
  199. }
  200. static bool hasMAIA() {
  201. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
  202. }
  203. static bool hasHPU() {
  204. return detail::getHPUHooks().hasHPU();
  205. }
  206. static const at::cuda::NVRTC& getNVRTC() {
  207. return detail::getCUDAHooks().nvrtc();
  208. }
  209. static bool setFlushDenormal(bool on);
  210. // NB: This method is *purely* whether or not a user requested
  211. // that CuDNN was enabled, it doesn't actually say anything about
  212. // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
  213. // to test this instead
  214. bool userEnabledCuDNN() const;
  215. void setUserEnabledCuDNN(bool e);
  216. bool userEnabledMkldnn() const;
  217. void setUserEnabledMkldnn(bool e);
  218. bool benchmarkCuDNN() const;
  219. void setBenchmarkCuDNN(bool /*b*/);
  220. int benchmarkLimitCuDNN() const;
  221. void setBenchmarkLimitCuDNN(int /*b*/);
  222. bool immediateMiopen() const;
  223. void setImmediateMiopen(bool /*b*/);
  224. bool deterministicCuDNN() const;
  225. void setDeterministicCuDNN(bool /*b*/);
  226. bool deterministicMkldnn() const;
  227. void setDeterministicMkldnn(bool /*b*/);
  228. bool userEnabledNNPACK() const;
  229. void setUserEnabledNNPACK(bool e);
  230. // Note [Disabling Fused SDP Kernels]
  231. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  232. // Flash and Memory Efficient SDP kernels are enabled by default.
  233. // However, they can be disabled by setting
  234. // at::globalContext().setUserEnabledFlashSDP(false) flag.
  235. // This is useful for debugging purposes. For example, if you want to
  236. // compare the performance of the flash SDP kernels with the unfused
  237. // kernel, you can disable the flash SDP kernels. By disabling
  238. // the math SDP kernel, you can force your code to use flash kernels.
  239. // The math SDP kernel can be disabled by setting
  240. // at::globalContext().setUserEnabledMathSDP(false) flag.
  241. void setSDPPriorityOrder(const std::vector<int64_t>& order);
  242. std::array<at::SDPBackend, at::num_sdp_backends> sDPPriorityOrder();
  243. void setSDPUseFlash(bool /*e*/);
  244. bool userEnabledFlashSDP() const;
  245. void setSDPUseMemEfficient(bool /*e*/);
  246. bool userEnabledMemEfficientSDP() const;
  247. void setSDPUseMath(bool /*e*/);
  248. bool userEnabledMathSDP() const;
  249. void setSDPUseCuDNN(bool /*e*/);
  250. bool userEnabledCuDNNSDP() const;
  251. void setAllowFP16BF16ReductionMathSDP(bool /*e*/);
  252. bool allowFP16BF16ReductionMathSDP() const;
  253. void setSDPUseOverrideable(bool /*e*/);
  254. bool userEnabledOverrideableSDP() const;
  255. at::LinalgBackend linalgPreferredBackend() const;
  256. void setLinalgPreferredBackend(at::LinalgBackend /*b*/);
  257. at::BlasBackend blasPreferredBackend();
  258. void setBlasPreferredBackend(at::BlasBackend /*b*/);
  259. at::ROCmFABackend getROCmFAPreferredBackend();
  260. void setROCmFAPreferredBackend(at::ROCmFABackend /*b*/);
  261. // Note [Enabling Deterministic Operations]
  262. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  263. // Operations in PyTorch that normally act nondeterministically, but have an
  264. // alternate deterministic implementation, should satisfy the following
  265. // requirements:
  266. //
  267. // * Include this comment: "See Note [Enabling Deterministic Operations]"
  268. //
  269. // * Check the value of `at::globalContext().deterministicAlgorithms()` to
  270. // toggle
  271. // between nondeterministic and deterministic implementations.
  272. //
  273. // * Have an entry in the list of PyTorch operations that toggle between
  274. // nondeterministic
  275. // and deterministic implementations, in the docstring of
  276. // `use_deterministic_algorithms()` in torch/__init__.py
  277. //
  278. // `example_func()` below shows an example of toggling between
  279. // nondeterministic and deterministic implementations:
  280. //
  281. // void example_func() {
  282. // // See Note [Enabling Deterministic Operations]
  283. // if (at::globalContext().deterministicAlgorithms()) {
  284. // example_func_deterministic();
  285. // } else {
  286. // example_func_nondeterministic();
  287. // }
  288. // }
  289. bool deterministicAlgorithms() const;
  290. bool deterministicAlgorithmsWarnOnly() const;
  291. void setDeterministicAlgorithms(bool /*b*/, bool /*warn_only*/);
  292. bool deterministicFillUninitializedMemory() const;
  293. void setDeterministicFillUninitializedMemory(bool /*b*/);
  294. // Note [Writing Nondeterministic Operations]
  295. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  296. // Operations in PyTorch that act nondeterministically and do not have an
  297. // alternate deterministic implementation should satisfy the following
  298. // requirements:
  299. //
  300. // * Include this comment: "See Note [Writing Nondeterministic Operations]"
  301. //
  302. // * Include a comment explaining why the operation is nondeterministic.
  303. //
  304. // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
  305. // of the time, this should be accomplished by calling
  306. // `at::globalContext().alertNotDeterminstic().
  307. //
  308. // * Have an entry in the list of nondeterministic PyTorch operations in the
  309. // docstring of `use_deterministic_algorithms()` in torch/__init__.py
  310. //
  311. // * Have a test function in `test/test_torch.py` whose name begins with
  312. // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
  313. // configuration is the reason for nondeterminism, the operation should be
  314. // included in the `test_cublas_config_nondeterministic_alert` test. Any new
  315. // tests should ideally follow a pattern similar to the existing ones.
  316. //
  317. // `example_func()` below shows an example of the comments and error-throwing
  318. // code for a nondeterministic operation:
  319. //
  320. // void example_func() {
  321. // // See Note [Writing Nondeterministic Operations]
  322. // // Nondeterministic because <reason>
  323. // at::globalContext().alertNondeterministic("example_func");
  324. // ...
  325. // }
  326. // Throws an error if `Context::deterministicAlgorithms()` is true
  327. static void alertNotDeterministic(std::string_view const& caller);
  328. void setFloat32MatmulPrecision(const std::string& s);
  329. void setFloat32Precision(
  330. Float32Backend backend,
  331. Float32Op op,
  332. Float32Precision p);
  333. bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
  334. void setAllowTF32CuDNN(bool /*b*/);
  335. bool allowTF32OneDNN() const;
  336. void setAllowTF32OneDNN(bool /*b*/);
  337. bool allowTF32CuBLAS() const;
  338. void setAllowTF32CuBLAS(bool /*b*/);
  339. Float32MatmulPrecision float32MatmulPrecision() const;
  340. Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
  341. CuBLASReductionOption allowFP16ReductionCuBLAS() const;
  342. void setAllowFP16ReductionCuBLAS(
  343. bool allow_reduced_precision,
  344. bool allow_splitk = true);
  345. CuBLASReductionOption allowBF16ReductionCuBLAS() const;
  346. void setAllowBF16ReductionCuBLAS(
  347. bool allow_reduced_precision,
  348. bool allow_splitk = true);
  349. bool allowFP16AccumulationCuBLAS() const;
  350. void setAllowFP16AccumulationCuBLAS(bool /*b*/);
  351. bool rocmAllowGroupGemmCk() const;
  352. // Matmuls can use a so-called "persistent" kernel which launches one CUDA
  353. // block for each SM on the GPU, and each block then iterates over multiple
  354. // output tiles. This allows to use software pipelining to hide the begin/end
  355. // latencies (e.g., epilogue), especially when only one tile fits per SM.
  356. // However, if some SMs are busy (e.g., with a background NCCL kernel), the
  357. // matmul's blocks will be scheduled in two waves and, in the absence of some
  358. // smart load balancing, the kernel will take twice as long. This flag allows
  359. // to make matmuls target only a subset of the SMs, so they can fully schedule
  360. // even next to a comms kernel, and only be a few percent slower.
  361. std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
  362. void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t> /*c*/);
  363. at::QEngine qEngine() const;
  364. void setQEngine(at::QEngine e);
  365. static const std::vector<at::QEngine>& supportedQEngines();
  366. static bool isXNNPACKAvailable();
  367. void setCheckSparseTensorInvariants(bool e);
  368. bool checkSparseTensorInvariants() const;
  369. // This method is used to release the original weight after pre-packing.
  370. // It should be called once before loading/running the model.
  371. // NB: By default it is set to true for mobile builds.
  372. void setReleaseWeightsWhenPrepacking(bool e);
  373. bool releaseWeightsWhenPrepacking() const;
  374. void setDisplayVmapFallbackWarnings(bool enabled);
  375. bool areVmapFallbackWarningsEnabled() const;
  376. void setWarnOnAccumulateGradStreamMismatch(bool enabled);
  377. bool warnOnAccumulateGradStreamMismatch() const;
  378. bool isDefaultMobileCPUAllocatorSet();
  379. void setDefaultMobileCPUAllocator();
  380. void unsetDefaultMobileCPUAllocator();
  381. bool allowFP16ReductionCPU() const;
  382. void setAllowFP16ReductionCPU(bool /*b*/);
  383. // Preserved for BC
  384. void lazyInitCUDA() {
  385. TORCH_WARN_DEPRECATION(
  386. "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.")
  387. lazyInitDevice(at::kCUDA);
  388. }
  389. void lazyInitHIP() {
  390. TORCH_WARN_DEPRECATION(
  391. "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.")
  392. lazyInitDevice(at::kHIP);
  393. }
  394. void lazyInitXPU() {
  395. TORCH_WARN_DEPRECATION(
  396. "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.")
  397. lazyInitDevice(at::kXPU);
  398. }
  399. void lazyInitMTIA() {
  400. TORCH_WARN_DEPRECATION(
  401. "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.")
  402. lazyInitDevice(at::kMTIA);
  403. }
  404. void lazyInitPrivateUse1() {
  405. TORCH_WARN_DEPRECATION(
  406. "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.")
  407. lazyInitDevice(at::kPrivateUse1);
  408. }
  409. private:
  410. std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
  411. bool enabled_cudnn = true;
  412. bool deterministic_cudnn = false;
  413. bool deterministic_mkldnn = false;
  414. bool _deterministic_algorithms = false;
  415. bool _deterministic_algorithms_warn_only = false;
  416. bool _deterministic_fill_uninitialized_memory = true;
  417. std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
  418. at::SDPBackend::flash_attention,
  419. at::SDPBackend::efficient_attention,
  420. at::SDPBackend::math,
  421. at::SDPBackend::cudnn_attention,
  422. at::SDPBackend::overrideable};
  423. bool enabled_flashSDP = true;
  424. bool enabled_mem_efficientSDP = true;
  425. bool enabled_mathSDP = true;
  426. bool enabled_cudnnSDP = true;
  427. bool enabled_overrideable = true;
  428. bool allow_fp16_bf16_reduction_mathSDP = false;
  429. bool benchmark_cudnn = false;
  430. bool immediate_miopen = false;
  431. Float32MatmulPrecision float32_matmul_precision =
  432. c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
  433. ? at::Float32MatmulPrecision::HIGH
  434. : at::Float32MatmulPrecision::HIGHEST;
  435. int benchmark_limit_cudnn = 10;
  436. bool allow_tf32_cudnn = true;
  437. CuBLASReductionOption allow_fp16_reduction_cublas =
  438. CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
  439. CuBLASReductionOption allow_bf16_reduction_cublas =
  440. CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
  441. bool allow_fp16_accumulation_cublas = false;
  442. std::optional<int32_t> sm_carveout = std::nullopt;
  443. bool enabled_mkldnn = true;
  444. bool allow_tf32_onednn = false;
  445. bool enabled_nnpack = true;
  446. at::LinalgBackend linalg_preferred_backend =
  447. (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ||
  448. c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias
  449. ? at::LinalgBackend::Cusolver
  450. : at::LinalgBackend::Default;
  451. at::BlasBackend blas_preferred_backend =
  452. (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
  453. c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias
  454. ? at::BlasBackend::Cublaslt
  455. : at::BlasBackend::Default;
  456. at::ROCmFABackend rocm_fa_preferred_backend =
  457. c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
  458. ? at::ROCmFABackend::Ck
  459. : at::ROCmFABackend::Default;
  460. #ifdef C10_MOBILE
  461. bool release_original_weights = true;
  462. #else
  463. bool release_original_weights = false;
  464. #endif
  465. bool display_vmap_fallback_warnings_ = false;
  466. bool warn_on_accumulate_grad_stream_mismatch_ = true;
  467. std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
  468. bool enable_sparse_tensor_invariant_checks = false;
  469. bool allow_fp16_reduction_cpu = false;
  470. using Key = std::pair<Float32Backend, Float32Op>;
  471. std::unordered_map<Key, Float32Precision, c10::hash<Key>> fp32_precision = {
  472. {{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE},
  473. {{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE},
  474. {{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE},
  475. {{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE},
  476. {{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE},
  477. {{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE},
  478. {{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32},
  479. {{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32},
  480. {{Float32Backend::CUDA, Float32Op::MATMUL},
  481. float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
  482. ? Float32Precision::NONE
  483. : Float32Precision::TF32},
  484. };
  485. Allocator* prev_allocator_ptr_{nullptr};
  486. };
  487. TORCH_API Context& globalContext();
  488. inline void init() {
  489. globalContext();
  490. }
  491. TORCH_API Allocator* getCPUAllocator();
  492. inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
  493. Backend p,
  494. ScalarType s) {
  495. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  496. p, s);
  497. }
  498. inline DeprecatedTypeProperties& CPU(ScalarType s) {
  499. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  500. Backend::CPU, s);
  501. }
  502. inline DeprecatedTypeProperties& CUDA(ScalarType s) {
  503. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  504. Backend::CUDA, s);
  505. }
  506. inline DeprecatedTypeProperties& HIP(ScalarType s) {
  507. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  508. Backend::HIP, s);
  509. }
  510. inline DeprecatedTypeProperties& MPS(ScalarType s) {
  511. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  512. Backend::MPS, s);
  513. }
  514. inline bool hasCUDA() {
  515. return globalContext().hasCUDA();
  516. }
  517. inline bool hasMTIA() {
  518. return globalContext().hasMTIA();
  519. }
  520. inline bool hasHIP() {
  521. return globalContext().hasHIP();
  522. }
  523. inline bool hasIPU() {
  524. return globalContext().hasIPU();
  525. }
  526. inline bool hasXLA() {
  527. return globalContext().hasXLA();
  528. }
  529. inline bool hasMPS() {
  530. return globalContext().hasMPS();
  531. }
  532. inline bool hasMAIA() {
  533. return globalContext().hasMAIA();
  534. }
  535. inline bool hasXPU() {
  536. return globalContext().hasXPU();
  537. }
  538. inline bool hasHPU() {
  539. return globalContext().hasHPU();
  540. }
  541. // Despite its name, this function returns the number of *CUDA* GPUs.
  542. inline size_t getNumGPUs() {
  543. // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
  544. // FUNCTION. If you are interested in interrogating the number of
  545. // devices for a specific device type, add that function to the
  546. // relevant library (e.g., similar to at::cuda::device_count())
  547. if (hasCUDA() && hasHIP()) {
  548. TORCH_CHECK(
  549. false,
  550. "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
  551. "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
  552. "means HIP. Rebuild PyTorch with one or the other disabled.");
  553. } else if (hasCUDA()) {
  554. return detail::getCUDAHooks().deviceCount();
  555. } else if (hasHIP()) {
  556. return detail::getHIPHooks().getNumGPUs();
  557. } else {
  558. return 0;
  559. }
  560. }
  561. inline bool hasOpenMP() {
  562. return globalContext().hasOpenMP();
  563. }
  564. inline bool hasMKL() {
  565. return globalContext().hasMKL();
  566. }
  567. inline bool hasKleidiAI() {
  568. return globalContext().hasKleidiAI();
  569. }
  570. inline bool hasLAPACK() {
  571. return globalContext().hasLAPACK();
  572. }
  573. inline bool hasEigenSparse() {
  574. return globalContext().hasEigenSparse();
  575. }
  576. inline bool hasMAGMA() {
  577. return globalContext().hasMAGMA();
  578. }
  579. inline bool hasMKLDNN() {
  580. return globalContext().hasMKLDNN();
  581. }
  582. inline void manual_seed(uint64_t seed) {
  583. {
  584. auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
  585. // See Note [Acquire lock when using random generators]
  586. std::lock_guard<std::mutex> lock(gen.mutex());
  587. gen.set_current_seed(seed);
  588. }
  589. const auto opt_device_type = at::getAccelerator();
  590. if (!opt_device_type.has_value()) {
  591. return;
  592. }
  593. const auto num_gpus = globalContext()
  594. .getAcceleratorHooksInterface(opt_device_type)
  595. .deviceCount();
  596. for (const auto i : c10::irange(num_gpus)) {
  597. auto gen = globalContext().defaultGenerator(
  598. Device(opt_device_type.value(), static_cast<c10::DeviceIndex>(i)));
  599. {
  600. // See Note [Acquire lock when using random generators]
  601. std::lock_guard<std::mutex> lock(gen.mutex());
  602. gen.set_current_seed(seed);
  603. }
  604. }
  605. }
  606. // When the global flag `allow_tf32` is set to true, cuBLAS handles are
  607. // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
  608. // For some operators, such as addmv, TF32 offers no performance improvement
  609. // but causes precision loss. To help this case, this class implements
  610. // a RAII guard that can be used to quickly disable TF32 within its scope.
  611. //
  612. // Usage:
  613. // NoTF32Guard disable_tf32;
  614. struct TORCH_API NoTF32Guard {
  615. NoTF32Guard();
  616. NoTF32Guard(NoTF32Guard&& other) = delete;
  617. NoTF32Guard(const NoTF32Guard&) = delete;
  618. NoTF32Guard& operator=(const NoTF32Guard&) = delete;
  619. NoTF32Guard& operator=(NoTF32Guard&&) = delete;
  620. ~NoTF32Guard();
  621. static bool should_disable_tf32();
  622. private:
  623. bool changed = false;
  624. };
  625. struct TORCH_API ROCmBackwardPassGuard {
  626. ROCmBackwardPassGuard();
  627. ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;
  628. ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete;
  629. ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete;
  630. ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete;
  631. ~ROCmBackwardPassGuard();
  632. static bool is_backward_pass();
  633. };
  634. } // namespace at
  635. #else
  636. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  637. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)