| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- #include <torch/extension.h>
- #include <ATen/ATen.h>
- #include "cuda_launch.h"
- #include "cuda_kernel.h"
- #include <vector>
- //////////////////////////////////////////////////////////////////////////////////////////////////
- //////////////////////////////////////////////////////////////////////////////////////////////////
- std::vector<at::Tensor> index_max_kernel(
- at::Tensor index_vals, // [batch_size, 32, num_block]
- at::Tensor indices, // [batch_size, num_block],
- int A_num_block,
- int B_num_block
- ) {
- int batch_size = indices.size(0);
- int num_block = indices.size(1);
- at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options());
- at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options());
- dim3 threads(256);
- dim3 blocks(batch_size);
- int shared_mem = A_num_block * 32 * sizeof(float);
- index_max_cuda_kernel<<<blocks, threads, shared_mem>>>(
- index_vals.data_ptr<float>(),
- indices.data_ptr<int>(),
- max_vals.data_ptr<float>(),
- max_vals_scatter.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- num_block
- );
- return {max_vals, max_vals_scatter};
- }
- at::Tensor mm_to_sparse_kernel(
- at::Tensor dense_A, // [batch_size, A_num_block, dim, 32]
- at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
- at::Tensor indices // [batch_size, num_block]
- ) {
- int batch_size = dense_A.size(0);
- int A_num_block = dense_A.size(1);
- int B_num_block = dense_B.size(1);
- int dim = dense_A.size(2);
- int num_block = indices.size(1);
- at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
- dim3 threads(64, 4);
- dim3 blocks(num_block / 4, batch_size);
- mm_to_sparse_cuda_kernel<<<blocks, threads>>>(
- dense_A.data_ptr<float>(),
- dense_B.data_ptr<float>(),
- indices.data_ptr<int>(),
- sparse_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- dim,
- num_block
- );
- return sparse_C;
- }
- at::Tensor sparse_dense_mm_kernel(
- at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
- at::Tensor indices, // [batch_size, num_block]
- at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
- int A_num_block
- ) {
- int batch_size = sparse_A.size(0);
- int num_block = sparse_A.size(1);
- int B_num_block = dense_B.size(1);
- int dim = dense_B.size(2);
- at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options());
- dim3 threads(128, 2);
- dim3 blocks(num_block / 2, batch_size);
- sparse_dense_mm_cuda_kernel<<<blocks, threads>>>(
- sparse_A.data_ptr<float>(),
- indices.data_ptr<int>(),
- dense_B.data_ptr<float>(),
- dense_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- dim,
- num_block
- );
- return dense_C;
- }
- at::Tensor reduce_sum_kernel(
- at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
- at::Tensor indices, // [batch_size, num_block]
- int A_num_block,
- int B_num_block
- ) {
- int batch_size = sparse_A.size(0);
- int num_block = sparse_A.size(1);
- at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options());
- dim3 threads(32, 4);
- dim3 blocks(num_block / 4, batch_size);
- reduce_sum_cuda_kernel<<<blocks, threads>>>(
- sparse_A.data_ptr<float>(),
- indices.data_ptr<int>(),
- dense_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- num_block
- );
- return dense_C;
- }
- at::Tensor scatter_kernel(
- at::Tensor dense_A, // [batch_size, A_num_block, 32]
- at::Tensor indices, // [batch_size, num_block]
- int B_num_block
- ) {
- int batch_size = dense_A.size(0);
- int A_num_block = dense_A.size(1);
- int num_block = indices.size(1);
- at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
- dim3 threads(32, 4);
- dim3 blocks(num_block / 4, batch_size);
- scatter_cuda_kernel<<<blocks, threads>>>(
- dense_A.data_ptr<float>(),
- indices.data_ptr<int>(),
- sparse_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- num_block
- );
- return sparse_C;
- }
|