render_utils.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #include <torch/extension.h>
  2. #include <vector>
  3. std::vector<torch::Tensor> infer_t_minmax_cuda(
  4. torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
  5. const float near, const float far);
  6. torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist);
  7. std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min);
  8. std::vector<torch::Tensor> sample_pts_on_rays_cuda(
  9. torch::Tensor rays_o, torch::Tensor rays_d,
  10. torch::Tensor xyz_min, torch::Tensor xyz_max,
  11. const float near, const float far, const float stepdist);
  12. std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
  13. torch::Tensor rays_o, torch::Tensor rays_d,
  14. torch::Tensor xyz_min, torch::Tensor xyz_max,
  15. const int N_samples);
  16. torch::Tensor sample_bg_pts_on_rays_cuda(
  17. torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
  18. const float bg_preserve, const int N_samples);
  19. torch::Tensor maskcache_lookup_cuda(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift);
  20. std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval);
  21. std::vector<torch::Tensor> raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval);
  22. torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, const float interval);
  23. torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, torch::Tensor interval);
  24. std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays);
  25. torch::Tensor alpha2weight_backward_cuda(
  26. torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
  27. torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
  28. torch::Tensor grad_weights, torch::Tensor grad_last);
  29. #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
  30. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  31. #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
  32. std::vector<torch::Tensor> infer_t_minmax(
  33. torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
  34. const float near, const float far) {
  35. CHECK_INPUT(rays_o);
  36. CHECK_INPUT(rays_d);
  37. CHECK_INPUT(xyz_min);
  38. CHECK_INPUT(xyz_max);
  39. return infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
  40. }
  41. torch::Tensor infer_n_samples(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
  42. CHECK_INPUT(rays_d);
  43. CHECK_INPUT(t_min);
  44. CHECK_INPUT(t_max);
  45. return infer_n_samples_cuda(rays_d, t_min, t_max, stepdist);
  46. }
  47. std::vector<torch::Tensor> infer_ray_start_dir(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
  48. CHECK_INPUT(rays_o);
  49. CHECK_INPUT(rays_d);
  50. CHECK_INPUT(t_min);
  51. return infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
  52. }
  53. std::vector<torch::Tensor> sample_pts_on_rays(
  54. torch::Tensor rays_o, torch::Tensor rays_d,
  55. torch::Tensor xyz_min, torch::Tensor xyz_max,
  56. const float near, const float far, const float stepdist) {
  57. CHECK_INPUT(rays_o);
  58. CHECK_INPUT(rays_d);
  59. CHECK_INPUT(xyz_min);
  60. CHECK_INPUT(xyz_max);
  61. assert(rays_o.dim()==2);
  62. assert(rays_o.size(1)==3);
  63. return sample_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far, stepdist);
  64. }
  65. std::vector<torch::Tensor> sample_ndc_pts_on_rays(
  66. torch::Tensor rays_o, torch::Tensor rays_d,
  67. torch::Tensor xyz_min, torch::Tensor xyz_max,
  68. const int N_samples) {
  69. CHECK_INPUT(rays_o);
  70. CHECK_INPUT(rays_d);
  71. CHECK_INPUT(xyz_min);
  72. CHECK_INPUT(xyz_max);
  73. assert(rays_o.dim()==2);
  74. assert(rays_o.size(1)==3);
  75. return sample_ndc_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, N_samples);
  76. }
  77. torch::Tensor sample_bg_pts_on_rays(
  78. torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
  79. const float bg_preserve, const int N_samples) {
  80. CHECK_INPUT(rays_o);
  81. CHECK_INPUT(rays_d);
  82. CHECK_INPUT(t_max);
  83. return sample_bg_pts_on_rays_cuda(rays_o, rays_d, t_max, bg_preserve, N_samples);
  84. }
  85. torch::Tensor maskcache_lookup(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift) {
  86. CHECK_INPUT(world);
  87. CHECK_INPUT(xyz);
  88. CHECK_INPUT(xyz2ijk_scale);
  89. CHECK_INPUT(xyz2ijk_shift);
  90. assert(world.dim()==3);
  91. assert(xyz.dim()==2);
  92. assert(xyz.size(1)==3);
  93. return maskcache_lookup_cuda(world, xyz, xyz2ijk_scale, xyz2ijk_shift);
  94. }
  95. std::vector<torch::Tensor> raw2alpha(torch::Tensor density, const float shift, const float interval) {
  96. CHECK_INPUT(density);
  97. assert(density.dim()==1);
  98. return raw2alpha_cuda(density, shift, interval);
  99. }
  100. std::vector<torch::Tensor> raw2alpha_nonuni(torch::Tensor density, const float shift, torch::Tensor interval) {
  101. CHECK_INPUT(density);
  102. assert(density.dim()==1);
  103. return raw2alpha_nonuni_cuda(density, shift, interval);
  104. }
  105. torch::Tensor raw2alpha_backward(torch::Tensor exp, torch::Tensor grad_back, const float interval) {
  106. CHECK_INPUT(exp);
  107. CHECK_INPUT(grad_back);
  108. return raw2alpha_backward_cuda(exp, grad_back, interval);
  109. }
  110. torch::Tensor raw2alpha_nonuni_backward(torch::Tensor exp, torch::Tensor grad_back, torch::Tensor interval) {
  111. CHECK_INPUT(exp);
  112. CHECK_INPUT(grad_back);
  113. return raw2alpha_nonuni_backward_cuda(exp, grad_back, interval);
  114. }
  115. std::vector<torch::Tensor> alpha2weight(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {
  116. CHECK_INPUT(alpha);
  117. CHECK_INPUT(ray_id);
  118. assert(alpha.dim()==1);
  119. assert(ray_id.dim()==1);
  120. assert(alpha.sizes()==ray_id.sizes());
  121. return alpha2weight_cuda(alpha, ray_id, n_rays);
  122. }
  123. torch::Tensor alpha2weight_backward(
  124. torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
  125. torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
  126. torch::Tensor grad_weights, torch::Tensor grad_last) {
  127. CHECK_INPUT(alpha);
  128. CHECK_INPUT(weight);
  129. CHECK_INPUT(T);
  130. CHECK_INPUT(alphainv_last);
  131. CHECK_INPUT(i_start);
  132. CHECK_INPUT(i_end);
  133. CHECK_INPUT(grad_weights);
  134. CHECK_INPUT(grad_last);
  135. return alpha2weight_backward_cuda(
  136. alpha, weight, T, alphainv_last,
  137. i_start, i_end, n_rays,
  138. grad_weights, grad_last);
  139. }
  140. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  141. m.def("infer_t_minmax", &infer_t_minmax, "Inference t_min and t_max of ray-bbox intersection");
  142. m.def("infer_n_samples", &infer_n_samples, "Inference the number of points to sample on each ray");
  143. m.def("infer_ray_start_dir", &infer_ray_start_dir, "Inference the starting point and shooting direction of each ray");
  144. m.def("sample_pts_on_rays", &sample_pts_on_rays, "Sample points on rays");
  145. m.def("sample_ndc_pts_on_rays", &sample_ndc_pts_on_rays, "Sample points on rays");
  146. m.def("sample_bg_pts_on_rays", &sample_bg_pts_on_rays, "Sample points on bg");
  147. m.def("maskcache_lookup", &maskcache_lookup, "Lookup to skip know freespace.");
  148. m.def("raw2alpha", &raw2alpha, "Raw values [-inf, inf] to alpha [0, 1].");
  149. m.def("raw2alpha_backward", &raw2alpha_backward, "Backward pass of the raw to alpha");
  150. m.def("raw2alpha_nonuni", &raw2alpha_nonuni, "Raw values [-inf, inf] to alpha [0, 1].");
  151. m.def("raw2alpha_nonuni_backward", &raw2alpha_nonuni_backward, "Backward pass of the raw to alpha");
  152. m.def("alpha2weight", &alpha2weight, "Per-point alpha to accumulated blending weight");
  153. m.def("alpha2weight_backward", &alpha2weight_backward, "Backward pass of alpha2weight");
  154. }