| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707 |
- #include <torch/extension.h>
- #include <cuda.h>
- #include <cuda_runtime.h>
- #include <vector>
- /*
- Points sampling helper functions.
- */
- template <typename scalar_t>
- __global__ void infer_t_minmax_cuda_kernel(
- scalar_t* __restrict__ rays_o,
- scalar_t* __restrict__ rays_d,
- scalar_t* __restrict__ xyz_min,
- scalar_t* __restrict__ xyz_max,
- const float near, const float far, const int n_rays,
- scalar_t* __restrict__ t_min,
- scalar_t* __restrict__ t_max) {
- const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_ray<n_rays) {
- const int offset = i_ray * 3;
- float vx = ((rays_d[offset ]==0) ? 1e-6 : rays_d[offset ]);
- float vy = ((rays_d[offset+1]==0) ? 1e-6 : rays_d[offset+1]);
- float vz = ((rays_d[offset+2]==0) ? 1e-6 : rays_d[offset+2]);
- float ax = (xyz_max[0] - rays_o[offset ]) / vx;
- float ay = (xyz_max[1] - rays_o[offset+1]) / vy;
- float az = (xyz_max[2] - rays_o[offset+2]) / vz;
- float bx = (xyz_min[0] - rays_o[offset ]) / vx;
- float by = (xyz_min[1] - rays_o[offset+1]) / vy;
- float bz = (xyz_min[2] - rays_o[offset+2]) / vz;
- t_min[i_ray] = max(min(max(max(min(ax, bx), min(ay, by)), min(az, bz)), far), near);
- t_max[i_ray] = max(min(min(min(max(ax, bx), max(ay, by)), max(az, bz)), far), near);
- }
- }
- template <typename scalar_t>
- __global__ void infer_n_samples_cuda_kernel(
- scalar_t* __restrict__ rays_d,
- scalar_t* __restrict__ t_min,
- scalar_t* __restrict__ t_max,
- const float stepdist,
- const int n_rays,
- int64_t* __restrict__ n_samples) {
- const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_ray<n_rays) {
- const int offset = i_ray * 3;
- const float rnorm = sqrt(
- rays_d[offset ]*rays_d[offset ] +\
- rays_d[offset+1]*rays_d[offset+1] +\
- rays_d[offset+2]*rays_d[offset+2]);
- // at least 1 point for easier implementation in the later sample_pts_on_rays_cuda
- n_samples[i_ray] = max(ceil((t_max[i_ray]-t_min[i_ray]) * rnorm / stepdist), 1.);
- }
- }
- template <typename scalar_t>
- __global__ void infer_ray_start_dir_cuda_kernel(
- scalar_t* __restrict__ rays_o,
- scalar_t* __restrict__ rays_d,
- scalar_t* __restrict__ t_min,
- const int n_rays,
- scalar_t* __restrict__ rays_start,
- scalar_t* __restrict__ rays_dir) {
- const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_ray<n_rays) {
- const int offset = i_ray * 3;
- const float rnorm = sqrt(
- rays_d[offset ]*rays_d[offset ] +\
- rays_d[offset+1]*rays_d[offset+1] +\
- rays_d[offset+2]*rays_d[offset+2]);
- rays_start[offset ] = rays_o[offset ] + rays_d[offset ] * t_min[i_ray];
- rays_start[offset+1] = rays_o[offset+1] + rays_d[offset+1] * t_min[i_ray];
- rays_start[offset+2] = rays_o[offset+2] + rays_d[offset+2] * t_min[i_ray];
- rays_dir [offset ] = rays_d[offset ] / rnorm;
- rays_dir [offset+1] = rays_d[offset+1] / rnorm;
- rays_dir [offset+2] = rays_d[offset+2] / rnorm;
- }
- }
- std::vector<torch::Tensor> infer_t_minmax_cuda(
- torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
- const float near, const float far) {
- const int n_rays = rays_o.size(0);
- auto t_min = torch::empty({n_rays}, rays_o.options());
- auto t_max = torch::empty({n_rays}, rays_o.options());
- const int threads = 256;
- const int blocks = (n_rays + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_t_minmax_cuda", ([&] {
- infer_t_minmax_cuda_kernel<scalar_t><<<blocks, threads>>>(
- rays_o.data<scalar_t>(),
- rays_d.data<scalar_t>(),
- xyz_min.data<scalar_t>(),
- xyz_max.data<scalar_t>(),
- near, far, n_rays,
- t_min.data<scalar_t>(),
- t_max.data<scalar_t>());
- }));
- return {t_min, t_max};
- }
- torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
- const int n_rays = t_min.size(0);
- auto n_samples = torch::empty({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
- const int threads = 256;
- const int blocks = (n_rays + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(t_min.type(), "infer_n_samples_cuda", ([&] {
- infer_n_samples_cuda_kernel<scalar_t><<<blocks, threads>>>(
- rays_d.data<scalar_t>(),
- t_min.data<scalar_t>(),
- t_max.data<scalar_t>(),
- stepdist,
- n_rays,
- n_samples.data<int64_t>());
- }));
- return n_samples;
- }
- std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
- const int n_rays = rays_o.size(0);
- const int threads = 256;
- const int blocks = (n_rays + threads - 1) / threads;
- auto rays_start = torch::empty_like(rays_o);
- auto rays_dir = torch::empty_like(rays_o);
- AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_ray_start_dir_cuda", ([&] {
- infer_ray_start_dir_cuda_kernel<scalar_t><<<blocks, threads>>>(
- rays_o.data<scalar_t>(),
- rays_d.data<scalar_t>(),
- t_min.data<scalar_t>(),
- n_rays,
- rays_start.data<scalar_t>(),
- rays_dir.data<scalar_t>());
- }));
- return {rays_start, rays_dir};
- }
- /*
- Sampling query points on rays.
- */
- __global__ void __set_1_at_ray_seg_start(
- int64_t* __restrict__ ray_id,
- int64_t* __restrict__ N_steps_cumsum,
- const int n_rays) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if(0<idx && idx<n_rays) {
- ray_id[N_steps_cumsum[idx-1]] = 1;
- }
- }
- __global__ void __set_step_id(
- int64_t* __restrict__ step_id,
- int64_t* __restrict__ ray_id,
- int64_t* __restrict__ N_steps_cumsum,
- const int total_len) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if(idx<total_len) {
- const int rid = ray_id[idx];
- step_id[idx] = idx - ((rid!=0) ? N_steps_cumsum[rid-1] : 0);
- }
- }
- template <typename scalar_t>
- __global__ void sample_pts_on_rays_cuda_kernel(
- scalar_t* __restrict__ rays_start,
- scalar_t* __restrict__ rays_dir,
- scalar_t* __restrict__ xyz_min,
- scalar_t* __restrict__ xyz_max,
- int64_t* __restrict__ ray_id,
- int64_t* __restrict__ step_id,
- const float stepdist, const int total_len,
- scalar_t* __restrict__ rays_pts,
- bool* __restrict__ mask_outbbox) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if(idx<total_len) {
- const int i_ray = ray_id[idx];
- const int i_step = step_id[idx];
- const int offset_p = idx * 3;
- const int offset_r = i_ray * 3;
- const float dist = stepdist * i_step;
- const float px = rays_start[offset_r ] + rays_dir[offset_r ] * dist;
- const float py = rays_start[offset_r+1] + rays_dir[offset_r+1] * dist;
- const float pz = rays_start[offset_r+2] + rays_dir[offset_r+2] * dist;
- rays_pts[offset_p ] = px;
- rays_pts[offset_p+1] = py;
- rays_pts[offset_p+2] = pz;
- mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
- (xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
- }
- }
- std::vector<torch::Tensor> sample_pts_on_rays_cuda(
- torch::Tensor rays_o, torch::Tensor rays_d,
- torch::Tensor xyz_min, torch::Tensor xyz_max,
- const float near, const float far, const float stepdist) {
- const int threads = 256;
- const int n_rays = rays_o.size(0);
- // Compute ray-bbox intersection
- auto t_minmax = infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
- auto t_min = t_minmax[0];
- auto t_max = t_minmax[1];
- // Compute the number of points required.
- // Assign ray index and step index to each.
- auto N_steps = infer_n_samples_cuda(rays_d, t_min, t_max, stepdist);
- auto N_steps_cumsum = N_steps.cumsum(0);
- const int total_len = N_steps.sum().item<int>();
- auto ray_id = torch::zeros({total_len}, torch::dtype(torch::kInt64).device(torch::kCUDA));
- __set_1_at_ray_seg_start<<<(n_rays+threads-1)/threads, threads>>>(
- ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), n_rays);
- ray_id.cumsum_(0);
- auto step_id = torch::empty({total_len}, ray_id.options());
- __set_step_id<<<(total_len+threads-1)/threads, threads>>>(
- step_id.data<int64_t>(), ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), total_len);
- // Compute the global xyz of each point
- auto rays_start_dir = infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
- auto rays_start = rays_start_dir[0];
- auto rays_dir = rays_start_dir[1];
- auto rays_pts = torch::empty({total_len, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
- auto mask_outbbox = torch::empty({total_len}, torch::dtype(torch::kBool).device(torch::kCUDA));
- AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_pts_on_rays_cuda", ([&] {
- sample_pts_on_rays_cuda_kernel<scalar_t><<<(total_len+threads-1)/threads, threads>>>(
- rays_start.data<scalar_t>(),
- rays_dir.data<scalar_t>(),
- xyz_min.data<scalar_t>(),
- xyz_max.data<scalar_t>(),
- ray_id.data<int64_t>(),
- step_id.data<int64_t>(),
- stepdist, total_len,
- rays_pts.data<scalar_t>(),
- mask_outbbox.data<bool>());
- }));
- return {rays_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max};
- }
- template <typename scalar_t>
- __global__ void sample_ndc_pts_on_rays_cuda_kernel(
- const scalar_t* __restrict__ rays_o,
- const scalar_t* __restrict__ rays_d,
- const scalar_t* __restrict__ xyz_min,
- const scalar_t* __restrict__ xyz_max,
- const int N_samples, const int n_rays,
- scalar_t* __restrict__ rays_pts,
- bool* __restrict__ mask_outbbox) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if(idx<N_samples*n_rays) {
- const int i_ray = idx / N_samples;
- const int i_step = idx % N_samples;
- const int offset_p = idx * 3;
- const int offset_r = i_ray * 3;
- const float dist = ((float)i_step) / (N_samples-1);
- const float px = rays_o[offset_r ] + rays_d[offset_r ] * dist;
- const float py = rays_o[offset_r+1] + rays_d[offset_r+1] * dist;
- const float pz = rays_o[offset_r+2] + rays_d[offset_r+2] * dist;
- rays_pts[offset_p ] = px;
- rays_pts[offset_p+1] = py;
- rays_pts[offset_p+2] = pz;
- mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
- (xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
- }
- }
- std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
- torch::Tensor rays_o, torch::Tensor rays_d,
- torch::Tensor xyz_min, torch::Tensor xyz_max,
- const int N_samples) {
- const int threads = 256;
- const int n_rays = rays_o.size(0);
- auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
- auto mask_outbbox = torch::empty({n_rays, N_samples}, torch::dtype(torch::kBool).device(torch::kCUDA));
- AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_ndc_pts_on_rays_cuda", ([&] {
- sample_ndc_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
- rays_o.data<scalar_t>(),
- rays_d.data<scalar_t>(),
- xyz_min.data<scalar_t>(),
- xyz_max.data<scalar_t>(),
- N_samples, n_rays,
- rays_pts.data<scalar_t>(),
- mask_outbbox.data<bool>());
- }));
- return {rays_pts, mask_outbbox};
- }
- template <typename scalar_t>
- __device__ __forceinline__ scalar_t norm3(const scalar_t x, const scalar_t y, const scalar_t z) {
- return sqrt(x*x + y*y + z*z);
- }
- template <typename scalar_t>
- __global__ void sample_bg_pts_on_rays_cuda_kernel(
- const scalar_t* __restrict__ rays_o,
- const scalar_t* __restrict__ rays_d,
- const scalar_t* __restrict__ t_max,
- const float bg_preserve,
- const int N_samples, const int n_rays,
- scalar_t* __restrict__ rays_pts) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if(idx<N_samples*n_rays) {
- const int i_ray = idx / N_samples;
- const int i_step = idx % N_samples;
- const int offset_p = idx * 3;
- const int offset_r = i_ray * 3;
- /* Original pytorch implementation
- ori_t_outer = t_max[:,None] - 1 + 1 / torch.linspace(1, 0, N_outer+1)[:-1]
- ori_ray_pts_outer = (rays_o[:,None,:] + rays_d[:,None,:] * ori_t_outer[:,:,None]).reshape(-1,3)
- t_outer = ori_ray_pts_outer.norm(dim=-1)
- R_outer = t_outer / ori_ray_pts_outer.abs().amax(1)
- # r = R * R / t
- o2i_p = R_outer.pow(2) / t_outer.pow(2) * (1-self.bg_preserve) + R_outer / t_outer * self.bg_preserve
- ray_pts_outer = (ori_ray_pts_outer * o2i_p[:,None]).reshape(len(rays_o), -1, 3)
- */
- const float t_inner = t_max[i_ray];
- const float ori_t_outer = t_inner - 1. + 1. / (1. - ((float)i_step) / N_samples);
- const float ori_ray_pts_x = rays_o[offset_r ] + rays_d[offset_r ] * ori_t_outer;
- const float ori_ray_pts_y = rays_o[offset_r+1] + rays_d[offset_r+1] * ori_t_outer;
- const float ori_ray_pts_z = rays_o[offset_r+2] + rays_d[offset_r+2] * ori_t_outer;
- const float t_outer = norm3(ori_ray_pts_x, ori_ray_pts_y, ori_ray_pts_z);
- const float ori_ray_pts_m = max(abs(ori_ray_pts_x), max(abs(ori_ray_pts_y), abs(ori_ray_pts_z)));
- const float R_outer = t_outer / ori_ray_pts_m;
- const float o2i_p = R_outer*R_outer / (t_outer*t_outer) * (1.-bg_preserve) + R_outer / t_outer * bg_preserve;
- const float px = ori_ray_pts_x * o2i_p;
- const float py = ori_ray_pts_y * o2i_p;
- const float pz = ori_ray_pts_z * o2i_p;
- rays_pts[offset_p ] = px;
- rays_pts[offset_p+1] = py;
- rays_pts[offset_p+2] = pz;
- }
- }
- torch::Tensor sample_bg_pts_on_rays_cuda(
- torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
- const float bg_preserve, const int N_samples) {
- const int threads = 256;
- const int n_rays = rays_o.size(0);
- auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
- AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_bg_pts_on_rays_cuda", ([&] {
- sample_bg_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
- rays_o.data<scalar_t>(),
- rays_d.data<scalar_t>(),
- t_max.data<scalar_t>(),
- bg_preserve,
- N_samples, n_rays,
- rays_pts.data<scalar_t>());
- }));
- return rays_pts;
- }
- /*
- MaskCache lookup to skip known freespace.
- */
- static __forceinline__ __device__
- bool check_xyz(int i, int j, int k, int sz_i, int sz_j, int sz_k) {
- return (0 <= i) && (i < sz_i) && (0 <= j) && (j < sz_j) && (0 <= k) && (k < sz_k);
- }
- template <typename scalar_t>
- __global__ void maskcache_lookup_cuda_kernel(
- bool* __restrict__ world,
- scalar_t* __restrict__ xyz,
- bool* __restrict__ out,
- scalar_t* __restrict__ xyz2ijk_scale,
- scalar_t* __restrict__ xyz2ijk_shift,
- const int sz_i, const int sz_j, const int sz_k, const int n_pts) {
- const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_pt<n_pts) {
- const int offset = i_pt * 3;
- const int i = round(xyz[offset ] * xyz2ijk_scale[0] + xyz2ijk_shift[0]);
- const int j = round(xyz[offset+1] * xyz2ijk_scale[1] + xyz2ijk_shift[1]);
- const int k = round(xyz[offset+2] * xyz2ijk_scale[2] + xyz2ijk_shift[2]);
- if(check_xyz(i, j, k, sz_i, sz_j, sz_k)) {
- out[i_pt] = world[i*sz_j*sz_k + j*sz_k + k];
- }
- }
- }
- torch::Tensor maskcache_lookup_cuda(
- torch::Tensor world,
- torch::Tensor xyz,
- torch::Tensor xyz2ijk_scale,
- torch::Tensor xyz2ijk_shift) {
- const int sz_i = world.size(0);
- const int sz_j = world.size(1);
- const int sz_k = world.size(2);
- const int n_pts = xyz.size(0);
- auto out = torch::zeros({n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA));
- if(n_pts==0) {
- return out;
- }
- const int threads = 256;
- const int blocks = (n_pts + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(xyz.type(), "maskcache_lookup_cuda", ([&] {
- maskcache_lookup_cuda_kernel<scalar_t><<<blocks, threads>>>(
- world.data<bool>(),
- xyz.data<scalar_t>(),
- out.data<bool>(),
- xyz2ijk_scale.data<scalar_t>(),
- xyz2ijk_shift.data<scalar_t>(),
- sz_i, sz_j, sz_k, n_pts);
- }));
- return out;
- }
- /*
- Ray marching helper function.
- */
- template <typename scalar_t>
- __global__ void raw2alpha_cuda_kernel(
- scalar_t* __restrict__ density,
- const float shift, const float interval, const int n_pts,
- scalar_t* __restrict__ exp_d,
- scalar_t* __restrict__ alpha) {
- const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_pt<n_pts) {
- const scalar_t e = exp(density[i_pt] + shift); // can be inf
- exp_d[i_pt] = e;
- alpha[i_pt] = 1 - pow(1 + e, -interval);
- }
- }
- template <typename scalar_t>
- __global__ void raw2alpha_nonuni_cuda_kernel(
- scalar_t* __restrict__ density,
- const float shift, scalar_t* __restrict__ interval, const int n_pts,
- scalar_t* __restrict__ exp_d,
- scalar_t* __restrict__ alpha) {
- const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_pt<n_pts) {
- const scalar_t e = exp(density[i_pt] + shift); // can be inf
- exp_d[i_pt] = e;
- alpha[i_pt] = 1 - pow(1 + e, -interval[i_pt]);
- }
- }
- std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval) {
- const int n_pts = density.size(0);
- auto exp_d = torch::empty_like(density);
- auto alpha = torch::empty_like(density);
- if(n_pts==0) {
- return {exp_d, alpha};
- }
- const int threads = 256;
- const int blocks = (n_pts + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
- raw2alpha_cuda_kernel<scalar_t><<<blocks, threads>>>(
- density.data<scalar_t>(),
- shift, interval, n_pts,
- exp_d.data<scalar_t>(),
- alpha.data<scalar_t>());
- }));
- return {exp_d, alpha};
- }
- std::vector<torch::Tensor> raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval) {
- const int n_pts = density.size(0);
- auto exp_d = torch::empty_like(density);
- auto alpha = torch::empty_like(density);
- if(n_pts==0) {
- return {exp_d, alpha};
- }
- const int threads = 256;
- const int blocks = (n_pts + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
- raw2alpha_nonuni_cuda_kernel<scalar_t><<<blocks, threads>>>(
- density.data<scalar_t>(),
- shift, interval.data<scalar_t>(), n_pts,
- exp_d.data<scalar_t>(),
- alpha.data<scalar_t>());
- }));
- return {exp_d, alpha};
- }
- template <typename scalar_t>
- __global__ void raw2alpha_backward_cuda_kernel(
- scalar_t* __restrict__ exp_d,
- scalar_t* __restrict__ grad_back,
- const float interval, const int n_pts,
- scalar_t* __restrict__ grad) {
- const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_pt<n_pts) {
- grad[i_pt] = min(exp_d[i_pt], 1e10) * pow(1+exp_d[i_pt], -interval-1) * interval * grad_back[i_pt];
- }
- }
- template <typename scalar_t>
- __global__ void raw2alpha_nonuni_backward_cuda_kernel(
- scalar_t* __restrict__ exp_d,
- scalar_t* __restrict__ grad_back,
- scalar_t* __restrict__ interval, const int n_pts,
- scalar_t* __restrict__ grad) {
- const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_pt<n_pts) {
- grad[i_pt] = min(exp_d[i_pt], 1e10) * pow(1+exp_d[i_pt], -interval[i_pt]-1) * interval[i_pt] * grad_back[i_pt];
- }
- }
- torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, const float interval) {
- const int n_pts = exp_d.size(0);
- auto grad = torch::empty_like(exp_d);
- if(n_pts==0) {
- return grad;
- }
- const int threads = 256;
- const int blocks = (n_pts + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
- raw2alpha_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
- exp_d.data<scalar_t>(),
- grad_back.data<scalar_t>(),
- interval, n_pts,
- grad.data<scalar_t>());
- }));
- return grad;
- }
- torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, torch::Tensor interval) {
- const int n_pts = exp_d.size(0);
- auto grad = torch::empty_like(exp_d);
- if(n_pts==0) {
- return grad;
- }
- const int threads = 256;
- const int blocks = (n_pts + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
- raw2alpha_nonuni_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
- exp_d.data<scalar_t>(),
- grad_back.data<scalar_t>(),
- interval.data<scalar_t>(), n_pts,
- grad.data<scalar_t>());
- }));
- return grad;
- }
- template <typename scalar_t>
- __global__ void alpha2weight_cuda_kernel(
- scalar_t* __restrict__ alpha,
- const int n_rays,
- scalar_t* __restrict__ weight,
- scalar_t* __restrict__ T,
- scalar_t* __restrict__ alphainv_last,
- int64_t* __restrict__ i_start,
- int64_t* __restrict__ i_end) {
- const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_ray<n_rays) {
- const int i_s = i_start[i_ray];
- const int i_e_max = i_end[i_ray];
- float T_cum = 1.;
- int i;
- for(i=i_s; i<i_e_max; ++i) {
- T[i] = T_cum;
- weight[i] = T_cum * alpha[i];
- T_cum *= (1. - alpha[i]);
- if(T_cum<1e-3) {
- i+=1;
- break;
- }
- }
- i_end[i_ray] = i;
- alphainv_last[i_ray] = T_cum;
- }
- }
- __global__ void __set_i_for_segment_start_end(
- int64_t* __restrict__ ray_id,
- const int n_pts,
- int64_t* __restrict__ i_start,
- int64_t* __restrict__ i_end) {
- const int index = blockIdx.x * blockDim.x + threadIdx.x;
- if(0<index && index<n_pts && ray_id[index]!=ray_id[index-1]) {
- i_start[ray_id[index]] = index;
- i_end[ray_id[index-1]] = index;
- }
- }
- std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {
- const int n_pts = alpha.size(0);
- const int threads = 256;
- auto weight = torch::zeros_like(alpha);
- auto T = torch::ones_like(alpha);
- auto alphainv_last = torch::ones({n_rays}, alpha.options());
- auto i_start = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
- auto i_end = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
- if(n_pts==0) {
- return {weight, T, alphainv_last, i_start, i_end};
- }
- __set_i_for_segment_start_end<<<(n_pts+threads-1)/threads, threads>>>(
- ray_id.data<int64_t>(), n_pts, i_start.data<int64_t>(), i_end.data<int64_t>());
- i_end[ray_id[n_pts-1]] = n_pts;
- const int blocks = (n_rays + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_cuda", ([&] {
- alpha2weight_cuda_kernel<scalar_t><<<blocks, threads>>>(
- alpha.data<scalar_t>(),
- n_rays,
- weight.data<scalar_t>(),
- T.data<scalar_t>(),
- alphainv_last.data<scalar_t>(),
- i_start.data<int64_t>(),
- i_end.data<int64_t>());
- }));
- return {weight, T, alphainv_last, i_start, i_end};
- }
- template <typename scalar_t>
- __global__ void alpha2weight_backward_cuda_kernel(
- scalar_t* __restrict__ alpha,
- scalar_t* __restrict__ weight,
- scalar_t* __restrict__ T,
- scalar_t* __restrict__ alphainv_last,
- int64_t* __restrict__ i_start,
- int64_t* __restrict__ i_end,
- const int n_rays,
- scalar_t* __restrict__ grad_weights,
- scalar_t* __restrict__ grad_last,
- scalar_t* __restrict__ grad) {
- const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
- if(i_ray<n_rays) {
- const int i_s = i_start[i_ray];
- const int i_e = i_end[i_ray];
- float back_cum = grad_last[i_ray] * alphainv_last[i_ray];
- for(int i=i_e-1; i>=i_s; --i) {
- grad[i] = grad_weights[i] * T[i] - back_cum / (1-alpha[i] + 1e-10);
- back_cum += grad_weights[i] * weight[i];
- }
- }
- }
- torch::Tensor alpha2weight_backward_cuda(
- torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
- torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
- torch::Tensor grad_weights, torch::Tensor grad_last) {
- auto grad = torch::zeros_like(alpha);
- if(n_rays==0) {
- return grad;
- }
- const int threads = 256;
- const int blocks = (n_rays + threads - 1) / threads;
- AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_backward_cuda", ([&] {
- alpha2weight_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
- alpha.data<scalar_t>(),
- weight.data<scalar_t>(),
- T.data<scalar_t>(),
- alphainv_last.data<scalar_t>(),
- i_start.data<int64_t>(),
- i_end.data<int64_t>(),
- n_rays,
- grad_weights.data<scalar_t>(),
- grad_last.data<scalar_t>(),
- grad.data<scalar_t>());
- }));
- return grad;
- }
|