adam_upd_kernel.cu 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #include <torch/extension.h>
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <vector>
  5. template <typename scalar_t>
  6. __global__ void adam_upd_cuda_kernel(
  7. scalar_t* __restrict__ param,
  8. const scalar_t* __restrict__ grad,
  9. scalar_t* __restrict__ exp_avg,
  10. scalar_t* __restrict__ exp_avg_sq,
  11. const size_t N,
  12. const float step_size, const float beta1, const float beta2, const float eps) {
  13. const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  14. if(index<N) {
  15. exp_avg[index] = beta1 * exp_avg[index] + (1-beta1) * grad[index];
  16. exp_avg_sq[index] = beta2 * exp_avg_sq[index] + (1-beta2) * grad[index] * grad[index];
  17. param[index] -= step_size * exp_avg[index] / (sqrt(exp_avg_sq[index]) + eps);
  18. }
  19. }
  20. template <typename scalar_t>
  21. __global__ void masked_adam_upd_cuda_kernel(
  22. scalar_t* __restrict__ param,
  23. const scalar_t* __restrict__ grad,
  24. scalar_t* __restrict__ exp_avg,
  25. scalar_t* __restrict__ exp_avg_sq,
  26. const size_t N,
  27. const float step_size, const float beta1, const float beta2, const float eps) {
  28. const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  29. if(index<N && grad[index]!=0) {
  30. exp_avg[index] = beta1 * exp_avg[index] + (1-beta1) * grad[index];
  31. exp_avg_sq[index] = beta2 * exp_avg_sq[index] + (1-beta2) * grad[index] * grad[index];
  32. param[index] -= step_size * exp_avg[index] / (sqrt(exp_avg_sq[index]) + eps);
  33. }
  34. }
  35. template <typename scalar_t>
  36. __global__ void adam_upd_with_perlr_cuda_kernel(
  37. scalar_t* __restrict__ param,
  38. const scalar_t* __restrict__ grad,
  39. scalar_t* __restrict__ exp_avg,
  40. scalar_t* __restrict__ exp_avg_sq,
  41. scalar_t* __restrict__ perlr,
  42. const size_t N,
  43. const float step_size, const float beta1, const float beta2, const float eps) {
  44. const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  45. if(index<N) {
  46. exp_avg[index] = beta1 * exp_avg[index] + (1-beta1) * grad[index];
  47. exp_avg_sq[index] = beta2 * exp_avg_sq[index] + (1-beta2) * grad[index] * grad[index];
  48. param[index] -= step_size * perlr[index] * exp_avg[index] / (sqrt(exp_avg_sq[index]) + eps);
  49. }
  50. }
  51. void adam_upd_cuda(
  52. torch::Tensor param,
  53. torch::Tensor grad,
  54. torch::Tensor exp_avg,
  55. torch::Tensor exp_avg_sq,
  56. const int step, const float beta1, const float beta2, const float lr, const float eps) {
  57. const size_t N = param.numel();
  58. const int threads = 256;
  59. const int blocks = (N + threads - 1) / threads;
  60. const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step));
  61. AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_cuda", ([&] {
  62. adam_upd_cuda_kernel<scalar_t><<<blocks, threads>>>(
  63. param.data<scalar_t>(),
  64. grad.data<scalar_t>(),
  65. exp_avg.data<scalar_t>(),
  66. exp_avg_sq.data<scalar_t>(),
  67. N, step_size, beta1, beta2, eps);
  68. }));
  69. }
  70. void masked_adam_upd_cuda(
  71. torch::Tensor param,
  72. torch::Tensor grad,
  73. torch::Tensor exp_avg,
  74. torch::Tensor exp_avg_sq,
  75. const int step, const float beta1, const float beta2, const float lr, const float eps) {
  76. const size_t N = param.numel();
  77. const int threads = 256;
  78. const int blocks = (N + threads - 1) / threads;
  79. const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step));
  80. AT_DISPATCH_FLOATING_TYPES(param.type(), "masked_adam_upd_cuda", ([&] {
  81. masked_adam_upd_cuda_kernel<scalar_t><<<blocks, threads>>>(
  82. param.data<scalar_t>(),
  83. grad.data<scalar_t>(),
  84. exp_avg.data<scalar_t>(),
  85. exp_avg_sq.data<scalar_t>(),
  86. N, step_size, beta1, beta2, eps);
  87. }));
  88. }
  89. void adam_upd_with_perlr_cuda(
  90. torch::Tensor param,
  91. torch::Tensor grad,
  92. torch::Tensor exp_avg,
  93. torch::Tensor exp_avg_sq,
  94. torch::Tensor perlr,
  95. const int step, const float beta1, const float beta2, const float lr, const float eps) {
  96. const size_t N = param.numel();
  97. const int threads = 256;
  98. const int blocks = (N + threads - 1) / threads;
  99. const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step));
  100. AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_with_perlr_cuda", ([&] {
  101. adam_upd_with_perlr_cuda_kernel<scalar_t><<<blocks, threads>>>(
  102. param.data<scalar_t>(),
  103. grad.data<scalar_t>(),
  104. exp_avg.data<scalar_t>(),
  105. exp_avg_sq.data<scalar_t>(),
  106. perlr.data<scalar_t>(),
  107. N, step_size, beta1, beta2, eps);
  108. }));
  109. }