MapAllocator.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #pragma once
  2. #include <c10/core/Allocator.h>
  3. #include <string_view>
  4. namespace at {
  5. enum MappedAllocatorModes {
  6. ALLOCATOR_MAPPED_SHARED = 1,
  7. ALLOCATOR_MAPPED_SHAREDMEM = 2,
  8. ALLOCATOR_MAPPED_EXCLUSIVE = 4,
  9. ALLOCATOR_MAPPED_NOCREATE = 8,
  10. ALLOCATOR_MAPPED_KEEPFD = 16,
  11. ALLOCATOR_MAPPED_FROMFD = 32,
  12. ALLOCATOR_MAPPED_UNLINK = 64
  13. };
  14. // Sentinel value/type to help distinguish the file descriptor constructor from
  15. // the non-file descriptor constructor
  16. enum WithFd { WITH_FD };
  17. TORCH_API std::string NewProcessWideShmHandle();
  18. class TORCH_API MapAllocator {
  19. public:
  20. MapAllocator(std::string_view filename, int flags, size_t size);
  21. MapAllocator(
  22. WithFd,
  23. std::string_view filename,
  24. int fd,
  25. int flags,
  26. size_t size);
  27. MapAllocator(const MapAllocator&) = delete;
  28. MapAllocator& operator=(const MapAllocator&) = delete;
  29. MapAllocator(MapAllocator&&) = delete;
  30. MapAllocator& operator=(MapAllocator&&) = delete;
  31. const char* filename() const {
  32. return filename_.c_str();
  33. }
  34. int fd() const {
  35. #ifdef _WIN32
  36. TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
  37. #else
  38. return fd_;
  39. #endif
  40. }
  41. ptrdiff_t size() const {
  42. return size_;
  43. }
  44. // Return a pointer to the actual data for this allocator
  45. // (in the case of the refcounted allocator, this is offset
  46. // from the base pointer.)
  47. virtual void* data() const {
  48. return base_ptr_;
  49. }
  50. int flags() const {
  51. return flags_;
  52. }
  53. static MapAllocator* fromDataPtr(const at::DataPtr&);
  54. static at::DataPtr makeDataPtr(
  55. std::string_view filename,
  56. int flags,
  57. size_t size,
  58. size_t* actual_size_out);
  59. static at::DataPtr makeDataPtr(
  60. WithFd,
  61. const char* filename,
  62. int fd,
  63. int flags,
  64. size_t size,
  65. size_t* actual_size_out);
  66. // Closes the data. Helps us avoid destructor shenanigans
  67. virtual void close();
  68. // This is very dangerous. You have to redefine this destructor for each
  69. // subclass
  70. virtual ~MapAllocator();
  71. protected:
  72. bool closed_ = false;
  73. std::string filename_;
  74. int flags_ = 0;
  75. ptrdiff_t size_; /* mapped size */
  76. #ifdef _WIN32
  77. void* handle_;
  78. void* event_;
  79. std::string eventname_;
  80. #else
  81. int fd_ = -1;
  82. #endif
  83. void* base_ptr_ = nullptr;
  84. };
  85. // Base-from-member idiom
  86. struct TORCH_API RefcountedMapAllocatorArgCheck {
  87. RefcountedMapAllocatorArgCheck(int flags);
  88. };
  89. class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
  90. public MapAllocator {
  91. public:
  92. RefcountedMapAllocator(const char* filename, int flags, size_t size);
  93. RefcountedMapAllocator(
  94. WithFd,
  95. const char* filename,
  96. int fd,
  97. int flags,
  98. size_t size);
  99. static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
  100. RefcountedMapAllocator(const RefcountedMapAllocator&) = delete;
  101. RefcountedMapAllocator(RefcountedMapAllocator&&) = delete;
  102. RefcountedMapAllocator& operator=(const RefcountedMapAllocator&) = delete;
  103. RefcountedMapAllocator& operator=(RefcountedMapAllocator&&) = delete;
  104. static at::DataPtr makeDataPtr(
  105. const char* filename,
  106. int flags,
  107. size_t size,
  108. size_t* actual_size_out);
  109. static at::DataPtr makeDataPtr(
  110. WithFd,
  111. const char* filename,
  112. int fd,
  113. int flags,
  114. size_t size,
  115. size_t* actual_size_out);
  116. void* data() const override;
  117. void incref();
  118. int decref();
  119. void close() override;
  120. ~RefcountedMapAllocator() override {
  121. RefcountedMapAllocator::close();
  122. }
  123. protected:
  124. void checkFlags();
  125. void initializeAlloc();
  126. };
  127. } // namespace at