DeviceUtils.cuh 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #pragma once
  2. #include <cuda.h>
  3. #include <c10/util/complex.h>
  4. #include <c10/util/Half.h>
  5. __device__ __forceinline__ unsigned int ACTIVE_MASK()
  6. {
  7. #if !defined(USE_ROCM)
  8. return __activemask();
  9. #else
  10. // will be ignored anyway
  11. return 0xffffffff;
  12. #endif
  13. }
  14. __device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
  15. #if !defined(USE_ROCM)
  16. return __syncwarp(mask);
  17. #endif
  18. }
  19. #if defined(USE_ROCM)
  20. __device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
  21. {
  22. return __ballot(predicate);
  23. }
  24. #else
  25. __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
  26. {
  27. #if !defined(USE_ROCM)
  28. return __ballot_sync(mask, predicate);
  29. #else
  30. return __ballot(predicate);
  31. #endif
  32. }
  33. #endif
  34. template <typename T>
  35. __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
  36. {
  37. #if !defined(USE_ROCM)
  38. return __shfl_xor_sync(mask, value, laneMask, width);
  39. #else
  40. return __shfl_xor(value, laneMask, width);
  41. #endif
  42. }
  43. template <typename T>
  44. __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
  45. {
  46. #if !defined(USE_ROCM)
  47. return __shfl_sync(mask, value, srcLane, width);
  48. #else
  49. return __shfl(value, srcLane, width);
  50. #endif
  51. }
  52. template <typename T>
  53. __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  54. {
  55. #if !defined(USE_ROCM)
  56. return __shfl_up_sync(mask, value, delta, width);
  57. #else
  58. return __shfl_up(value, delta, width);
  59. #endif
  60. }
  61. template <typename T>
  62. __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  63. {
  64. #if !defined(USE_ROCM)
  65. return __shfl_down_sync(mask, value, delta, width);
  66. #else
  67. return __shfl_down(value, delta, width);
  68. #endif
  69. }
  70. #if defined(USE_ROCM)
  71. template<>
  72. __device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
  73. {
  74. //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
  75. int2 a = *reinterpret_cast<int2*>(&value);
  76. a.x = __shfl_down(a.x, delta);
  77. a.y = __shfl_down(a.y, delta);
  78. return *reinterpret_cast<int64_t*>(&a);
  79. }
  80. #endif
  81. template<>
  82. __device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value, unsigned int delta, int width, unsigned int mask)
  83. {
  84. return c10::Half(WARP_SHFL_DOWN<unsigned short>(value.x, delta, width, mask), c10::Half::from_bits_t{});
  85. }
  86. template <typename T>
  87. __device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  88. {
  89. #if !defined(USE_ROCM)
  90. return c10::complex<T>(
  91. __shfl_down_sync(mask, value.real_, delta, width),
  92. __shfl_down_sync(mask, value.imag_, delta, width));
  93. #else
  94. return c10::complex<T>(
  95. __shfl_down(value.real_, delta, width),
  96. __shfl_down(value.imag_, delta, width));
  97. #endif
  98. }
  99. /**
  100. * For CC 3.5+, perform a load using __ldg
  101. */
  102. template <typename T>
  103. __device__ __forceinline__ T doLdg(const T* p) {
  104. #if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
  105. return __ldg(p);
  106. #else
  107. return *p;
  108. #endif
  109. }