total_variation_kernel.cu 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #include <torch/extension.h>
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <vector>
  5. template <typename scalar_t, typename bound_t>
  6. __device__ __forceinline__ scalar_t clamp(const scalar_t v, const bound_t lo, const bound_t hi) {
  7. return min(max(v, lo), hi);
  8. }
  9. template <typename scalar_t, bool dense_mode>
  10. __global__ void total_variation_add_grad_cuda_kernel(
  11. const scalar_t* __restrict__ param,
  12. scalar_t* __restrict__ grad,
  13. float wx, float wy, float wz,
  14. const size_t sz_i, const size_t sz_j, const size_t sz_k, const size_t N) {
  15. const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  16. if(index<N && (dense_mode || grad[index]!=0)) {
  17. const size_t k = index % sz_k;
  18. const size_t j = index / sz_k % sz_j;
  19. const size_t i = index / sz_k / sz_j % sz_i;
  20. float grad_to_add = 0;
  21. grad_to_add += (k==0 ? 0 : wx * clamp(param[index]-param[index-1], -1.f, 1.f));
  22. grad_to_add += (k==sz_k-1 ? 0 : wx * clamp(param[index]-param[index+1], -1.f, 1.f));
  23. grad_to_add += (j==0 ? 0 : wy * clamp(param[index]-param[index-sz_k], -1.f, 1.f));
  24. grad_to_add += (j==sz_j-1 ? 0 : wy * clamp(param[index]-param[index+sz_k], -1.f, 1.f));
  25. grad_to_add += (i==0 ? 0 : wz * clamp(param[index]-param[index-sz_k*sz_j], -1.f, 1.f));
  26. grad_to_add += (i==sz_i-1 ? 0 : wz * clamp(param[index]-param[index+sz_k*sz_j], -1.f, 1.f));
  27. grad[index] += grad_to_add;
  28. }
  29. }
  30. void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) {
  31. const size_t N = param.numel();
  32. const size_t sz_i = param.size(2);
  33. const size_t sz_j = param.size(3);
  34. const size_t sz_k = param.size(4);
  35. const int threads = 256;
  36. const int blocks = (N + threads - 1) / threads;
  37. wx /= 6;
  38. wy /= 6;
  39. wz /= 6;
  40. if(dense_mode) {
  41. AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] {
  42. total_variation_add_grad_cuda_kernel<scalar_t,true><<<blocks, threads>>>(
  43. param.data<scalar_t>(),
  44. grad.data<scalar_t>(),
  45. wx, wy, wz,
  46. sz_i, sz_j, sz_k, N);
  47. }));
  48. }
  49. else {
  50. AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] {
  51. total_variation_add_grad_cuda_kernel<scalar_t,false><<<blocks, threads>>>(
  52. param.data<scalar_t>(),
  53. grad.data<scalar_t>(),
  54. wx, wy, wz,
  55. sz_i, sz_j, sz_k, N);
  56. }));
  57. }
  58. }