adam_upd.cpp 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #include <torch/extension.h>
  2. #include <vector>
  3. void adam_upd_cuda(
  4. torch::Tensor param,
  5. torch::Tensor grad,
  6. torch::Tensor exp_avg,
  7. torch::Tensor exp_avg_sq,
  8. int step, float beta1, float beta2, float lr, float eps);
  9. void masked_adam_upd_cuda(
  10. torch::Tensor param,
  11. torch::Tensor grad,
  12. torch::Tensor exp_avg,
  13. torch::Tensor exp_avg_sq,
  14. int step, float beta1, float beta2, float lr, float eps);
  15. void adam_upd_with_perlr_cuda(
  16. torch::Tensor param,
  17. torch::Tensor grad,
  18. torch::Tensor exp_avg,
  19. torch::Tensor exp_avg_sq,
  20. torch::Tensor perlr,
  21. int step, float beta1, float beta2, float lr, float eps);
  22. // C++ interface
  23. #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
  24. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  25. #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
  26. void adam_upd(
  27. torch::Tensor param,
  28. torch::Tensor grad,
  29. torch::Tensor exp_avg,
  30. torch::Tensor exp_avg_sq,
  31. int step, float beta1, float beta2, float lr, float eps) {
  32. CHECK_INPUT(param);
  33. CHECK_INPUT(grad);
  34. CHECK_INPUT(exp_avg);
  35. CHECK_INPUT(exp_avg_sq);
  36. adam_upd_cuda(param, grad, exp_avg, exp_avg_sq,
  37. step, beta1, beta2, lr, eps);
  38. }
  39. void masked_adam_upd(
  40. torch::Tensor param,
  41. torch::Tensor grad,
  42. torch::Tensor exp_avg,
  43. torch::Tensor exp_avg_sq,
  44. int step, float beta1, float beta2, float lr, float eps) {
  45. CHECK_INPUT(param);
  46. CHECK_INPUT(grad);
  47. CHECK_INPUT(exp_avg);
  48. CHECK_INPUT(exp_avg_sq);
  49. masked_adam_upd_cuda(param, grad, exp_avg, exp_avg_sq,
  50. step, beta1, beta2, lr, eps);
  51. }
  52. void adam_upd_with_perlr(
  53. torch::Tensor param,
  54. torch::Tensor grad,
  55. torch::Tensor exp_avg,
  56. torch::Tensor exp_avg_sq,
  57. torch::Tensor perlr,
  58. int step, float beta1, float beta2, float lr, float eps) {
  59. CHECK_INPUT(param);
  60. CHECK_INPUT(grad);
  61. CHECK_INPUT(exp_avg);
  62. CHECK_INPUT(exp_avg_sq);
  63. adam_upd_with_perlr_cuda(param, grad, exp_avg, exp_avg_sq, perlr,
  64. step, beta1, beta2, lr, eps);
  65. }
  66. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  67. m.def("adam_upd", &adam_upd,
  68. "Adam update");
  69. m.def("masked_adam_upd", &masked_adam_upd,
  70. "Adam update ignoring zero grad");
  71. m.def("adam_upd_with_perlr", &adam_upd_with_perlr,
  72. "Adam update ignoring zero grad with per-voxel lr");
  73. }