cub.cuh 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. #pragma once
  2. #include <ATen/cuda/cub.h>
  3. #include <cstddef>
  4. #include <type_traits>
  5. #include <iterator>
  6. #include <limits>
  7. #ifndef USE_ROCM
  8. #include <cuda/std/functional>
  9. #endif
  10. #include <ATen/cuda/cub_definitions.cuh>
  11. #include <ATen/cuda/CUDAContextLight.h>
  12. #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
  13. #include <cub/cub.cuh>
  14. #else
  15. // include cub in a safe manner, see:
  16. // https://github.com/pytorch/pytorch/pull/55292
  17. #undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
  18. #undef CUB_NS_PREFIX
  19. #undef CUB_NS_QUALIFIER
  20. #define CUB_NS_PREFIX namespace at_cuda_detail {
  21. #define CUB_NS_POSTFIX }
  22. #define CUB_NS_QUALIFIER ::at_cuda_detail::cub
  23. #include <cub/cub.cuh>
  24. #undef CUB_NS_POSTFIX
  25. #undef CUB_NS_PREFIX
  26. #undef CUB_NS_QUALIFIER
  27. #endif
  28. #include <ATen/cuda/Exceptions.h>
  29. #include <c10/cuda/CUDACachingAllocator.h>
  30. #include <c10/cuda/CUDAStream.h>
  31. // handle the temporary storage and 'twice' calls for cub API
  32. #define CUB_WRAPPER(func, ...) do { \
  33. size_t temp_storage_bytes = 0; \
  34. AT_CUDA_CHECK(func(nullptr, temp_storage_bytes, __VA_ARGS__)); \
  35. auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
  36. auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
  37. AT_CUDA_CHECK(func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__));\
  38. } while (false)
  39. #ifdef USE_ROCM
  40. #define NO_ROCM(x)
  41. #define ROCM_HIPCUB(x) ::hipcub
  42. #else
  43. #define NO_ROCM(x) x
  44. #define ROCM_HIPCUB(x) x
  45. #endif
  46. #if CUB_V3_PLUS()
  47. #include <thrust/iterator/transform_iterator.h>
  48. #include <thrust/iterator/counting_iterator.h>
  49. #include <thrust/iterator/constant_iterator.h>
  50. #define ATEN_CUB_TRANSFORM_ITERATOR(ValueType, ...) ::thrust::transform_iterator<__VA_ARGS__>
  51. #define ATEN_CUB_COUNTING_ITERATOR(...) ::thrust::counting_iterator<__VA_ARGS__>
  52. #define ATEN_CUB_CONSTANT_ITERATOR(...) ::thrust::constant_iterator<__VA_ARGS__>
  53. #define ATEN_CUB_MAXIMUM() ::cuda::maximum<>()
  54. #else
  55. #define ATEN_CUB_TRANSFORM_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::TransformInputIterator<__VA_ARGS__>
  56. #define ATEN_CUB_COUNTING_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::CountingInputIterator<__VA_ARGS__>
  57. #define ATEN_CUB_CONSTANT_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<__VA_ARGS__>
  58. #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
  59. #endif
  60. #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
  61. #if !defined(USE_ROCM)
  62. namespace at_cuda_detail {
  63. #endif
  64. // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
  65. template <>
  66. struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
  67. {
  68. static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
  69. unsigned short max_word = 0x7F7F;
  70. return reinterpret_cast<c10::BFloat16&>(max_word);
  71. }
  72. static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
  73. unsigned short lowest_word = 0xFF7F;
  74. return reinterpret_cast<c10::BFloat16&>(lowest_word);
  75. }
  76. };
  77. template <>
  78. struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
  79. ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
  80. #if !defined(USE_ROCM)
  81. } // namespace at_cuda_detail
  82. #endif
  83. #endif
  84. #if !defined(USE_ROCM)
  85. namespace at::native {
  86. namespace cub = ::at_cuda_detail::cub;
  87. } // namespace at::native
  88. #endif
  89. namespace at::cuda::cub {
  90. namespace detail {
  91. template<typename T>
  92. struct cuda_type {
  93. using type = T;
  94. };
  95. template<>
  96. struct cuda_type<c10::Half> {
  97. using type = __half;
  98. };
  99. #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
  100. template<>
  101. struct cuda_type<c10::BFloat16> {
  102. using type = __nv_bfloat16;
  103. };
  104. #elif defined(USE_ROCM)
  105. template<>
  106. struct cuda_type<c10::BFloat16> {
  107. using type = hip_bfloat16;
  108. };
  109. #endif
  110. } // namespace detail
  111. template<typename key_t, typename value_t, typename OffsetIteratorT>
  112. inline void segmented_sort_pairs(
  113. const key_t *keys_in, key_t *keys_out,
  114. const value_t *values_in, value_t *values_out,
  115. int64_t num_elements, int64_t num_segments,
  116. OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
  117. bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
  118. ) {
  119. TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
  120. "cub sort does not support sorting more than INT_MAX elements");
  121. TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
  122. "cub sort does not support sorting more than INT_MAX elements");
  123. using key_t_ = typename detail::cuda_type<key_t>::type;
  124. auto allocator = c10::cuda::CUDACachingAllocator::get();
  125. c10::DataPtr keys_out_owner;
  126. if (keys_out == nullptr) {
  127. keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
  128. keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
  129. }
  130. const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
  131. key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
  132. if (descending) {
  133. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
  134. keys_in_, keys_out_, values_in, values_out,
  135. num_elements, num_segments, begin_offsets, end_offsets,
  136. begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
  137. } else {
  138. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
  139. keys_in_, keys_out_, values_in, values_out,
  140. num_elements, num_segments, begin_offsets, end_offsets,
  141. begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
  142. }
  143. }
  144. #if CUB_SUPPORTS_UNIQUE_BY_KEY()
  145. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
  146. inline void unique_by_key(
  147. KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
  148. ValuesOutputIteratorT values_out,
  149. NumSelectedIteratorT num_selected, int64_t num_input_items)
  150. {
  151. // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
  152. using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
  153. auto allocator = c10::cuda::CUDACachingAllocator::get();
  154. c10::DataPtr keys_out_owner;
  155. keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
  156. auto keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
  157. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
  158. keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
  159. }
  160. #endif
  161. namespace impl {
  162. template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
  163. C10_LAUNCH_BOUNDS_1(1)
  164. __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
  165. // NOTE: out here not the final scan output, but an intermediate of the accumulation type.
  166. using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
  167. *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
  168. }
  169. #if !CUB_SUPPORTS_FUTURE_VALUE()
  170. template<typename ValueT, typename InputIteratorT>
  171. struct chained_iterator {
  172. using iterator_category = std::random_access_iterator_tag;
  173. using difference_type = std::ptrdiff_t;
  174. using value_type = ValueT;
  175. using pointer = ValueT*;
  176. using reference = ValueT&;
  177. InputIteratorT iter;
  178. ValueT *first;
  179. difference_type offset = 0;
  180. __device__ ValueT operator[](difference_type i) {
  181. i += offset;
  182. if (i == 0) {
  183. return *first;
  184. } else {
  185. return ValueT(iter[i - 1]);
  186. }
  187. }
  188. __device__ chained_iterator operator+(difference_type i) {
  189. return chained_iterator{iter, first, i};
  190. }
  191. __device__ ValueT operator*() {
  192. return (*this)[0];
  193. }
  194. };
  195. #endif
  196. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  197. // so split at int_max/2
  198. constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
  199. }
  200. // non synchronizing cub call
  201. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  202. // so split at int_max/2
  203. template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
  204. inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
  205. #if defined(USE_ROCM)
  206. //For ROCm, use hipCUB chained iterators
  207. CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
  208. input,
  209. output,
  210. scan_op,
  211. num_items,
  212. at::cuda::getCurrentCUDAStream());
  213. C10_HIP_KERNEL_LAUNCH_CHECK();
  214. #else
  215. // non synchronizing cub call
  216. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  217. // so split at int_max/2
  218. int size_cub = std::min<int64_t>(num_items, max_cub_size);
  219. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  220. input,
  221. output,
  222. scan_op,
  223. size_cub,
  224. at::cuda::getCurrentCUDAStream());
  225. C10_CUDA_KERNEL_LAUNCH_CHECK();
  226. using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
  227. for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
  228. auto allocator = c10::cuda::CUDACachingAllocator::get();
  229. c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
  230. auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
  231. size_cub = std::min<int64_t>(num_items - i, max_cub_size);
  232. impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  233. output + i - 1,
  234. input + i,
  235. first_elem_ptr,
  236. scan_op);
  237. C10_CUDA_KERNEL_LAUNCH_CHECK();
  238. #if !CUB_SUPPORTS_FUTURE_VALUE()
  239. using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
  240. using tuple = typename ArgIndexInputIterator::value_type;
  241. auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
  242. if (x.key == 0) {
  243. return *first_elem_ptr;
  244. } else {
  245. return x.value;
  246. }
  247. };
  248. auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
  249. ArgIndexInputIterator(input + i), input_iter_transform);
  250. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  251. input_,
  252. output + i,
  253. scan_op,
  254. size_cub,
  255. at::cuda::getCurrentCUDAStream());
  256. #else
  257. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  258. input + i + 1,
  259. output + i,
  260. scan_op,
  261. ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
  262. size_cub,
  263. at::cuda::getCurrentCUDAStream());
  264. #endif
  265. }
  266. #endif
  267. }
  268. # if defined(CUDA_VERSION) || defined(USE_ROCM)
  269. template<typename T>
  270. struct BlockPrefixCallbackOp
  271. {
  272. public:
  273. T running_total;
  274. __host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {}
  275. // Callback operator to be entered by the first warp of threads in the block.
  276. // Thread-0 is responsible for returning a value for seeding the block-wide scan.
  277. __host__ __device__ T operator()(T block_aggregate)
  278. {
  279. T old_prefix = running_total;
  280. running_total += block_aggregate;
  281. return old_prefix;
  282. }
  283. };
  284. template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
  285. __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) {
  286. int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
  287. int64_t remaining = nelem - offset;
  288. if (remaining <= 0) {
  289. return;
  290. }
  291. d_in += offset;
  292. d_out += offset;
  293. using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
  294. // Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared
  295. // memory to a blocked arrangement)
  296. using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_STORE_WARP_TRANSPOSE>;
  297. // Specialize BlockScan type for our thread block
  298. using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<T, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
  299. using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
  300. // Shared memory
  301. __shared__ union TempStorage
  302. {
  303. typename BlockLoadT::TempStorage load;
  304. typename BlockStoreT::TempStorage store;
  305. typename BlockScanT::TempStorage scan;
  306. typename BlockReduceT::TempStorage reduce;
  307. } temp_storage;
  308. // load agg and reduce my starting value
  309. T agg_data;
  310. agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x];
  311. // if there are fewer threads than previous values to be read,
  312. // read another value
  313. if (threadIdx.x + blockDim.x < blockIdx.x) {
  314. agg_data += agg[threadIdx.x + blockDim.x];
  315. }
  316. T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data);
  317. __syncthreads();
  318. BlockPrefixCallbackOp prefix_op(aggregate);
  319. // Per-thread tile data
  320. T data[ITEMS_PER_THREAD];
  321. for (int i=0; i<iters_per_cta; i++){
  322. // Load items into a blocked arrangement
  323. if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
  324. BlockLoadT(temp_storage.load).Load(d_in, data);
  325. } else {
  326. #pragma unroll
  327. for (int j=0; j<ITEMS_PER_THREAD; j++) {
  328. data[j] = 0;
  329. }
  330. BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
  331. }
  332. // Barrier for smem reuse
  333. __syncthreads();
  334. // Compute inclusive prefix sum
  335. BlockScanT(temp_storage.scan).InclusiveSum(data, data, prefix_op);
  336. // Barrier for smem reuse
  337. __syncthreads();
  338. // Store items from a blocked arrangement
  339. if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
  340. BlockStoreT(temp_storage.store).Store(d_out, data);
  341. } else {
  342. BlockStoreT(temp_storage.store).Store(d_out, data, remaining);
  343. }
  344. d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
  345. d_out += BLOCK_THREADS * ITEMS_PER_THREAD;
  346. remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
  347. if (remaining <= 0) return;
  348. __syncthreads();
  349. }
  350. }
  351. template <typename T, typename aggT, bool nonzero>
  352. struct TransformFunctor {
  353. __device__ aggT operator()(T value) const {
  354. if constexpr (!nonzero) {
  355. return value;
  356. } else {
  357. return (value != T(0)) ? 1 : 0;
  358. }
  359. }
  360. };
  361. template<int BLOCK_THREADS, int ITEMS_PER_THREAD, bool nonzero, typename T, typename aggT>
  362. __global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){
  363. int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
  364. int64_t remaining = nelem - offset;
  365. if (remaining <= 0) {
  366. return;
  367. }
  368. d_in += offset;
  369. using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<aggT, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
  370. using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<aggT, BLOCK_THREADS>;
  371. // Shared memory
  372. __shared__ union TempStorage
  373. {
  374. typename BlockLoadT::TempStorage load;
  375. typename BlockReduceT::TempStorage reduce;
  376. } temp_storage;
  377. aggT data[ITEMS_PER_THREAD];
  378. aggT agg_val = 0;
  379. TransformFunctor<T, aggT, nonzero> transform_functor;
  380. auto iter_in = ATEN_CUB_TRANSFORM_ITERATOR(aggT, TransformFunctor<T, aggT, nonzero>, const T*)(d_in, transform_functor);
  381. for (int i=0; i<iters_per_cta; i++){
  382. if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
  383. BlockLoadT(temp_storage.load).Load(iter_in, data);
  384. __syncthreads();
  385. agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
  386. } else {
  387. BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0));
  388. __syncthreads();
  389. agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
  390. }
  391. iter_in += BLOCK_THREADS * ITEMS_PER_THREAD;
  392. remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
  393. if (remaining <= 0) {
  394. // for nonzeros we need to write out last blocks
  395. // accumulated value to be able to compute
  396. // total number of nonzeros
  397. if (nonzero && threadIdx.x == 0) {
  398. agg[blockIdx.x] = agg_val;
  399. }
  400. return;
  401. }
  402. __syncthreads();
  403. }
  404. if (threadIdx.x == 0) {
  405. agg[blockIdx.x] = agg_val;
  406. }
  407. }
  408. template <typename T>
  409. struct NonZeroOp {
  410. __host__ __device__ __forceinline__ int operator()(const T& a) const {
  411. return (a != T(0));
  412. }
  413. };
  414. template<int size>
  415. constexpr int block_threads(){
  416. if constexpr (size >=16) {
  417. return 128;
  418. } else if constexpr (size >=8) {
  419. return 256;
  420. } else {
  421. return 512;
  422. }
  423. }
  424. template<typename scalar_t, typename ScanOpT>
  425. inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) {
  426. static_assert(std::is_same_v<ScanOpT, std::plus<scalar_t>>, "");
  427. constexpr int BLOCK_THREADS = block_threads<sizeof(scalar_t)>();
  428. constexpr int ITEMS_PER_THREAD = 16;
  429. auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
  430. const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  431. const int iters_per_cta = (grid_size + num_sms - 1)/num_sms;
  432. grid_size = std::min(num_sms, grid_size);
  433. // simple reduction in scan kernel handles at most 2 items per thread
  434. TORCH_INTERNAL_ASSERT(2 * BLOCK_THREADS >= grid_size);
  435. auto& allocator = *c10::cuda::CUDACachingAllocator::get();
  436. auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
  437. calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, false>
  438. <<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
  439. input, (scalar_t*)agg.get(), num_items, iters_per_cta);
  440. C10_CUDA_KERNEL_LAUNCH_CHECK();
  441. final_scan_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
  442. <<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
  443. input, output, (scalar_t*)agg.get(), num_items, iters_per_cta);
  444. C10_CUDA_KERNEL_LAUNCH_CHECK();
  445. }
  446. #endif
  447. template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
  448. inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
  449. #if defined(USE_ROCM)
  450. //For ROCm, use hipCUB chained iterators
  451. CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
  452. input,
  453. output,
  454. scan_op,
  455. init_value,
  456. num_items,
  457. at::cuda::getCurrentCUDAStream());
  458. C10_HIP_KERNEL_LAUNCH_CHECK();
  459. #else
  460. // non synchronizing cub call
  461. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  462. // so split at int_max/2
  463. int size_cub = std::min<int64_t>(num_items, max_cub_size);
  464. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  465. input,
  466. output,
  467. scan_op,
  468. init_value,
  469. size_cub,
  470. at::cuda::getCurrentCUDAStream());
  471. C10_CUDA_KERNEL_LAUNCH_CHECK();
  472. for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
  473. auto allocator = c10::cuda::CUDACachingAllocator::get();
  474. c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
  475. auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
  476. size_cub = std::min<int64_t>(num_items - i, max_cub_size);
  477. impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  478. output + i - 1,
  479. input + i - 1,
  480. first_elem_ptr,
  481. scan_op);
  482. C10_CUDA_KERNEL_LAUNCH_CHECK();
  483. #if !CUB_SUPPORTS_FUTURE_VALUE()
  484. auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
  485. input + i, first_elem_ptr};
  486. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  487. input_,
  488. output + i,
  489. scan_op,
  490. size_cub,
  491. at::cuda::getCurrentCUDAStream());
  492. #else
  493. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  494. input + i,
  495. output + i,
  496. scan_op,
  497. ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
  498. size_cub,
  499. at::cuda::getCurrentCUDAStream());
  500. #endif
  501. }
  502. #endif
  503. }
  504. #if CUB_SUPPORTS_SCAN_BY_KEY()
  505. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
  506. inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
  507. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  508. "cub InclusiveSumByKey does not support more than INT_MAX elements");
  509. #if !defined(USE_ROCM)
  510. CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
  511. keys, input, output, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream());
  512. #else
  513. CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey,
  514. keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
  515. #endif
  516. }
  517. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
  518. inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
  519. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  520. "cub InclusiveSumByKey does not support more than INT_MAX elements");
  521. #if !defined(USE_ROCM)
  522. CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
  523. keys, input, output, scan_op, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream());
  524. #else
  525. CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey,
  526. keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
  527. #endif
  528. }
  529. #endif
  530. template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
  531. void unique(InputIteratorT input, OutputIteratorT output,
  532. NumSelectedIteratorT num_selected_out, int64_t num_items) {
  533. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  534. "cub unique does not support more than INT_MAX elements");
  535. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
  536. input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
  537. }
  538. template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
  539. typename LengthOutputIteratorT>
  540. void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
  541. LengthOutputIteratorT length_out, int64_t num_items) {
  542. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  543. "cub run_length_encode does not support more than INT_MAX elements");
  544. CUB_WRAPPER(
  545. NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
  546. input, output, counts_out, length_out, num_items,
  547. at::cuda::getCurrentCUDAStream());
  548. }
  549. template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
  550. void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
  551. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  552. "cub reduce does not support more than INT_MAX elements");
  553. CUB_WRAPPER(
  554. NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
  555. input, output, num_items, op, init,
  556. at::cuda::getCurrentCUDAStream());
  557. }
  558. } // namespace at::cuda::cub