FusedSGD.h 516 B

123456789101112131415161718192021
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/native/DispatchStub.h>
  3. namespace at::native {
  4. using fused_sgd_fn = void (*)(
  5. const at::Tensor& param,
  6. const at::Tensor& grad,
  7. const at::Tensor& momentum_buffer,
  8. const double weight_decay,
  9. const double momentum,
  10. const double lr,
  11. const double dampening,
  12. const bool nesterov,
  13. const bool maximize,
  14. const bool is_first_step,
  15. const float* grad_scale_ptr);
  16. DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub)
  17. } // namespace at::native