FusedAdagrad.h 495 B

1234567891011121314151617181920
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/native/DispatchStub.h>
  3. namespace at::native {
  4. using fused_adagrad_fn = void (*)(
  5. const at::Tensor& param,
  6. const at::Tensor& grad,
  7. const at::Tensor& state_sum,
  8. const at::Tensor& state_step,
  9. const double lr,
  10. const double lr_decay,
  11. const double weight_decay,
  12. const double eps,
  13. const bool maximize,
  14. const float* grad_scale_ptr);
  15. DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub)
  16. } // namespace at::native