render_utils_kernel.cu 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. #include <torch/extension.h>
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <vector>
  5. /*
  6. Points sampling helper functions.
  7. */
  8. template <typename scalar_t>
  9. __global__ void infer_t_minmax_cuda_kernel(
  10. scalar_t* __restrict__ rays_o,
  11. scalar_t* __restrict__ rays_d,
  12. scalar_t* __restrict__ xyz_min,
  13. scalar_t* __restrict__ xyz_max,
  14. const float near, const float far, const int n_rays,
  15. scalar_t* __restrict__ t_min,
  16. scalar_t* __restrict__ t_max) {
  17. const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  18. if(i_ray<n_rays) {
  19. const int offset = i_ray * 3;
  20. float vx = ((rays_d[offset ]==0) ? 1e-6 : rays_d[offset ]);
  21. float vy = ((rays_d[offset+1]==0) ? 1e-6 : rays_d[offset+1]);
  22. float vz = ((rays_d[offset+2]==0) ? 1e-6 : rays_d[offset+2]);
  23. float ax = (xyz_max[0] - rays_o[offset ]) / vx;
  24. float ay = (xyz_max[1] - rays_o[offset+1]) / vy;
  25. float az = (xyz_max[2] - rays_o[offset+2]) / vz;
  26. float bx = (xyz_min[0] - rays_o[offset ]) / vx;
  27. float by = (xyz_min[1] - rays_o[offset+1]) / vy;
  28. float bz = (xyz_min[2] - rays_o[offset+2]) / vz;
  29. t_min[i_ray] = max(min(max(max(min(ax, bx), min(ay, by)), min(az, bz)), far), near);
  30. t_max[i_ray] = max(min(min(min(max(ax, bx), max(ay, by)), max(az, bz)), far), near);
  31. }
  32. }
  33. template <typename scalar_t>
  34. __global__ void infer_n_samples_cuda_kernel(
  35. scalar_t* __restrict__ rays_d,
  36. scalar_t* __restrict__ t_min,
  37. scalar_t* __restrict__ t_max,
  38. const float stepdist,
  39. const int n_rays,
  40. int64_t* __restrict__ n_samples) {
  41. const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  42. if(i_ray<n_rays) {
  43. const int offset = i_ray * 3;
  44. const float rnorm = sqrt(
  45. rays_d[offset ]*rays_d[offset ] +\
  46. rays_d[offset+1]*rays_d[offset+1] +\
  47. rays_d[offset+2]*rays_d[offset+2]);
  48. // at least 1 point for easier implementation in the later sample_pts_on_rays_cuda
  49. n_samples[i_ray] = max(ceil((t_max[i_ray]-t_min[i_ray]) * rnorm / stepdist), 1.);
  50. }
  51. }
  52. template <typename scalar_t>
  53. __global__ void infer_ray_start_dir_cuda_kernel(
  54. scalar_t* __restrict__ rays_o,
  55. scalar_t* __restrict__ rays_d,
  56. scalar_t* __restrict__ t_min,
  57. const int n_rays,
  58. scalar_t* __restrict__ rays_start,
  59. scalar_t* __restrict__ rays_dir) {
  60. const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  61. if(i_ray<n_rays) {
  62. const int offset = i_ray * 3;
  63. const float rnorm = sqrt(
  64. rays_d[offset ]*rays_d[offset ] +\
  65. rays_d[offset+1]*rays_d[offset+1] +\
  66. rays_d[offset+2]*rays_d[offset+2]);
  67. rays_start[offset ] = rays_o[offset ] + rays_d[offset ] * t_min[i_ray];
  68. rays_start[offset+1] = rays_o[offset+1] + rays_d[offset+1] * t_min[i_ray];
  69. rays_start[offset+2] = rays_o[offset+2] + rays_d[offset+2] * t_min[i_ray];
  70. rays_dir [offset ] = rays_d[offset ] / rnorm;
  71. rays_dir [offset+1] = rays_d[offset+1] / rnorm;
  72. rays_dir [offset+2] = rays_d[offset+2] / rnorm;
  73. }
  74. }
  75. std::vector<torch::Tensor> infer_t_minmax_cuda(
  76. torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
  77. const float near, const float far) {
  78. const int n_rays = rays_o.size(0);
  79. auto t_min = torch::empty({n_rays}, rays_o.options());
  80. auto t_max = torch::empty({n_rays}, rays_o.options());
  81. const int threads = 256;
  82. const int blocks = (n_rays + threads - 1) / threads;
  83. AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_t_minmax_cuda", ([&] {
  84. infer_t_minmax_cuda_kernel<scalar_t><<<blocks, threads>>>(
  85. rays_o.data<scalar_t>(),
  86. rays_d.data<scalar_t>(),
  87. xyz_min.data<scalar_t>(),
  88. xyz_max.data<scalar_t>(),
  89. near, far, n_rays,
  90. t_min.data<scalar_t>(),
  91. t_max.data<scalar_t>());
  92. }));
  93. return {t_min, t_max};
  94. }
  95. torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
  96. const int n_rays = t_min.size(0);
  97. auto n_samples = torch::empty({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  98. const int threads = 256;
  99. const int blocks = (n_rays + threads - 1) / threads;
  100. AT_DISPATCH_FLOATING_TYPES(t_min.type(), "infer_n_samples_cuda", ([&] {
  101. infer_n_samples_cuda_kernel<scalar_t><<<blocks, threads>>>(
  102. rays_d.data<scalar_t>(),
  103. t_min.data<scalar_t>(),
  104. t_max.data<scalar_t>(),
  105. stepdist,
  106. n_rays,
  107. n_samples.data<int64_t>());
  108. }));
  109. return n_samples;
  110. }
  111. std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
  112. const int n_rays = rays_o.size(0);
  113. const int threads = 256;
  114. const int blocks = (n_rays + threads - 1) / threads;
  115. auto rays_start = torch::empty_like(rays_o);
  116. auto rays_dir = torch::empty_like(rays_o);
  117. AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_ray_start_dir_cuda", ([&] {
  118. infer_ray_start_dir_cuda_kernel<scalar_t><<<blocks, threads>>>(
  119. rays_o.data<scalar_t>(),
  120. rays_d.data<scalar_t>(),
  121. t_min.data<scalar_t>(),
  122. n_rays,
  123. rays_start.data<scalar_t>(),
  124. rays_dir.data<scalar_t>());
  125. }));
  126. return {rays_start, rays_dir};
  127. }
  128. /*
  129. Sampling query points on rays.
  130. */
  131. __global__ void __set_1_at_ray_seg_start(
  132. int64_t* __restrict__ ray_id,
  133. int64_t* __restrict__ N_steps_cumsum,
  134. const int n_rays) {
  135. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  136. if(0<idx && idx<n_rays) {
  137. ray_id[N_steps_cumsum[idx-1]] = 1;
  138. }
  139. }
  140. __global__ void __set_step_id(
  141. int64_t* __restrict__ step_id,
  142. int64_t* __restrict__ ray_id,
  143. int64_t* __restrict__ N_steps_cumsum,
  144. const int total_len) {
  145. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  146. if(idx<total_len) {
  147. const int rid = ray_id[idx];
  148. step_id[idx] = idx - ((rid!=0) ? N_steps_cumsum[rid-1] : 0);
  149. }
  150. }
  151. template <typename scalar_t>
  152. __global__ void sample_pts_on_rays_cuda_kernel(
  153. scalar_t* __restrict__ rays_start,
  154. scalar_t* __restrict__ rays_dir,
  155. scalar_t* __restrict__ xyz_min,
  156. scalar_t* __restrict__ xyz_max,
  157. int64_t* __restrict__ ray_id,
  158. int64_t* __restrict__ step_id,
  159. const float stepdist, const int total_len,
  160. scalar_t* __restrict__ rays_pts,
  161. bool* __restrict__ mask_outbbox) {
  162. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  163. if(idx<total_len) {
  164. const int i_ray = ray_id[idx];
  165. const int i_step = step_id[idx];
  166. const int offset_p = idx * 3;
  167. const int offset_r = i_ray * 3;
  168. const float dist = stepdist * i_step;
  169. const float px = rays_start[offset_r ] + rays_dir[offset_r ] * dist;
  170. const float py = rays_start[offset_r+1] + rays_dir[offset_r+1] * dist;
  171. const float pz = rays_start[offset_r+2] + rays_dir[offset_r+2] * dist;
  172. rays_pts[offset_p ] = px;
  173. rays_pts[offset_p+1] = py;
  174. rays_pts[offset_p+2] = pz;
  175. mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
  176. (xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
  177. }
  178. }
  179. std::vector<torch::Tensor> sample_pts_on_rays_cuda(
  180. torch::Tensor rays_o, torch::Tensor rays_d,
  181. torch::Tensor xyz_min, torch::Tensor xyz_max,
  182. const float near, const float far, const float stepdist) {
  183. const int threads = 256;
  184. const int n_rays = rays_o.size(0);
  185. // Compute ray-bbox intersection
  186. auto t_minmax = infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
  187. auto t_min = t_minmax[0];
  188. auto t_max = t_minmax[1];
  189. // Compute the number of points required.
  190. // Assign ray index and step index to each.
  191. auto N_steps = infer_n_samples_cuda(rays_d, t_min, t_max, stepdist);
  192. auto N_steps_cumsum = N_steps.cumsum(0);
  193. const int total_len = N_steps.sum().item<int>();
  194. auto ray_id = torch::zeros({total_len}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  195. __set_1_at_ray_seg_start<<<(n_rays+threads-1)/threads, threads>>>(
  196. ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), n_rays);
  197. ray_id.cumsum_(0);
  198. auto step_id = torch::empty({total_len}, ray_id.options());
  199. __set_step_id<<<(total_len+threads-1)/threads, threads>>>(
  200. step_id.data<int64_t>(), ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), total_len);
  201. // Compute the global xyz of each point
  202. auto rays_start_dir = infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
  203. auto rays_start = rays_start_dir[0];
  204. auto rays_dir = rays_start_dir[1];
  205. auto rays_pts = torch::empty({total_len, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
  206. auto mask_outbbox = torch::empty({total_len}, torch::dtype(torch::kBool).device(torch::kCUDA));
  207. AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_pts_on_rays_cuda", ([&] {
  208. sample_pts_on_rays_cuda_kernel<scalar_t><<<(total_len+threads-1)/threads, threads>>>(
  209. rays_start.data<scalar_t>(),
  210. rays_dir.data<scalar_t>(),
  211. xyz_min.data<scalar_t>(),
  212. xyz_max.data<scalar_t>(),
  213. ray_id.data<int64_t>(),
  214. step_id.data<int64_t>(),
  215. stepdist, total_len,
  216. rays_pts.data<scalar_t>(),
  217. mask_outbbox.data<bool>());
  218. }));
  219. return {rays_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max};
  220. }
  221. template <typename scalar_t>
  222. __global__ void sample_ndc_pts_on_rays_cuda_kernel(
  223. const scalar_t* __restrict__ rays_o,
  224. const scalar_t* __restrict__ rays_d,
  225. const scalar_t* __restrict__ xyz_min,
  226. const scalar_t* __restrict__ xyz_max,
  227. const int N_samples, const int n_rays,
  228. scalar_t* __restrict__ rays_pts,
  229. bool* __restrict__ mask_outbbox) {
  230. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  231. if(idx<N_samples*n_rays) {
  232. const int i_ray = idx / N_samples;
  233. const int i_step = idx % N_samples;
  234. const int offset_p = idx * 3;
  235. const int offset_r = i_ray * 3;
  236. const float dist = ((float)i_step) / (N_samples-1);
  237. const float px = rays_o[offset_r ] + rays_d[offset_r ] * dist;
  238. const float py = rays_o[offset_r+1] + rays_d[offset_r+1] * dist;
  239. const float pz = rays_o[offset_r+2] + rays_d[offset_r+2] * dist;
  240. rays_pts[offset_p ] = px;
  241. rays_pts[offset_p+1] = py;
  242. rays_pts[offset_p+2] = pz;
  243. mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
  244. (xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
  245. }
  246. }
  247. std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
  248. torch::Tensor rays_o, torch::Tensor rays_d,
  249. torch::Tensor xyz_min, torch::Tensor xyz_max,
  250. const int N_samples) {
  251. const int threads = 256;
  252. const int n_rays = rays_o.size(0);
  253. auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
  254. auto mask_outbbox = torch::empty({n_rays, N_samples}, torch::dtype(torch::kBool).device(torch::kCUDA));
  255. AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_ndc_pts_on_rays_cuda", ([&] {
  256. sample_ndc_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
  257. rays_o.data<scalar_t>(),
  258. rays_d.data<scalar_t>(),
  259. xyz_min.data<scalar_t>(),
  260. xyz_max.data<scalar_t>(),
  261. N_samples, n_rays,
  262. rays_pts.data<scalar_t>(),
  263. mask_outbbox.data<bool>());
  264. }));
  265. return {rays_pts, mask_outbbox};
  266. }
  267. template <typename scalar_t>
  268. __device__ __forceinline__ scalar_t norm3(const scalar_t x, const scalar_t y, const scalar_t z) {
  269. return sqrt(x*x + y*y + z*z);
  270. }
  271. template <typename scalar_t>
  272. __global__ void sample_bg_pts_on_rays_cuda_kernel(
  273. const scalar_t* __restrict__ rays_o,
  274. const scalar_t* __restrict__ rays_d,
  275. const scalar_t* __restrict__ t_max,
  276. const float bg_preserve,
  277. const int N_samples, const int n_rays,
  278. scalar_t* __restrict__ rays_pts) {
  279. const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  280. if(idx<N_samples*n_rays) {
  281. const int i_ray = idx / N_samples;
  282. const int i_step = idx % N_samples;
  283. const int offset_p = idx * 3;
  284. const int offset_r = i_ray * 3;
  285. /* Original pytorch implementation
  286. ori_t_outer = t_max[:,None] - 1 + 1 / torch.linspace(1, 0, N_outer+1)[:-1]
  287. ori_ray_pts_outer = (rays_o[:,None,:] + rays_d[:,None,:] * ori_t_outer[:,:,None]).reshape(-1,3)
  288. t_outer = ori_ray_pts_outer.norm(dim=-1)
  289. R_outer = t_outer / ori_ray_pts_outer.abs().amax(1)
  290. # r = R * R / t
  291. o2i_p = R_outer.pow(2) / t_outer.pow(2) * (1-self.bg_preserve) + R_outer / t_outer * self.bg_preserve
  292. ray_pts_outer = (ori_ray_pts_outer * o2i_p[:,None]).reshape(len(rays_o), -1, 3)
  293. */
  294. const float t_inner = t_max[i_ray];
  295. const float ori_t_outer = t_inner - 1. + 1. / (1. - ((float)i_step) / N_samples);
  296. const float ori_ray_pts_x = rays_o[offset_r ] + rays_d[offset_r ] * ori_t_outer;
  297. const float ori_ray_pts_y = rays_o[offset_r+1] + rays_d[offset_r+1] * ori_t_outer;
  298. const float ori_ray_pts_z = rays_o[offset_r+2] + rays_d[offset_r+2] * ori_t_outer;
  299. const float t_outer = norm3(ori_ray_pts_x, ori_ray_pts_y, ori_ray_pts_z);
  300. const float ori_ray_pts_m = max(abs(ori_ray_pts_x), max(abs(ori_ray_pts_y), abs(ori_ray_pts_z)));
  301. const float R_outer = t_outer / ori_ray_pts_m;
  302. const float o2i_p = R_outer*R_outer / (t_outer*t_outer) * (1.-bg_preserve) + R_outer / t_outer * bg_preserve;
  303. const float px = ori_ray_pts_x * o2i_p;
  304. const float py = ori_ray_pts_y * o2i_p;
  305. const float pz = ori_ray_pts_z * o2i_p;
  306. rays_pts[offset_p ] = px;
  307. rays_pts[offset_p+1] = py;
  308. rays_pts[offset_p+2] = pz;
  309. }
  310. }
  311. torch::Tensor sample_bg_pts_on_rays_cuda(
  312. torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
  313. const float bg_preserve, const int N_samples) {
  314. const int threads = 256;
  315. const int n_rays = rays_o.size(0);
  316. auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
  317. AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_bg_pts_on_rays_cuda", ([&] {
  318. sample_bg_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
  319. rays_o.data<scalar_t>(),
  320. rays_d.data<scalar_t>(),
  321. t_max.data<scalar_t>(),
  322. bg_preserve,
  323. N_samples, n_rays,
  324. rays_pts.data<scalar_t>());
  325. }));
  326. return rays_pts;
  327. }
  328. /*
  329. MaskCache lookup to skip known freespace.
  330. */
  331. static __forceinline__ __device__
  332. bool check_xyz(int i, int j, int k, int sz_i, int sz_j, int sz_k) {
  333. return (0 <= i) && (i < sz_i) && (0 <= j) && (j < sz_j) && (0 <= k) && (k < sz_k);
  334. }
  335. template <typename scalar_t>
  336. __global__ void maskcache_lookup_cuda_kernel(
  337. bool* __restrict__ world,
  338. scalar_t* __restrict__ xyz,
  339. bool* __restrict__ out,
  340. scalar_t* __restrict__ xyz2ijk_scale,
  341. scalar_t* __restrict__ xyz2ijk_shift,
  342. const int sz_i, const int sz_j, const int sz_k, const int n_pts) {
  343. const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  344. if(i_pt<n_pts) {
  345. const int offset = i_pt * 3;
  346. const int i = round(xyz[offset ] * xyz2ijk_scale[0] + xyz2ijk_shift[0]);
  347. const int j = round(xyz[offset+1] * xyz2ijk_scale[1] + xyz2ijk_shift[1]);
  348. const int k = round(xyz[offset+2] * xyz2ijk_scale[2] + xyz2ijk_shift[2]);
  349. if(check_xyz(i, j, k, sz_i, sz_j, sz_k)) {
  350. out[i_pt] = world[i*sz_j*sz_k + j*sz_k + k];
  351. }
  352. }
  353. }
  354. torch::Tensor maskcache_lookup_cuda(
  355. torch::Tensor world,
  356. torch::Tensor xyz,
  357. torch::Tensor xyz2ijk_scale,
  358. torch::Tensor xyz2ijk_shift) {
  359. const int sz_i = world.size(0);
  360. const int sz_j = world.size(1);
  361. const int sz_k = world.size(2);
  362. const int n_pts = xyz.size(0);
  363. auto out = torch::zeros({n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA));
  364. if(n_pts==0) {
  365. return out;
  366. }
  367. const int threads = 256;
  368. const int blocks = (n_pts + threads - 1) / threads;
  369. AT_DISPATCH_FLOATING_TYPES(xyz.type(), "maskcache_lookup_cuda", ([&] {
  370. maskcache_lookup_cuda_kernel<scalar_t><<<blocks, threads>>>(
  371. world.data<bool>(),
  372. xyz.data<scalar_t>(),
  373. out.data<bool>(),
  374. xyz2ijk_scale.data<scalar_t>(),
  375. xyz2ijk_shift.data<scalar_t>(),
  376. sz_i, sz_j, sz_k, n_pts);
  377. }));
  378. return out;
  379. }
  380. /*
  381. Ray marching helper function.
  382. */
  383. template <typename scalar_t>
  384. __global__ void raw2alpha_cuda_kernel(
  385. scalar_t* __restrict__ density,
  386. const float shift, const float interval, const int n_pts,
  387. scalar_t* __restrict__ exp_d,
  388. scalar_t* __restrict__ alpha) {
  389. const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  390. if(i_pt<n_pts) {
  391. const scalar_t e = exp(density[i_pt] + shift); // can be inf
  392. exp_d[i_pt] = e;
  393. alpha[i_pt] = 1 - pow(1 + e, -interval);
  394. }
  395. }
  396. template <typename scalar_t>
  397. __global__ void raw2alpha_nonuni_cuda_kernel(
  398. scalar_t* __restrict__ density,
  399. const float shift, scalar_t* __restrict__ interval, const int n_pts,
  400. scalar_t* __restrict__ exp_d,
  401. scalar_t* __restrict__ alpha) {
  402. const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  403. if(i_pt<n_pts) {
  404. const scalar_t e = exp(density[i_pt] + shift); // can be inf
  405. exp_d[i_pt] = e;
  406. alpha[i_pt] = 1 - pow(1 + e, -interval[i_pt]);
  407. }
  408. }
  409. std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval) {
  410. const int n_pts = density.size(0);
  411. auto exp_d = torch::empty_like(density);
  412. auto alpha = torch::empty_like(density);
  413. if(n_pts==0) {
  414. return {exp_d, alpha};
  415. }
  416. const int threads = 256;
  417. const int blocks = (n_pts + threads - 1) / threads;
  418. AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
  419. raw2alpha_cuda_kernel<scalar_t><<<blocks, threads>>>(
  420. density.data<scalar_t>(),
  421. shift, interval, n_pts,
  422. exp_d.data<scalar_t>(),
  423. alpha.data<scalar_t>());
  424. }));
  425. return {exp_d, alpha};
  426. }
  427. std::vector<torch::Tensor> raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval) {
  428. const int n_pts = density.size(0);
  429. auto exp_d = torch::empty_like(density);
  430. auto alpha = torch::empty_like(density);
  431. if(n_pts==0) {
  432. return {exp_d, alpha};
  433. }
  434. const int threads = 256;
  435. const int blocks = (n_pts + threads - 1) / threads;
  436. AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
  437. raw2alpha_nonuni_cuda_kernel<scalar_t><<<blocks, threads>>>(
  438. density.data<scalar_t>(),
  439. shift, interval.data<scalar_t>(), n_pts,
  440. exp_d.data<scalar_t>(),
  441. alpha.data<scalar_t>());
  442. }));
  443. return {exp_d, alpha};
  444. }
  445. template <typename scalar_t>
  446. __global__ void raw2alpha_backward_cuda_kernel(
  447. scalar_t* __restrict__ exp_d,
  448. scalar_t* __restrict__ grad_back,
  449. const float interval, const int n_pts,
  450. scalar_t* __restrict__ grad) {
  451. const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  452. if(i_pt<n_pts) {
  453. grad[i_pt] = min(exp_d[i_pt], 1e10) * pow(1+exp_d[i_pt], -interval-1) * interval * grad_back[i_pt];
  454. }
  455. }
  456. template <typename scalar_t>
  457. __global__ void raw2alpha_nonuni_backward_cuda_kernel(
  458. scalar_t* __restrict__ exp_d,
  459. scalar_t* __restrict__ grad_back,
  460. scalar_t* __restrict__ interval, const int n_pts,
  461. scalar_t* __restrict__ grad) {
  462. const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  463. if(i_pt<n_pts) {
  464. grad[i_pt] = min(exp_d[i_pt], 1e10) * pow(1+exp_d[i_pt], -interval[i_pt]-1) * interval[i_pt] * grad_back[i_pt];
  465. }
  466. }
  467. torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, const float interval) {
  468. const int n_pts = exp_d.size(0);
  469. auto grad = torch::empty_like(exp_d);
  470. if(n_pts==0) {
  471. return grad;
  472. }
  473. const int threads = 256;
  474. const int blocks = (n_pts + threads - 1) / threads;
  475. AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
  476. raw2alpha_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
  477. exp_d.data<scalar_t>(),
  478. grad_back.data<scalar_t>(),
  479. interval, n_pts,
  480. grad.data<scalar_t>());
  481. }));
  482. return grad;
  483. }
  484. torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, torch::Tensor interval) {
  485. const int n_pts = exp_d.size(0);
  486. auto grad = torch::empty_like(exp_d);
  487. if(n_pts==0) {
  488. return grad;
  489. }
  490. const int threads = 256;
  491. const int blocks = (n_pts + threads - 1) / threads;
  492. AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
  493. raw2alpha_nonuni_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
  494. exp_d.data<scalar_t>(),
  495. grad_back.data<scalar_t>(),
  496. interval.data<scalar_t>(), n_pts,
  497. grad.data<scalar_t>());
  498. }));
  499. return grad;
  500. }
  501. template <typename scalar_t>
  502. __global__ void alpha2weight_cuda_kernel(
  503. scalar_t* __restrict__ alpha,
  504. const int n_rays,
  505. scalar_t* __restrict__ weight,
  506. scalar_t* __restrict__ T,
  507. scalar_t* __restrict__ alphainv_last,
  508. int64_t* __restrict__ i_start,
  509. int64_t* __restrict__ i_end) {
  510. const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  511. if(i_ray<n_rays) {
  512. const int i_s = i_start[i_ray];
  513. const int i_e_max = i_end[i_ray];
  514. float T_cum = 1.;
  515. int i;
  516. for(i=i_s; i<i_e_max; ++i) {
  517. T[i] = T_cum;
  518. weight[i] = T_cum * alpha[i];
  519. T_cum *= (1. - alpha[i]);
  520. if(T_cum<1e-3) {
  521. i+=1;
  522. break;
  523. }
  524. }
  525. i_end[i_ray] = i;
  526. alphainv_last[i_ray] = T_cum;
  527. }
  528. }
  529. __global__ void __set_i_for_segment_start_end(
  530. int64_t* __restrict__ ray_id,
  531. const int n_pts,
  532. int64_t* __restrict__ i_start,
  533. int64_t* __restrict__ i_end) {
  534. const int index = blockIdx.x * blockDim.x + threadIdx.x;
  535. if(0<index && index<n_pts && ray_id[index]!=ray_id[index-1]) {
  536. i_start[ray_id[index]] = index;
  537. i_end[ray_id[index-1]] = index;
  538. }
  539. }
  540. std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {
  541. const int n_pts = alpha.size(0);
  542. const int threads = 256;
  543. auto weight = torch::zeros_like(alpha);
  544. auto T = torch::ones_like(alpha);
  545. auto alphainv_last = torch::ones({n_rays}, alpha.options());
  546. auto i_start = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  547. auto i_end = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  548. if(n_pts==0) {
  549. return {weight, T, alphainv_last, i_start, i_end};
  550. }
  551. __set_i_for_segment_start_end<<<(n_pts+threads-1)/threads, threads>>>(
  552. ray_id.data<int64_t>(), n_pts, i_start.data<int64_t>(), i_end.data<int64_t>());
  553. i_end[ray_id[n_pts-1]] = n_pts;
  554. const int blocks = (n_rays + threads - 1) / threads;
  555. AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_cuda", ([&] {
  556. alpha2weight_cuda_kernel<scalar_t><<<blocks, threads>>>(
  557. alpha.data<scalar_t>(),
  558. n_rays,
  559. weight.data<scalar_t>(),
  560. T.data<scalar_t>(),
  561. alphainv_last.data<scalar_t>(),
  562. i_start.data<int64_t>(),
  563. i_end.data<int64_t>());
  564. }));
  565. return {weight, T, alphainv_last, i_start, i_end};
  566. }
  567. template <typename scalar_t>
  568. __global__ void alpha2weight_backward_cuda_kernel(
  569. scalar_t* __restrict__ alpha,
  570. scalar_t* __restrict__ weight,
  571. scalar_t* __restrict__ T,
  572. scalar_t* __restrict__ alphainv_last,
  573. int64_t* __restrict__ i_start,
  574. int64_t* __restrict__ i_end,
  575. const int n_rays,
  576. scalar_t* __restrict__ grad_weights,
  577. scalar_t* __restrict__ grad_last,
  578. scalar_t* __restrict__ grad) {
  579. const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  580. if(i_ray<n_rays) {
  581. const int i_s = i_start[i_ray];
  582. const int i_e = i_end[i_ray];
  583. float back_cum = grad_last[i_ray] * alphainv_last[i_ray];
  584. for(int i=i_e-1; i>=i_s; --i) {
  585. grad[i] = grad_weights[i] * T[i] - back_cum / (1-alpha[i] + 1e-10);
  586. back_cum += grad_weights[i] * weight[i];
  587. }
  588. }
  589. }
  590. torch::Tensor alpha2weight_backward_cuda(
  591. torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
  592. torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
  593. torch::Tensor grad_weights, torch::Tensor grad_last) {
  594. auto grad = torch::zeros_like(alpha);
  595. if(n_rays==0) {
  596. return grad;
  597. }
  598. const int threads = 256;
  599. const int blocks = (n_rays + threads - 1) / threads;
  600. AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_backward_cuda", ([&] {
  601. alpha2weight_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
  602. alpha.data<scalar_t>(),
  603. weight.data<scalar_t>(),
  604. T.data<scalar_t>(),
  605. alphainv_last.data<scalar_t>(),
  606. i_start.data<int64_t>(),
  607. i_end.data<int64_t>(),
  608. n_rays,
  609. grad_weights.data<scalar_t>(),
  610. grad_last.data<scalar_t>(),
  611. grad.data<scalar_t>());
  612. }));
  613. return grad;
  614. }