DispatchStub.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. #pragma once
  2. #include <c10/core/DeviceType.h>
  3. #include <c10/macros/Macros.h>
  4. #include <atomic>
  5. #include <utility>
  6. #include <variant>
  7. // Implements instruction set specific function dispatch.
  8. //
  9. // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
  10. // compiled multiple times with different compiler flags (e.g. -mavx2). A
  11. // DispatchStub contains a table of function pointers for a kernel. At runtime,
  12. // the fastest available kernel is chosen based on the features reported by
  13. // cpuinfo.
  14. //
  15. // Example:
  16. //
  17. // In native/MyKernel.h:
  18. // using fn_type = void(*)(const Tensor& x);
  19. // DECLARE_DISPATCH(fn_type, stub)
  20. //
  21. // In native/MyKernel.cpp
  22. // DEFINE_DISPATCH(stub);
  23. //
  24. // In native/cpu/MyKernel.cpp:
  25. // namespace {
  26. // // use anonymous namespace so that different cpu versions won't conflict
  27. // void kernel(const Tensor& x) { ... }
  28. // }
  29. // REGISTER_DISPATCH(stub, &kernel);
  30. //
  31. // To call:
  32. // stub(kCPU, tensor);
  33. //
  34. // TODO: CPU instruction set selection should be folded into whatever
  35. // the main dispatch mechanism is.
  36. //
  37. // Supported device types for registration:
  38. // - CPU: Central Processing Unit
  39. // - CUDA: NVIDIA GPUs
  40. // - HIP: AMD GPUs
  41. // - MPS: Apple Silicon GPUs (Metal Performance Shaders)
  42. // - MTIA: Meta Training and Inference Devices
  43. // - XPU: Intel GPUs
  44. // - HPU: Reserved for HPU (Intel Gaudi) device types
  45. // - PrivateUse1: Reserved for private/custom device types
  46. //
  47. // If you want to update the list of supported devices, add a new dispatch_ptr
  48. // member in DispatchStubImpl.h and update the get_call_ptr switch.
  49. // As well you will need to update the inlined list in 'is_device_supported`
  50. //
  51. //
  52. // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
  53. C10_CLANG_DIAGNOSTIC_PUSH()
  54. C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
  55. namespace at::native {
  56. enum class CPUCapability {
  57. DEFAULT = 0,
  58. #if defined(HAVE_VSX_CPU_DEFINITION)
  59. VSX = 1,
  60. #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
  61. ZVECTOR = 1,
  62. #elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
  63. SVE256 = 1,
  64. #else
  65. AVX2 = 1,
  66. AVX512 = 2,
  67. #endif
  68. NUM_OPTIONS
  69. };
  70. // Enum for error types
  71. enum class ErrorType {
  72. MissingDeviceKernel,
  73. DeviceNotSupported
  74. };
  75. // Alias for the return type using std::variant
  76. using DispatchResult = std::variant<void*, ErrorType>;
  77. CPUCapability get_cpu_capability();
  78. template <typename FnPtr, typename T>
  79. struct DispatchStub;
  80. /**
  81. * The sole purpose of this class is to outline methods that don't need to be
  82. * specialized or otherwise inlined and duplicated (by the compiler due to
  83. * template expansion), since it causes size bloat if there are a significant
  84. * number of specialization of the DispatchStub<> class.
  85. */
  86. struct TORCH_API DispatchStubImpl {
  87. // The DispatchStubImpl::try_get_call_ptr() method is used to get the call
  88. // pointer for a given device type. If the call pointer is not found,
  89. // DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
  90. // The main difference between try_get_call_ptr() and get_call_ptr() is that
  91. // try_get_call_ptr() will return the ErrorType and not raise an exception.
  92. DispatchResult try_get_call_ptr(
  93. c10::DeviceType device_type
  94. , void *DEFAULT
  95. #ifdef HAVE_AVX512_CPU_DEFINITION
  96. , void *AVX512
  97. #endif
  98. #ifdef HAVE_AVX2_CPU_DEFINITION
  99. , void *AVX2
  100. #endif
  101. #ifdef HAVE_VSX_CPU_DEFINITION
  102. , void *VSX
  103. #endif
  104. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  105. , void *ZVECTOR
  106. #endif
  107. #ifdef HAVE_SVE256_CPU_DEFINITION
  108. , void *SVE256
  109. #endif
  110. );
  111. // Analogous to try_get_call_ptr(), but it will return the ErrorType and not
  112. // raise an exception.
  113. DispatchResult try_choose_cpu_impl(
  114. void *DEFAULT
  115. #ifdef HAVE_AVX512_CPU_DEFINITION
  116. , void *AVX512
  117. #endif
  118. #ifdef HAVE_AVX2_CPU_DEFINITION
  119. , void *AVX2
  120. #endif
  121. #ifdef HAVE_VSX_CPU_DEFINITION
  122. , void *VSX
  123. #endif
  124. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  125. , void *ZVECTOR
  126. #endif
  127. #ifdef HAVE_SVE256_CPU_DEFINITION
  128. , void *SVE256
  129. #endif
  130. );
  131. void* get_call_ptr(
  132. c10::DeviceType device_type
  133. , void *DEFAULT
  134. #ifdef HAVE_AVX512_CPU_DEFINITION
  135. , void *AVX512
  136. #endif
  137. #ifdef HAVE_AVX2_CPU_DEFINITION
  138. , void *AVX2
  139. #endif
  140. #ifdef HAVE_VSX_CPU_DEFINITION
  141. , void *VSX
  142. #endif
  143. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  144. , void *ZVECTOR
  145. #endif
  146. #ifdef HAVE_SVE256_CPU_DEFINITION
  147. , void *SVE256
  148. #endif
  149. );
  150. /**
  151. * The CPU Dispatch actual method is chosen in decreasing order of preference by
  152. * DispatchStubImpl::choose_cpu_impl() in case none is found by
  153. * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
  154. */
  155. void* choose_cpu_impl(
  156. void *DEFAULT
  157. #ifdef HAVE_AVX512_CPU_DEFINITION
  158. , void *AVX512
  159. #endif
  160. #ifdef HAVE_AVX2_CPU_DEFINITION
  161. , void *AVX2
  162. #endif
  163. #ifdef HAVE_VSX_CPU_DEFINITION
  164. , void *VSX
  165. #endif
  166. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  167. , void *ZVECTOR
  168. #endif
  169. #ifdef HAVE_SVE256_CPU_DEFINITION
  170. , void *SVE256
  171. #endif
  172. );
  173. // Fixing dispatch error in Windows debug builds.
  174. // See https://github.com/pytorch/pytorch/issues/22681 for more details.
  175. #if defined(_MSC_VER) && defined(_DEBUG)
  176. std::atomic<void*> cpu_dispatch_ptr;
  177. void* cuda_dispatch_ptr;
  178. void* hip_dispatch_ptr;
  179. void* mps_dispatch_ptr;
  180. void* mtia_dispatch_ptr;
  181. #if defined(USE_XPU)
  182. void* xpu_dispatch_ptr;
  183. #endif
  184. void* hpu_dispatch_ptr;
  185. void* privateuse1_dispatch_ptr;
  186. #else
  187. std::atomic<void*> cpu_dispatch_ptr{nullptr};
  188. void* cuda_dispatch_ptr = nullptr;
  189. void* hip_dispatch_ptr = nullptr;
  190. void* mps_dispatch_ptr = nullptr;
  191. void* mtia_dispatch_ptr = nullptr;
  192. #if defined(USE_XPU)
  193. void* xpu_dispatch_ptr = nullptr;
  194. #endif
  195. void* hpu_dispatch_ptr = nullptr;
  196. void* privateuse1_dispatch_ptr = nullptr;
  197. #endif
  198. };
  199. template <typename rT, typename T, typename... Args>
  200. struct DispatchStub<rT (*)(Args...), T> {
  201. using FnPtr = rT (*) (Args...);
  202. DispatchStub() = default;
  203. DispatchStub(const DispatchStub&) = delete;
  204. DispatchStub& operator=(const DispatchStub&) = delete;
  205. private:
  206. FnPtr get_call_ptr(const c10::DeviceType device_type) {
  207. return reinterpret_cast<FnPtr>(
  208. impl.get_call_ptr(device_type
  209. , reinterpret_cast<void*>(DEFAULT)
  210. #ifdef HAVE_AVX512_CPU_DEFINITION
  211. , reinterpret_cast<void*>(AVX512)
  212. #endif
  213. #ifdef HAVE_AVX2_CPU_DEFINITION
  214. , reinterpret_cast<void*>(AVX2)
  215. #endif
  216. #ifdef HAVE_VSX_CPU_DEFINITION
  217. , reinterpret_cast<void*>(VSX)
  218. #endif
  219. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  220. , reinterpret_cast<void*>(ZVECTOR)
  221. #endif
  222. #ifdef HAVE_SVE256_CPU_DEFINITION
  223. , reinterpret_cast<void*>(SVE256)
  224. #endif
  225. )
  226. );
  227. }
  228. public:
  229. template <typename... ArgTypes>
  230. rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
  231. FnPtr call_ptr = get_call_ptr(device_type);
  232. return (*call_ptr)(std::forward<ArgTypes>(args)...);
  233. }
  234. void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
  235. impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  236. }
  237. #if defined(USE_XPU)
  238. void set_xpu_dispatch_ptr(FnPtr fn_ptr){
  239. impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  240. }
  241. #endif
  242. void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
  243. impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  244. }
  245. void set_hip_dispatch_ptr(FnPtr fn_ptr) {
  246. impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  247. }
  248. void set_mps_dispatch_ptr(FnPtr fn_ptr) {
  249. impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  250. }
  251. void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
  252. impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  253. }
  254. void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
  255. impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  256. }
  257. // Returns true if the dispatcher has a kernel registered for this device
  258. // type.
  259. bool is_device_supported(const c10::DeviceType device_type) {
  260. auto result = impl.try_get_call_ptr(device_type
  261. , reinterpret_cast<void*>(DEFAULT)
  262. #ifdef HAVE_AVX512_CPU_DEFINITION
  263. , reinterpret_cast<void*>(AVX512)
  264. #endif
  265. #ifdef HAVE_AVX2_CPU_DEFINITION
  266. , reinterpret_cast<void*>(AVX2)
  267. #endif
  268. #ifdef HAVE_VSX_CPU_DEFINITION
  269. , reinterpret_cast<void*>(VSX)
  270. #endif
  271. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  272. , reinterpret_cast<void*>(ZVECTOR)
  273. #endif
  274. #ifdef HAVE_SVE256_CPU_DEFINITION
  275. , reinterpret_cast<void*>(SVE256)
  276. #endif
  277. );
  278. if (std::holds_alternative<ErrorType>(result)){
  279. return false;
  280. }
  281. return true;
  282. }
  283. static TORCH_API FnPtr DEFAULT;
  284. #ifdef HAVE_AVX512_CPU_DEFINITION
  285. static TORCH_API FnPtr AVX512;
  286. #endif
  287. #ifdef HAVE_AVX2_CPU_DEFINITION
  288. static TORCH_API FnPtr AVX2;
  289. #endif
  290. #ifdef HAVE_VSX_CPU_DEFINITION
  291. static TORCH_API FnPtr VSX;
  292. #endif
  293. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  294. static TORCH_API FnPtr ZVECTOR;
  295. #endif
  296. #ifdef HAVE_SVE256_CPU_DEFINITION
  297. static TORCH_API FnPtr SVE256;
  298. #endif
  299. private:
  300. DispatchStubImpl impl;
  301. };
  302. namespace {
  303. template <typename DispatchStub>
  304. struct RegisterCUDADispatch {
  305. RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  306. stub.set_cuda_dispatch_ptr(value);
  307. }
  308. };
  309. template <typename DispatchStub>
  310. struct RegisterXPUDispatch {
  311. RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
  312. stub.set_xpu_dispatch_ptr(value);
  313. }
  314. };
  315. template <typename DispatchStub>
  316. struct RegisterHPUDispatch {
  317. RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
  318. stub.set_hpu_dispatch_ptr(value);
  319. }
  320. };
  321. template <typename DispatchStub>
  322. struct RegisterMPSDispatch {
  323. RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  324. stub.set_mps_dispatch_ptr(value);
  325. }
  326. };
  327. template <typename DispatchStub>
  328. struct RegisterHIPDispatch {
  329. RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  330. // TODO: make this point at hip_dispatch_ptr
  331. stub.set_cuda_dispatch_ptr(value);
  332. }
  333. };
  334. template <typename DispatchStub>
  335. struct RegisterMTIADispatch {
  336. RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  337. stub.set_mtia_dispatch_ptr(value);
  338. }
  339. };
  340. template <typename DispatchStub>
  341. struct RegisterPRIVATEUSE1Dispatch {
  342. RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  343. stub.set_privateuse1_dispatch_ptr(value);
  344. }
  345. };
  346. } // anonymous namespace
  347. // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
  348. // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
  349. // adding parentheses and using helper struct to get rid of the parentheses, do
  350. // not work with MSVC. So do a `using`-declaration if you need to pass in such
  351. // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
  352. #define DECLARE_DISPATCH(fn, name) \
  353. struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
  354. name##_DECLARE_DISPATCH_type() = default; \
  355. name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
  356. name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
  357. name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete; \
  358. name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete; \
  359. ~name##_DECLARE_DISPATCH_type() = default; \
  360. }; \
  361. extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
  362. #define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
  363. #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
  364. template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
  365. #ifdef HAVE_AVX512_CPU_DEFINITION
  366. #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
  367. #else
  368. #define REGISTER_AVX512_DISPATCH(name, fn)
  369. #endif
  370. #ifdef HAVE_AVX2_CPU_DEFINITION
  371. #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
  372. #else
  373. #define REGISTER_AVX2_DISPATCH(name, fn)
  374. #endif
  375. #ifdef HAVE_VSX_CPU_DEFINITION
  376. #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
  377. #else
  378. #define REGISTER_VSX_DISPATCH(name, fn)
  379. #endif
  380. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  381. #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
  382. #else
  383. #define REGISTER_ZVECTOR_DISPATCH(name, fn)
  384. #endif
  385. #ifdef HAVE_SVE256_CPU_DEFINITION
  386. #define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
  387. #else
  388. #define REGISTER_SVE256_DISPATCH(name, fn)
  389. #endif
  390. // Macro to register the same kernel for all CPU arch types. This is useful
  391. // if a kernel does not benefit from being recompiled across different arch types.
  392. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \
  393. REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
  394. REGISTER_AVX512_DISPATCH(name, fn) \
  395. REGISTER_AVX2_DISPATCH(name, fn) \
  396. REGISTER_VSX_DISPATCH(name, fn) \
  397. REGISTER_ZVECTOR_DISPATCH(name, fn) \
  398. REGISTER_SVE256_DISPATCH(name, fn)
  399. #define REGISTER_NO_CPU_DISPATCH(name) \
  400. REGISTER_ALL_CPU_DISPATCH(name, nullptr)
  401. #define REGISTER_CUDA_DISPATCH(name, fn) \
  402. static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  403. #define REGISTER_XPU_DISPATCH(name, fn) \
  404. static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  405. #define REGISTER_HPU_DISPATCH(name, fn) \
  406. static RegisterHPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  407. #define REGISTER_HIP_DISPATCH(name, fn) \
  408. static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  409. #define REGISTER_MPS_DISPATCH(name, fn) \
  410. static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  411. #define REGISTER_MTIA_DISPATCH(name, fn) \
  412. static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  413. #define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
  414. static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  415. // NB: This macro must be used in an actual 'cu' file; if you try using
  416. // it from a 'cpp' file it will not work!
  417. #if defined(__CUDACC__)
  418. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  419. #elif defined(__HIPCC__)
  420. // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
  421. // is HIP in the PyTorch HIPify build.
  422. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  423. // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
  424. #elif defined(__OBJC__) && defined(USE_MPS)
  425. // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
  426. #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
  427. #elif defined(CPU_CAPABILITY)
  428. // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
  429. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
  430. // ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others.
  431. #ifdef CPU_CAPABILITY_AVX512
  432. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
  433. #else
  434. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  435. #endif
  436. #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  437. #define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  438. #endif
  439. } // namespace at::native
  440. C10_CLANG_DIAGNOSTIC_POP()