Context.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. #pragma once
  2. #include <ATen/BlasBackend.h>
  3. #include <ATen/CPUGeneratorImpl.h>
  4. #include <ATen/DeviceAccelerator.h>
  5. #include <ATen/LinalgBackend.h>
  6. #include <ATen/ROCmFABackend.h>
  7. #include <ATen/SDPBackend.h>
  8. #include <ATen/core/ATenGeneral.h>
  9. #include <ATen/core/DeprecatedTypeProperties.h>
  10. #include <ATen/core/Generator.h>
  11. #include <ATen/core/LegacyTypeDispatch.h>
  12. #include <ATen/detail/AcceleratorHooksInterface.h>
  13. #include <ATen/detail/CUDAHooksInterface.h>
  14. #include <ATen/detail/HIPHooksInterface.h>
  15. #include <ATen/detail/HPUHooksInterface.h>
  16. #include <ATen/detail/IPUHooksInterface.h>
  17. #include <ATen/detail/MAIAHooksInterface.h>
  18. #include <ATen/detail/MPSHooksInterface.h>
  19. #include <ATen/detail/MTIAHooksInterface.h>
  20. #include <ATen/detail/PrivateUse1HooksInterface.h>
  21. #include <ATen/detail/XPUHooksInterface.h>
  22. #include <c10/core/QEngine.h>
  23. #include <c10/core/impl/DeviceGuardImplInterface.h>
  24. #include <c10/util/CallOnce.h>
  25. #include <c10/util/Exception.h>
  26. #include <c10/util/env.h>
  27. #include <c10/util/irange.h>
  28. #include <cstdint>
  29. #include <map>
  30. #include <mutex>
  31. namespace at {
  32. class Tensor;
  33. enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
  34. class TORCH_API Context {
  35. public:
  36. Context();
  37. const Generator& defaultGenerator(Device device) {
  38. c10::DeviceType device_type = device.type();
  39. lazyInitDevice(device_type);
  40. if (device_type == at::kCPU) {
  41. return at::detail::getDefaultCPUGenerator();
  42. } else {
  43. return getAcceleratorHooksInterface(device_type)
  44. .getDefaultGenerator(device.index());
  45. }
  46. }
  47. const AcceleratorHooksInterface& getAcceleratorHooksInterface(
  48. std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
  49. if (!opt_device_type.has_value()) {
  50. opt_device_type = at::getAccelerator(true);
  51. }
  52. if (opt_device_type == at::kCUDA) {
  53. return at::detail::getCUDAHooks();
  54. } else if (opt_device_type == at::kXPU) {
  55. return at::detail::getXPUHooks();
  56. } else if (opt_device_type == at::kMPS) {
  57. return at::detail::getMPSHooks();
  58. } else if (opt_device_type == at::kPrivateUse1) {
  59. return at::detail::getPrivateUse1Hooks();
  60. } else if (opt_device_type == at::kMTIA) {
  61. return at::detail::getMTIAHooks();
  62. } else if (opt_device_type == at::kHIP) {
  63. return at::detail::getHIPHooks();
  64. } else if (opt_device_type == at::kHPU) {
  65. return at::detail::getHPUHooks();
  66. } else {
  67. TORCH_CHECK(
  68. false,
  69. opt_device_type.has_value()
  70. ? c10::DeviceTypeName(opt_device_type.value())
  71. : "None",
  72. " device type not an accelerator.");
  73. }
  74. }
  75. Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
  76. lazyInitDevice(device_type);
  77. if (device_type == at::kCPU) {
  78. return c10::DeviceType::CPU;
  79. } else {
  80. return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
  81. }
  82. }
  83. bool isPinnedPtr(
  84. const void* data,
  85. std::optional<c10::DeviceType> device_type = std::nullopt) {
  86. auto opt_device_type =
  87. device_type.has_value() ? device_type : at::getAccelerator();
  88. if (!opt_device_type.has_value() || // there is no accelerator
  89. !at::isAccelerator(
  90. opt_device_type.value())) { // passed device not an accelerator
  91. return false;
  92. }
  93. if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
  94. // If the device is not initialized, no pointer can be pinned for it
  95. return false;
  96. }
  97. return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
  98. }
  99. Allocator* getPinnedMemoryAllocator(
  100. std::optional<c10::DeviceType> device_type = std::nullopt) {
  101. auto opt_device_type =
  102. device_type.has_value() ? device_type : at::getAccelerator();
  103. if (opt_device_type) {
  104. lazyInitDevice(opt_device_type.value());
  105. }
  106. return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
  107. }
  108. void lazyInitDevice(c10::DeviceType device_type) {
  109. if (device_type != at::kCPU) {
  110. c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
  111. getAcceleratorHooksInterface(device_type).init();
  112. });
  113. }
  114. }
  115. static bool hasOpenMP();
  116. static bool hasMKL();
  117. static bool hasKleidiAI();
  118. static bool hasLAPACK();
  119. static bool hasMKLDNN();
  120. static bool ckSupported();
  121. static bool hasEigenSparse();
  122. static bool hasMAGMA() {
  123. return detail::getCUDAHooks().hasMAGMA();
  124. }
  125. static bool hasCUDA() {
  126. return detail::getCUDAHooks().hasCUDA();
  127. }
  128. static bool hasMTIA() {
  129. return detail::getMTIAHooks().hasMTIA();
  130. }
  131. static bool hasCUDART() {
  132. return detail::getCUDAHooks().hasCUDART();
  133. }
  134. static long versionCUDART() {
  135. return detail::getCUDAHooks().versionCUDART();
  136. }
  137. static bool hasCuDNN() {
  138. return detail::getCUDAHooks().hasCuDNN();
  139. }
  140. static long versionCuDNN() {
  141. return detail::getCUDAHooks().versionCuDNN();
  142. }
  143. static long versionRuntimeCuDNN() {
  144. return detail::getCUDAHooks().versionRuntimeCuDNN();
  145. }
  146. static long versionCuDNNFrontend() {
  147. return detail::getCUDAHooks().versionCuDNNFrontend();
  148. }
  149. static bool hasCuSOLVER() {
  150. return detail::getCUDAHooks().hasCuSOLVER();
  151. }
  152. static bool hasCuBLASLt() {
  153. return detail::getCUDAHooks().hasCuBLASLt();
  154. }
  155. static bool hasROCM() {
  156. return detail::getCUDAHooks().hasROCM();
  157. }
  158. static bool hasCKSDPA() {
  159. return detail::getCUDAHooks().hasCKSDPA();
  160. }
  161. static bool hasCKGEMM() {
  162. return detail::getCUDAHooks().hasCKGEMM();
  163. }
  164. static bool hasHIP() {
  165. return detail::getHIPHooks().hasHIP();
  166. }
  167. static bool hasMPS() {
  168. return detail::getMPSHooks().hasMPS();
  169. }
  170. static bool hasIPU() {
  171. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
  172. }
  173. static bool hasXLA() {
  174. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
  175. }
  176. static bool hasXPU() {
  177. return detail::getXPUHooks().hasXPU();
  178. }
  179. static bool hasLazy() {
  180. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
  181. }
  182. static bool hasMAIA() {
  183. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
  184. }
  185. static bool hasHPU() {
  186. return detail::getHPUHooks().hasHPU();
  187. }
  188. static const at::cuda::NVRTC& getNVRTC() {
  189. return detail::getCUDAHooks().nvrtc();
  190. }
  191. static bool setFlushDenormal(bool on);
  192. // NB: This method is *purely* whether or not a user requested
  193. // that CuDNN was enabled, it doesn't actually say anything about
  194. // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
  195. // to test this instead
  196. bool userEnabledCuDNN() const;
  197. void setUserEnabledCuDNN(bool e);
  198. bool userEnabledMkldnn() const;
  199. void setUserEnabledMkldnn(bool e);
  200. bool benchmarkCuDNN() const;
  201. void setBenchmarkCuDNN(bool);
  202. int benchmarkLimitCuDNN() const;
  203. void setBenchmarkLimitCuDNN(int);
  204. bool immediateMiopen() const;
  205. void setImmediateMiopen(bool);
  206. bool deterministicCuDNN() const;
  207. void setDeterministicCuDNN(bool);
  208. bool deterministicMkldnn() const;
  209. void setDeterministicMkldnn(bool);
  210. bool userEnabledNNPACK() const;
  211. void setUserEnabledNNPACK(bool e);
  212. // Note [Disabling Fused SDP Kernels]
  213. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  214. // Flash and Memory Efficient SDP kernels are enabled by default.
  215. // However, they can be disabled by setting
  216. // at::globalContext().setUserEnabledFlashSDP(false) flag.
  217. // This is useful for debugging purposes. For example, if you want to
  218. // compare the performance of the flash SDP kernels with the unfused
  219. // kernel, you can disable the flash SDP kernels. By disabling
  220. // the math SDP kernel, you can force your code to use flash kernels.
  221. // The math SDP kernel can be disabled by setting
  222. // at::globalContext().setUserEnabledMathSDP(false) flag.
  223. void setSDPPriorityOrder(const std::vector<int64_t>& order);
  224. std::array<at::SDPBackend, at::num_sdp_backends> sDPPriorityOrder();
  225. void setSDPUseFlash(bool);
  226. bool userEnabledFlashSDP() const;
  227. void setSDPUseMemEfficient(bool);
  228. bool userEnabledMemEfficientSDP() const;
  229. void setSDPUseMath(bool);
  230. bool userEnabledMathSDP() const;
  231. void setSDPUseCuDNN(bool);
  232. bool userEnabledCuDNNSDP() const;
  233. void setAllowFP16BF16ReductionMathSDP(bool);
  234. bool allowFP16BF16ReductionMathSDP() const;
  235. void setSDPUseOverrideable(bool);
  236. bool userEnabledOverrideableSDP() const;
  237. at::LinalgBackend linalgPreferredBackend() const;
  238. void setLinalgPreferredBackend(at::LinalgBackend);
  239. at::BlasBackend blasPreferredBackend();
  240. void setBlasPreferredBackend(at::BlasBackend);
  241. at::ROCmFABackend getROCmFAPreferredBackend();
  242. void setROCmFAPreferredBackend(at::ROCmFABackend);
  243. // Note [Enabling Deterministic Operations]
  244. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  245. // Operations in PyTorch that normally act nondeterministically, but have an
  246. // alternate deterministic implementation, should satisfy the following
  247. // requirements:
  248. //
  249. // * Include this comment: "See Note [Enabling Deterministic Operations]"
  250. //
  251. // * Check the value of `at::globalContext().deterministicAlgorithms()` to
  252. // toggle
  253. // between nondeterministic and deterministic implementations.
  254. //
  255. // * Have an entry in the list of PyTorch operations that toggle between
  256. // nondeterministic
  257. // and deterministic implementations, in the docstring of
  258. // `use_deterministic_algorithms()` in torch/__init__.py
  259. //
  260. // `example_func()` below shows an example of toggling between
  261. // nondeterministic and deterministic implementations:
  262. //
  263. // void example_func() {
  264. // // See Note [Enabling Deterministic Operations]
  265. // if (at::globalContext().deterministicAlgorithms()) {
  266. // example_func_deterministic();
  267. // } else {
  268. // example_func_nondeterministic();
  269. // }
  270. // }
  271. bool deterministicAlgorithms() const;
  272. bool deterministicAlgorithmsWarnOnly() const;
  273. void setDeterministicAlgorithms(bool, bool);
  274. bool deterministicFillUninitializedMemory() const;
  275. void setDeterministicFillUninitializedMemory(bool);
  276. // Note [Writing Nondeterministic Operations]
  277. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  278. // Operations in PyTorch that act nondeterministically and do not have an
  279. // alternate deterministic implementation should satisfy the following
  280. // requirements:
  281. //
  282. // * Include this comment: "See Note [Writing Nondeterministic Operations]"
  283. //
  284. // * Include a comment explaining why the operation is nondeterministic.
  285. //
  286. // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
  287. // of the time, this should be accomplished by calling
  288. // `at::globalContext().alertNotDeterminstic()`. However, if the
  289. // nondeterministic behavior is caused by the CuBLAS workspace
  290. // configuration in CUDA >= 10.2,
  291. // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
  292. // called instead (in this case, a comment explaining why the operation is
  293. // nondeterministic is not necessary). See below for details on these
  294. // methods.
  295. //
  296. // * Have an entry in the list of nondeterministic PyTorch operations in the
  297. // docstring of `use_deterministic_algorithms()` in torch/__init__.py
  298. //
  299. // * Have a test function in `test/test_torch.py` whose name begins with
  300. // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
  301. // configuration is the reason for nondeterminism, the operation should be
  302. // included in the `test_cublas_config_nondeterministic_alert` test. Any new
  303. // tests should ideally follow a pattern similar to the existing ones.
  304. //
  305. // `example_func()` below shows an example of the comments and error-throwing
  306. // code for a nondeterministic operation:
  307. //
  308. // void example_func() {
  309. // // See Note [Writing Nondeterministic Operations]
  310. // // Nondeterministic because <reason>
  311. // at::globalContext().alertNondeterministic("example_func");
  312. // ...
  313. // }
  314. // Throws an error if `Context::deterministicAlgorithms()` is true
  315. static void alertNotDeterministic(std::string_view const& caller);
  316. // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
  317. // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
  318. // ":4096:8". For more details:
  319. // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
  320. void alertCuBLASConfigNotDeterministic() const;
  321. void setFloat32MatmulPrecision(const std::string& s);
  322. void setFloat32Precision(
  323. const std::string& backend,
  324. const std::string& op,
  325. const std::string& s);
  326. bool allowTF32CuDNN(const std::string& op = std::string()) const;
  327. void setAllowTF32CuDNN(bool);
  328. bool allowTF32OneDNN() const;
  329. void setAllowTF32OneDNN(bool);
  330. bool allowTF32CuBLAS() const;
  331. void setAllowTF32CuBLAS(bool);
  332. Float32MatmulPrecision float32MatmulPrecision() const;
  333. std::string float32Precision(
  334. const std::string& backend,
  335. const std::string& op) const;
  336. bool allowFP16ReductionCuBLAS() const;
  337. void setAllowFP16ReductionCuBLAS(bool);
  338. bool allowBF16ReductionCuBLAS() const;
  339. void setAllowBF16ReductionCuBLAS(bool);
  340. bool allowFP16AccumulationCuBLAS() const;
  341. void setAllowFP16AccumulationCuBLAS(bool);
  342. // Matmuls can use a so-called "persistent" kernel which launches one CUDA
  343. // block for each SM on the GPU, and each block then iterates over multiple
  344. // output tiles. This allows to use software pipelining to hide the begin/end
  345. // latencies (e.g., epilogue), especially when only one tile fits per SM.
  346. // However, if some SMs are busy (e.g., with a background NCCL kernel), the
  347. // matmul's blocks will be scheduled in two waves and, in the absence of some
  348. // smart load balancing, the kernel will take twice as long. This flag allows
  349. // to make matmuls target only a subset of the SMs, so they can fully schedule
  350. // even next to a comms kernel, and only be a few percent slower.
  351. std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
  352. void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
  353. at::QEngine qEngine() const;
  354. void setQEngine(at::QEngine e);
  355. static const std::vector<at::QEngine>& supportedQEngines();
  356. static bool isXNNPACKAvailable();
  357. void setCheckSparseTensorInvariants(bool e);
  358. bool checkSparseTensorInvariants() const;
  359. // This method is used to release the original weight after pre-packing.
  360. // It should be called once before loading/running the model.
  361. // NB: By default it is set to true for mobile builds.
  362. void setReleaseWeightsWhenPrepacking(bool e);
  363. bool releaseWeightsWhenPrepacking() const;
  364. void setDisplayVmapFallbackWarnings(bool enabled);
  365. bool areVmapFallbackWarningsEnabled() const;
  366. bool isDefaultMobileCPUAllocatorSet();
  367. void setDefaultMobileCPUAllocator();
  368. void unsetDefaultMobileCPUAllocator();
  369. bool allowFP16ReductionCPU() const;
  370. void setAllowFP16ReductionCPU(bool);
  371. // Preserved for BC
  372. void lazyInitCUDA() {
  373. TORCH_WARN_DEPRECATION(
  374. "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.")
  375. lazyInitDevice(at::kCUDA);
  376. }
  377. void lazyInitHIP() {
  378. TORCH_WARN_DEPRECATION(
  379. "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.")
  380. lazyInitDevice(at::kHIP);
  381. }
  382. void lazyInitXPU() {
  383. TORCH_WARN_DEPRECATION(
  384. "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.")
  385. lazyInitDevice(at::kXPU);
  386. }
  387. void lazyInitMTIA() {
  388. TORCH_WARN_DEPRECATION(
  389. "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.")
  390. lazyInitDevice(at::kMTIA);
  391. }
  392. void lazyInitPrivateUse1() {
  393. TORCH_WARN_DEPRECATION(
  394. "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.")
  395. lazyInitDevice(at::kPrivateUse1);
  396. }
  397. private:
  398. static bool checkCuBLASConfigDeterministic();
  399. std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
  400. bool enabled_cudnn = true;
  401. bool deterministic_cudnn = false;
  402. bool deterministic_mkldnn = false;
  403. bool _deterministic_algorithms = false;
  404. bool _deterministic_algorithms_warn_only = false;
  405. bool _deterministic_fill_uninitialized_memory = true;
  406. std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
  407. at::SDPBackend::flash_attention,
  408. at::SDPBackend::efficient_attention,
  409. at::SDPBackend::math,
  410. at::SDPBackend::cudnn_attention,
  411. at::SDPBackend::overrideable};
  412. bool enabled_flashSDP = true;
  413. bool enabled_mem_efficientSDP = true;
  414. bool enabled_mathSDP = true;
  415. bool enabled_cudnnSDP = true;
  416. bool enabled_overrideable = true;
  417. bool allow_fp16_bf16_reduction_mathSDP = false;
  418. bool benchmark_cudnn = false;
  419. bool immediate_miopen = false;
  420. Float32MatmulPrecision float32_matmul_precision =
  421. c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
  422. ? at::Float32MatmulPrecision::HIGH
  423. : at::Float32MatmulPrecision::HIGHEST;
  424. int benchmark_limit_cudnn = 10;
  425. bool allow_tf32_cudnn = true;
  426. bool allow_fp16_reduction_cublas = true;
  427. bool allow_bf16_reduction_cublas = true;
  428. bool allow_fp16_accumulation_cublas = false;
  429. std::optional<int32_t> sm_carveout = std::nullopt;
  430. bool enabled_mkldnn = true;
  431. bool allow_tf32_onednn = false;
  432. bool enabled_nnpack = true;
  433. at::LinalgBackend linalg_preferred_backend =
  434. (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ||
  435. c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias
  436. ? at::LinalgBackend::Cusolver
  437. : at::LinalgBackend::Default;
  438. at::BlasBackend blas_preferred_backend =
  439. (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
  440. c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias
  441. ? at::BlasBackend::Cublaslt
  442. : at::BlasBackend::Default;
  443. at::ROCmFABackend rocm_fa_preferred_backend =
  444. c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
  445. ? at::ROCmFABackend::Ck
  446. : at::ROCmFABackend::Default;
  447. #ifdef C10_MOBILE
  448. bool release_original_weights = true;
  449. #else
  450. bool release_original_weights = false;
  451. #endif
  452. bool display_vmap_fallback_warnings_ = false;
  453. std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
  454. bool enable_sparse_tensor_invariant_checks = false;
  455. bool allow_fp16_reduction_cpu = false;
  456. std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
  457. {"generic", {{"all", "none"}}},
  458. {"mkldnn",
  459. {{"matmul", "none"},
  460. {"conv", "none"},
  461. {"rnn", "none"},
  462. {"all", "none"}}},
  463. {"cuda",
  464. {{"matmul",
  465. float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
  466. ? "none"
  467. : "tf32"},
  468. {"conv", "tf32"},
  469. {"rnn", "tf32"},
  470. {"all", "none"}}},
  471. };
  472. Allocator* prev_allocator_ptr_{nullptr};
  473. };
  474. TORCH_API Context& globalContext();
  475. inline void init() {
  476. globalContext();
  477. }
  478. TORCH_API Allocator* getCPUAllocator();
  479. inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
  480. Backend p,
  481. ScalarType s) {
  482. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  483. p, s);
  484. }
  485. inline DeprecatedTypeProperties& CPU(ScalarType s) {
  486. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  487. Backend::CPU, s);
  488. }
  489. inline DeprecatedTypeProperties& CUDA(ScalarType s) {
  490. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  491. Backend::CUDA, s);
  492. }
  493. inline DeprecatedTypeProperties& HIP(ScalarType s) {
  494. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  495. Backend::HIP, s);
  496. }
  497. inline DeprecatedTypeProperties& MPS(ScalarType s) {
  498. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  499. Backend::MPS, s);
  500. }
  501. inline bool hasCUDA() {
  502. return globalContext().hasCUDA();
  503. }
  504. inline bool hasMTIA() {
  505. return globalContext().hasMTIA();
  506. }
  507. inline bool hasHIP() {
  508. return globalContext().hasHIP();
  509. }
  510. inline bool hasIPU() {
  511. return globalContext().hasIPU();
  512. }
  513. inline bool hasXLA() {
  514. return globalContext().hasXLA();
  515. }
  516. inline bool hasMPS() {
  517. return globalContext().hasMPS();
  518. }
  519. inline bool hasMAIA() {
  520. return globalContext().hasMAIA();
  521. }
  522. inline bool hasXPU() {
  523. return globalContext().hasXPU();
  524. }
  525. inline bool hasHPU() {
  526. return globalContext().hasHPU();
  527. }
  528. // Despite its name, this function returns the number of *CUDA* GPUs.
  529. inline size_t getNumGPUs() {
  530. // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
  531. // FUNCTION. If you are interested in interrogating the number of
  532. // devices for a specific device type, add that function to the
  533. // relevant library (e.g., similar to at::cuda::device_count())
  534. if (hasCUDA() && hasHIP()) {
  535. TORCH_CHECK(
  536. false,
  537. "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
  538. "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
  539. "means HIP. Rebuild PyTorch with one or the other disabled.");
  540. } else if (hasCUDA()) {
  541. return detail::getCUDAHooks().deviceCount();
  542. } else if (hasHIP()) {
  543. return detail::getHIPHooks().getNumGPUs();
  544. } else {
  545. return 0;
  546. }
  547. }
  548. inline bool hasOpenMP() {
  549. return globalContext().hasOpenMP();
  550. }
  551. inline bool hasMKL() {
  552. return globalContext().hasMKL();
  553. }
  554. inline bool hasKleidiAI() {
  555. return globalContext().hasKleidiAI();
  556. }
  557. inline bool hasLAPACK() {
  558. return globalContext().hasLAPACK();
  559. }
  560. inline bool hasEigenSparse() {
  561. return globalContext().hasEigenSparse();
  562. }
  563. inline bool hasMAGMA() {
  564. return globalContext().hasMAGMA();
  565. }
  566. inline bool hasMKLDNN() {
  567. return globalContext().hasMKLDNN();
  568. }
  569. inline void manual_seed(uint64_t seed) {
  570. {
  571. auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
  572. // See Note [Acquire lock when using random generators]
  573. std::lock_guard<std::mutex> lock(gen.mutex());
  574. gen.set_current_seed(seed);
  575. }
  576. const auto opt_device_type = at::getAccelerator();
  577. if (!opt_device_type.has_value()) {
  578. return;
  579. }
  580. const auto num_gpus = globalContext()
  581. .getAcceleratorHooksInterface(opt_device_type)
  582. .deviceCount();
  583. for (const auto i : c10::irange(num_gpus)) {
  584. auto gen = globalContext().defaultGenerator(
  585. Device(opt_device_type.value(), static_cast<c10::DeviceIndex>(i)));
  586. {
  587. // See Note [Acquire lock when using random generators]
  588. std::lock_guard<std::mutex> lock(gen.mutex());
  589. gen.set_current_seed(seed);
  590. }
  591. }
  592. }
  593. // When the global flag `allow_tf32` is set to true, cuBLAS handles are
  594. // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
  595. // For some operators, such as addmv, TF32 offers no performance improvement
  596. // but causes precision loss. To help this case, this class implements
  597. // a RAII guard that can be used to quickly disable TF32 within its scope.
  598. //
  599. // Usage:
  600. // NoTF32Guard disable_tf32;
  601. struct TORCH_API NoTF32Guard {
  602. NoTF32Guard();
  603. NoTF32Guard(NoTF32Guard&& other) = delete;
  604. NoTF32Guard(const NoTF32Guard&) = delete;
  605. NoTF32Guard& operator=(const NoTF32Guard&) = delete;
  606. NoTF32Guard& operator=(NoTF32Guard&&) = delete;
  607. ~NoTF32Guard();
  608. static bool should_disable_tf32();
  609. private:
  610. bool changed = false;
  611. };
  612. struct TORCH_API ROCmBackwardPassGuard {
  613. ROCmBackwardPassGuard();
  614. ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;
  615. ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete;
  616. ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete;
  617. ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete;
  618. ~ROCmBackwardPassGuard();
  619. static bool is_backward_pass();
  620. };
  621. } // namespace at