FusedAdam.h 683 B

123456789101112131415161718192021222324252627
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/native/DispatchStub.h>
  3. namespace at::native {
  4. enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
  5. using fused_adam_fn = void (*)(
  6. const at::Tensor& param,
  7. const at::Tensor& grad,
  8. const at::Tensor& exp_avg,
  9. const at::Tensor& exp_avg_sq,
  10. const at::Tensor& max_exp_avg_sq,
  11. const at::Tensor& state_step,
  12. const double lr,
  13. const double beta1,
  14. const double beta2,
  15. const double weight_decay,
  16. const double eps,
  17. const bool amsgrad,
  18. const bool maximize,
  19. const float* grad_scale_ptr,
  20. const ADAM_MODE);
  21. DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub)
  22. } // namespace at::native