ScalarOps.h 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <c10/core/Scalar.h>
  4. #ifndef AT_PER_OPERATOR_HEADERS
  5. #include <ATen/Functions.h>
  6. #else
  7. #include <ATen/ops/scalar_tensor.h>
  8. #endif
  9. namespace at::detail {
  10. // When filling a number to 1-element CPU tensor, we want to skip
  11. // everything but manipulate data ptr directly.
  12. // Ideally this fast pass should be implemented in TensorIterator,
  13. // but we also want to skip compute_types which in not avoidable
  14. // in TensorIterator for now.
  15. Tensor& scalar_fill(Tensor& self, const Scalar& value);
  16. TORCH_API Tensor scalar_tensor_static(
  17. const Scalar& s,
  18. std::optional<ScalarType> dtype_opt,
  19. std::optional<Device> device_opt);
  20. } // namespace at::detail
  21. // This is in the c10 namespace because we use ADL to find the functions in it.
  22. namespace c10 {
  23. // FIXME: this should be (and was) Scalar::toTensor, but there is currently no
  24. // way to implement this without going through Derived Types (which are not part
  25. // of core).
  26. inline at::Tensor scalar_to_tensor(
  27. const Scalar& s,
  28. const Device device = at::kCPU) {
  29. // This is the fast track we have for CPU scalar tensors.
  30. if (device == at::kCPU) {
  31. return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
  32. }
  33. return at::scalar_tensor(s, at::device(device).dtype(s.type()));
  34. }
  35. } // namespace c10
  36. namespace at::native {
  37. inline Tensor wrapped_scalar_tensor(
  38. const Scalar& scalar,
  39. const Device device = at::kCPU) {
  40. auto tensor = scalar_to_tensor(scalar, device);
  41. tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  42. return tensor;
  43. }
  44. } // namespace at::native