DLConvertor.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include <ATen/Tensor.h>
  4. #include <ATen/dlpack.h>
  5. // this converter will:
  6. // 1) take a Tensor object and wrap it in the DLPack tensor
  7. // 2) take a dlpack tensor and convert it to the ATen Tensor
  8. namespace at {
  9. TORCH_API ScalarType toScalarType(const DLDataType& dtype);
  10. TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
  11. TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
  12. TORCH_API Tensor
  13. fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
  14. TORCH_API Tensor fromDLPackVersioned(
  15. DLManagedTensorVersioned* src,
  16. std::function<void(void*)> deleter = {});
  17. TORCH_API DLDataType getDLDataType(const Tensor& t);
  18. TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
  19. // Copies the Tensor if there's a device mismatch or copy is forced.
  20. // This should be used before actually creating the DLPack capsule.
  21. TORCH_API Tensor maybeCopyTensor(
  22. const Tensor& data,
  23. std::optional<DLDevice> optional_dl_device,
  24. std::optional<bool> copy);
  25. // Converts the given at::Device into a DLDevice.
  26. TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
  27. // This trait class is used for retrieving different attributes, such as the
  28. // PyCapsule names and conversion functions for both DLPack tensor classes:
  29. // `DLManagedTensor` and `DLManagedTensorVersioned`.
  30. //
  31. // Each specialization should contain the following 2 traits:
  32. // - `capsule`: actual name of the capsule
  33. // - `used`: name of the capsule after using it
  34. // - `toDLPack`: function for converting a tensor into a DLPack capsule
  35. // - `fromDLPack`: function for creating a tensor from a DLPack capsule
  36. //
  37. // While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
  38. // Although it contains the core implementation, it lacks the required book
  39. // keeping logic contained in its caller `tensor_fromDLPack`.
  40. //
  41. // That said, `fromDLPack` is used directly in a few DLPack tests that live
  42. // inside ATen (no Python available).
  43. template <class T>
  44. struct DLPackTraits {};
  45. template <>
  46. struct DLPackTraits<DLManagedTensor> {
  47. inline static const char* capsule = "dltensor";
  48. inline static const char* used = "used_dltensor";
  49. inline static auto toDLPack = at::toDLPack;
  50. inline static auto fromDLPack = at::fromDLPack;
  51. };
  52. template <>
  53. struct DLPackTraits<DLManagedTensorVersioned> {
  54. inline static const char* capsule = "dltensor_versioned";
  55. inline static const char* used = "used_dltensor_versioned";
  56. inline static auto toDLPack = at::toDLPackVersioned;
  57. inline static auto fromDLPack = at::fromDLPackVersioned;
  58. };
  59. } // namespace at