ub360_utils_kernel.cu 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #include <torch/extension.h>
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <vector>
  5. /*
  6. helper function to skip oversampled points,
  7. especially near the foreground scene bbox boundary
  8. */
  9. template <typename scalar_t>
  10. __global__ void cumdist_thres_cuda_kernel(
  11. scalar_t* __restrict__ dist,
  12. const float thres,
  13. const int n_rays,
  14. const int n_pts,
  15. bool* __restrict__ mask) {
  16. const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  17. if(i_ray<n_rays) {
  18. float cum_dist = 0;
  19. const int i_s = i_ray * n_pts;
  20. const int i_t = i_s + n_pts;
  21. int i;
  22. for(i=i_s; i<i_t; ++i) {
  23. cum_dist += dist[i];
  24. bool over = (cum_dist > thres);
  25. cum_dist *= float(!over);
  26. mask[i] = over;
  27. }
  28. }
  29. }
  30. torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres) {
  31. const int n_rays = dist.size(0);
  32. const int n_pts = dist.size(1);
  33. const int threads = 256;
  34. const int blocks = (n_rays + threads - 1) / threads;
  35. auto mask = torch::zeros({n_rays, n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA));
  36. AT_DISPATCH_FLOATING_TYPES(dist.type(), "cumdist_thres_cuda", ([&] {
  37. cumdist_thres_cuda_kernel<scalar_t><<<blocks, threads>>>(
  38. dist.data<scalar_t>(), thres,
  39. n_rays, n_pts,
  40. mask.data<bool>());
  41. }));
  42. return mask;
  43. }