cuda_launch.cu 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #include <torch/extension.h>
  2. #include <ATen/ATen.h>
  3. #include "cuda_launch.h"
  4. #include "cuda_kernel.h"
  5. #include <vector>
  6. //////////////////////////////////////////////////////////////////////////////////////////////////
  7. //////////////////////////////////////////////////////////////////////////////////////////////////
  8. std::vector<at::Tensor> index_max_kernel(
  9. at::Tensor index_vals, // [batch_size, 32, num_block]
  10. at::Tensor indices, // [batch_size, num_block],
  11. int A_num_block,
  12. int B_num_block
  13. ) {
  14. int batch_size = indices.size(0);
  15. int num_block = indices.size(1);
  16. at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options());
  17. at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options());
  18. dim3 threads(256);
  19. dim3 blocks(batch_size);
  20. int shared_mem = A_num_block * 32 * sizeof(float);
  21. index_max_cuda_kernel<<<blocks, threads, shared_mem>>>(
  22. index_vals.data_ptr<float>(),
  23. indices.data_ptr<int>(),
  24. max_vals.data_ptr<float>(),
  25. max_vals_scatter.data_ptr<float>(),
  26. batch_size,
  27. A_num_block,
  28. B_num_block,
  29. num_block
  30. );
  31. return {max_vals, max_vals_scatter};
  32. }
  33. at::Tensor mm_to_sparse_kernel(
  34. at::Tensor dense_A, // [batch_size, A_num_block, dim, 32]
  35. at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
  36. at::Tensor indices // [batch_size, num_block]
  37. ) {
  38. int batch_size = dense_A.size(0);
  39. int A_num_block = dense_A.size(1);
  40. int B_num_block = dense_B.size(1);
  41. int dim = dense_A.size(2);
  42. int num_block = indices.size(1);
  43. at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
  44. dim3 threads(64, 4);
  45. dim3 blocks(num_block / 4, batch_size);
  46. mm_to_sparse_cuda_kernel<<<blocks, threads>>>(
  47. dense_A.data_ptr<float>(),
  48. dense_B.data_ptr<float>(),
  49. indices.data_ptr<int>(),
  50. sparse_C.data_ptr<float>(),
  51. batch_size,
  52. A_num_block,
  53. B_num_block,
  54. dim,
  55. num_block
  56. );
  57. return sparse_C;
  58. }
  59. at::Tensor sparse_dense_mm_kernel(
  60. at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
  61. at::Tensor indices, // [batch_size, num_block]
  62. at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
  63. int A_num_block
  64. ) {
  65. int batch_size = sparse_A.size(0);
  66. int num_block = sparse_A.size(1);
  67. int B_num_block = dense_B.size(1);
  68. int dim = dense_B.size(2);
  69. at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options());
  70. dim3 threads(128, 2);
  71. dim3 blocks(num_block / 2, batch_size);
  72. sparse_dense_mm_cuda_kernel<<<blocks, threads>>>(
  73. sparse_A.data_ptr<float>(),
  74. indices.data_ptr<int>(),
  75. dense_B.data_ptr<float>(),
  76. dense_C.data_ptr<float>(),
  77. batch_size,
  78. A_num_block,
  79. B_num_block,
  80. dim,
  81. num_block
  82. );
  83. return dense_C;
  84. }
  85. at::Tensor reduce_sum_kernel(
  86. at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
  87. at::Tensor indices, // [batch_size, num_block]
  88. int A_num_block,
  89. int B_num_block
  90. ) {
  91. int batch_size = sparse_A.size(0);
  92. int num_block = sparse_A.size(1);
  93. at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options());
  94. dim3 threads(32, 4);
  95. dim3 blocks(num_block / 4, batch_size);
  96. reduce_sum_cuda_kernel<<<blocks, threads>>>(
  97. sparse_A.data_ptr<float>(),
  98. indices.data_ptr<int>(),
  99. dense_C.data_ptr<float>(),
  100. batch_size,
  101. A_num_block,
  102. B_num_block,
  103. num_block
  104. );
  105. return dense_C;
  106. }
  107. at::Tensor scatter_kernel(
  108. at::Tensor dense_A, // [batch_size, A_num_block, 32]
  109. at::Tensor indices, // [batch_size, num_block]
  110. int B_num_block
  111. ) {
  112. int batch_size = dense_A.size(0);
  113. int A_num_block = dense_A.size(1);
  114. int num_block = indices.size(1);
  115. at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
  116. dim3 threads(32, 4);
  117. dim3 blocks(num_block / 4, batch_size);
  118. scatter_cuda_kernel<<<blocks, threads>>>(
  119. dense_A.data_ptr<float>(),
  120. indices.data_ptr<int>(),
  121. sparse_C.data_ptr<float>(),
  122. batch_size,
  123. A_num_block,
  124. B_num_block,
  125. num_block
  126. );
  127. return sparse_C;
  128. }