#include #include #include #include template __global__ void adam_upd_cuda_kernel( scalar_t* __restrict__ param, const scalar_t* __restrict__ grad, scalar_t* __restrict__ exp_avg, scalar_t* __restrict__ exp_avg_sq, const size_t N, const float step_size, const float beta1, const float beta2, const float eps) { const size_t index = blockIdx.x * blockDim.x + threadIdx.x; if(index __global__ void masked_adam_upd_cuda_kernel( scalar_t* __restrict__ param, const scalar_t* __restrict__ grad, scalar_t* __restrict__ exp_avg, scalar_t* __restrict__ exp_avg_sq, const size_t N, const float step_size, const float beta1, const float beta2, const float eps) { const size_t index = blockIdx.x * blockDim.x + threadIdx.x; if(index __global__ void adam_upd_with_perlr_cuda_kernel( scalar_t* __restrict__ param, const scalar_t* __restrict__ grad, scalar_t* __restrict__ exp_avg, scalar_t* __restrict__ exp_avg_sq, scalar_t* __restrict__ perlr, const size_t N, const float step_size, const float beta1, const float beta2, const float eps) { const size_t index = blockIdx.x * blockDim.x + threadIdx.x; if(index<<>>( param.data(), grad.data(), exp_avg.data(), exp_avg_sq.data(), N, step_size, beta1, beta2, eps); })); } void masked_adam_upd_cuda( torch::Tensor param, torch::Tensor grad, torch::Tensor exp_avg, torch::Tensor exp_avg_sq, const int step, const float beta1, const float beta2, const float lr, const float eps) { const size_t N = param.numel(); const int threads = 256; const int blocks = (N + threads - 1) / threads; const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); AT_DISPATCH_FLOATING_TYPES(param.type(), "masked_adam_upd_cuda", ([&] { masked_adam_upd_cuda_kernel<<>>( param.data(), grad.data(), exp_avg.data(), exp_avg_sq.data(), N, step_size, beta1, beta2, eps); })); } void adam_upd_with_perlr_cuda( torch::Tensor param, torch::Tensor grad, torch::Tensor exp_avg, torch::Tensor exp_avg_sq, torch::Tensor perlr, const int step, const float beta1, const float beta2, const float lr, const float eps) { const size_t N = param.numel(); const int threads = 256; const int blocks = (N + threads - 1) / threads; const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_with_perlr_cuda", ([&] { adam_upd_with_perlr_cuda_kernel<<>>( param.data(), grad.data(), exp_avg.data(), exp_avg_sq.data(), perlr.data(), N, step_size, beta1, beta2, eps); })); }