ScanUtils.cuh 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #pragma once
  2. #include <ATen/ceil_div.h>
  3. #include <ATen/cuda/DeviceUtils.cuh>
  4. #include <ATen/cuda/AsmUtils.cuh>
  5. #include <c10/macros/Macros.h>
  6. // Collection of in-kernel scan / prefix sum utilities
  7. namespace at::cuda {
  8. // Inclusive prefix sum for binary vars using intra-warp voting +
  9. // shared memory
  10. template <typename T, bool KillWARDependency, class BinaryFunction>
  11. __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
  12. // Within-warp, we use warp voting.
  13. #if defined (USE_ROCM)
  14. unsigned long long int vote = WARP_BALLOT(in);
  15. T index = __popcll(getLaneMaskLe() & vote);
  16. T carry = __popcll(vote);
  17. #else
  18. T vote = WARP_BALLOT(in);
  19. T index = __popc(getLaneMaskLe() & vote);
  20. T carry = __popc(vote);
  21. #endif
  22. int warp = threadIdx.x / C10_WARP_SIZE;
  23. // Per each warp, write out a value
  24. if (getLaneId() == 0) {
  25. smem[warp] = carry;
  26. }
  27. __syncthreads();
  28. // Sum across warps in one thread. This appears to be faster than a
  29. // warp shuffle scan for CC 3.0+
  30. if (threadIdx.x == 0) {
  31. int current = 0;
  32. for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
  33. T v = smem[i];
  34. smem[i] = binop(smem[i], current);
  35. current = binop(current, v);
  36. }
  37. }
  38. __syncthreads();
  39. // load the carry from the preceding warp
  40. if (warp >= 1) {
  41. index = binop(index, smem[warp - 1]);
  42. }
  43. *out = index;
  44. if (KillWARDependency) {
  45. __syncthreads();
  46. }
  47. }
  48. // Exclusive prefix sum for binary vars using intra-warp voting +
  49. // shared memory
  50. template <typename T, bool KillWARDependency, class BinaryFunction>
  51. __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
  52. inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
  53. // Inclusive to exclusive
  54. *out -= (T) in;
  55. // The outgoing carry for all threads is the last warp's sum
  56. *carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
  57. if (KillWARDependency) {
  58. __syncthreads();
  59. }
  60. }
  61. } // namespace at::cuda