CPUFallback.h 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #pragma once
  2. #include <ATen/core/ivalue.h>
  3. #include <ATen/core/stack.h>
  4. #include <ATen/core/boxing/KernelFunction.h>
  5. #include <ATen/core/dispatch/Dispatcher.h>
  6. #include <c10/util/Metaprogramming.h>
  7. #include <torch/library.h>
  8. namespace at::native {
  9. // This function implements a boxed fallback to CPU.
  10. // External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
  11. TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
  12. c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
  13. // This is a helper function that backends can use to directly call their boxed CPU fallback
  14. // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
  15. template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
  16. struct _call_fallback_fn final {};
  17. template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
  18. struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
  19. static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
  20. auto op = c10::Dispatcher::singleton()
  21. // TODO: figure out how to make compiler happy without dynamic casts
  22. .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
  23. //.findSchemaOrThrow("a", "b")
  24. .typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
  25. return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
  26. c10::BoxedKernel::makeFromFunction<fallback_fn>(),
  27. op,
  28. c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
  29. // TODO: get std::forward<> to work
  30. args...
  31. );
  32. }
  33. };
  34. template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
  35. using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
  36. template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
  37. using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
  38. } // namespace at::native