SparseCsrTensorUtils.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. #pragma once
  2. #include <ATen/SparseCsrTensorImpl.h>
  3. #include <ATen/SparseTensorImpl.h>
  4. #include <ATen/core/Tensor.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #include <ATen/NativeFunctions.h>
  8. #include <ATen/Operators.h>
  9. #else
  10. #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
  11. #include <ATen/ops/resize_as_sparse_native.h>
  12. #endif
  13. #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
  14. [&] { \
  15. const auto& the_layout = LAYOUT; \
  16. switch (the_layout) { \
  17. case kSparseCsr: \
  18. case kSparseCsc: \
  19. case kSparseBsr: \
  20. case kSparseBsc: \
  21. return __VA_ARGS__(); \
  22. default: \
  23. TORCH_CHECK( \
  24. false, \
  25. NAME, \
  26. " expected sparse compressed tensor layout but got ", \
  27. the_layout); \
  28. } \
  29. }()
  30. #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
  31. LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
  32. [&]() { \
  33. const auto& the_layout = LAYOUT; \
  34. switch (the_layout) { \
  35. case kSparseCsr: \
  36. case kSparseBsr: \
  37. return (ROW_DIM_ACTION)(); \
  38. case kSparseCsc: \
  39. case kSparseBsc: \
  40. return (COLUMN_DIM_ACTION)(); \
  41. default: \
  42. TORCH_CHECK( \
  43. false, \
  44. NAME, \
  45. " expected sparse compressed tensor layout but got ", \
  46. the_layout); \
  47. } \
  48. }()
  49. #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
  50. LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
  51. [&]() { \
  52. const auto& the_layout = LAYOUT; \
  53. switch (the_layout) { \
  54. case kSparseCsr: \
  55. case kSparseCsc: \
  56. return (NO_BLOCK_ACTION)(); \
  57. case kSparseBsr: \
  58. case kSparseBsc: \
  59. return (BLOCK_ACTION)(); \
  60. default: \
  61. TORCH_CHECK( \
  62. false, \
  63. NAME, \
  64. " expected sparse compressed tensor layout but got ", \
  65. the_layout); \
  66. } \
  67. }()
  68. #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
  69. LAYOUT, NAME, ROW_DIM_ACTION) \
  70. [&]() { \
  71. const auto& the_layout = LAYOUT; \
  72. switch (the_layout) { \
  73. case kSparseCsr: \
  74. case kSparseBsr: \
  75. return (ROW_DIM_ACTION)(); \
  76. default: \
  77. TORCH_CHECK( \
  78. false, \
  79. NAME, \
  80. " expected sparse row compressed tensor layout but got ", \
  81. the_layout); \
  82. } \
  83. }()
  84. #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
  85. LAYOUT, NAME, COL_DIM_ACTION) \
  86. [&]() { \
  87. const auto& the_layout = LAYOUT; \
  88. switch (the_layout) { \
  89. case kSparseCsc: \
  90. case kSparseBsc: \
  91. return (COL_DIM_ACTION)(); \
  92. default: \
  93. TORCH_CHECK( \
  94. false, \
  95. NAME, \
  96. " expected sparse column compressed tensor layout but got ", \
  97. the_layout); \
  98. } \
  99. }()
  100. #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  101. [&]() { \
  102. const auto& the_layout = LAYOUT; \
  103. switch (the_layout) { \
  104. case kSparseCsr: \
  105. case kSparseCsc: \
  106. return (ACTION)(); \
  107. default: \
  108. TORCH_CHECK( \
  109. false, \
  110. NAME, \
  111. " expected sparse compressed (non-block) tensor layout but got ", \
  112. the_layout); \
  113. } \
  114. }()
  115. #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  116. [&]() { \
  117. const auto& the_layout = LAYOUT; \
  118. switch (the_layout) { \
  119. case kSparseBsr: \
  120. case kSparseBsc: \
  121. return (ACTION)(); \
  122. default: \
  123. TORCH_CHECK( \
  124. false, \
  125. NAME, \
  126. " expected sparse compressed block tensor layout but got ", \
  127. the_layout); \
  128. } \
  129. }()
  130. #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
  131. AT_DISPATCH_SWITCH( \
  132. TYPE, \
  133. NAME, \
  134. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  135. kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
  136. namespace at::sparse_csr {
  137. // Implements RAII object to manage checking sparse tensor invariants:
  138. class CheckSparseTensorInvariants {
  139. bool old_state;
  140. public:
  141. CheckSparseTensorInvariants(bool state)
  142. : old_state(at::globalContext().checkSparseTensorInvariants()) {
  143. at::globalContext().setCheckSparseTensorInvariants(state);
  144. }
  145. CheckSparseTensorInvariants(CheckSparseTensorInvariants&& other) = delete;
  146. CheckSparseTensorInvariants(const CheckSparseTensorInvariants&) = delete;
  147. CheckSparseTensorInvariants& operator=(const CheckSparseTensorInvariants&) =
  148. delete;
  149. CheckSparseTensorInvariants& operator=(CheckSparseTensorInvariants&&) =
  150. delete;
  151. ~CheckSparseTensorInvariants() {
  152. at::globalContext().setCheckSparseTensorInvariants(old_state);
  153. }
  154. };
  155. using SparseCsrTensor = Tensor;
  156. inline bool is_sparse_compressed(const Layout& layout) {
  157. switch (layout) {
  158. case kSparseCsr:
  159. case kSparseCsc:
  160. case kSparseBsr:
  161. case kSparseBsc:
  162. return true;
  163. default:;
  164. }
  165. return false;
  166. }
  167. inline bool is_sparse_compressed(const Tensor& self) {
  168. return is_sparse_compressed(self.layout());
  169. }
  170. inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
  171. AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
  172. self.layout(), "get_sparse_csr_impl", [&] {});
  173. return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
  174. }
  175. inline std::string layoutToString(
  176. Layout layout,
  177. bool upper = false,
  178. bool lower = false) {
  179. switch (layout) {
  180. case kSparseCsr:
  181. return (upper ? "CSR" : (lower ? "csr" : "Csr"));
  182. case kSparseCsc:
  183. return (upper ? "CSC" : (lower ? "csc" : "Csc"));
  184. case kSparseBsr:
  185. return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
  186. case kSparseBsc:
  187. return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
  188. default:
  189. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  190. return "";
  191. }
  192. }
  193. inline bool isCompressedRow(Layout layout) {
  194. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  195. layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
  196. }
  197. inline bool isCompressedColumn(Layout layout) {
  198. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  199. layout,
  200. "isCompressedColumn",
  201. [&] { return false; },
  202. [&] { return true; });
  203. }
  204. inline std::string compressedIndicesName(Layout layout) {
  205. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  206. layout,
  207. "compressedIndicesName",
  208. [&] { return "crow_indices"; },
  209. [&] { return "ccol_indices"; });
  210. }
  211. inline std::string plainIndicesName(Layout layout) {
  212. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  213. layout,
  214. "plainIndicesName",
  215. [&] { return "col_indices"; },
  216. [&] { return "row_indices"; });
  217. }
  218. inline std::string compressedDimName(Layout layout) {
  219. switch (layout) {
  220. case kSparseCsr:
  221. return "row";
  222. case kSparseCsc:
  223. return "column";
  224. case kSparseBsr:
  225. return "row block";
  226. case kSparseBsc:
  227. return "column block";
  228. default:
  229. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  230. return "";
  231. }
  232. }
  233. inline std::string plainDimName(Layout layout) {
  234. switch (layout) {
  235. case kSparseCsr:
  236. return "column";
  237. case kSparseCsc:
  238. return "row";
  239. case kSparseBsr:
  240. return "column block";
  241. case kSparseBsc:
  242. return "row block";
  243. default:
  244. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  245. return "";
  246. }
  247. }
  248. inline size_t rowDimension(Layout layout, IntArrayRef size) {
  249. return size.size() - (isCompressedRow(layout) ? 2 : 1);
  250. }
  251. inline size_t columnDimension(Layout layout, IntArrayRef size) {
  252. return size.size() - (isCompressedColumn(layout) ? 2 : 1);
  253. }
  254. inline size_t compressedDimension(
  255. Layout layout,
  256. IntArrayRef size,
  257. size_t dense_ndim = 0) {
  258. return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
  259. }
  260. inline size_t plainDimension(
  261. Layout layout,
  262. IntArrayRef size,
  263. size_t dense_ndim = 0) {
  264. return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
  265. }
  266. inline int64_t numBatchDimensions(Tensor const& self) {
  267. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  268. self.layout(),
  269. "numBatchDimensions",
  270. [&self] { return self.crow_indices().dim() - 1; },
  271. [&self] { return self.ccol_indices().dim() - 1; });
  272. }
  273. inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
  274. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  275. self.layout(),
  276. "getCompressedPlainIndices",
  277. [&self] {
  278. return std::make_pair(self.crow_indices(), self.col_indices());
  279. },
  280. [&self] {
  281. return std::make_pair(self.ccol_indices(), self.row_indices());
  282. });
  283. }
  284. inline ScalarType getIndexDtype(Tensor const& self) {
  285. switch (self.layout()) {
  286. case kSparseCsr:
  287. case kSparseBsr:
  288. return self.crow_indices().scalar_type();
  289. case kSparseCsc:
  290. case kSparseBsc:
  291. return self.ccol_indices().scalar_type();
  292. case kSparse:
  293. return self._indices().scalar_type();
  294. default:
  295. return ScalarType::Long;
  296. }
  297. }
  298. inline Layout flip_compressed_layout(Layout layout) {
  299. switch (layout) {
  300. case kSparseCsr:
  301. return kSparseCsc;
  302. case kSparseCsc:
  303. return kSparseCsr;
  304. case kSparseBsr:
  305. return kSparseBsc;
  306. case kSparseBsc:
  307. return kSparseBsr;
  308. default:
  309. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  310. return kSparseCsr;
  311. }
  312. }
  313. inline DimVector getBlockSize(Tensor const& self) {
  314. int64_t n_batch = numBatchDimensions(self);
  315. return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
  316. }
  317. inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
  318. if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
  319. int64_t n_batch = numBatchDimensions(self);
  320. return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
  321. } else {
  322. return {};
  323. }
  324. }
  325. template <typename binary_op_t, typename binary_op_out_t>
  326. inline bool only_sparse_compressed_binary_op_trivial_cases(
  327. const Tensor& self,
  328. const Tensor& other,
  329. const Scalar& alpha,
  330. Tensor& out,
  331. const binary_op_t& binary_op,
  332. const binary_op_out_t& binary_op_out) {
  333. // Only sparse compressed! Just like the name says :)
  334. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
  335. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
  336. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
  337. // Bypass BLAS if there are matches in (self, other, out)
  338. if (self.is_same(out) && self.is_same(other)) {
  339. binary_op_out(self.values(), other.values(), alpha);
  340. return true;
  341. }
  342. if (self.is_same(other)) {
  343. auto [compressed_indices, plain_indices] =
  344. at::sparse_csr::getCompressedPlainIndices(self);
  345. static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
  346. ->set_member_tensors(
  347. compressed_indices,
  348. plain_indices,
  349. binary_op(self.values(), other.values(), alpha),
  350. self.sizes());
  351. return true;
  352. }
  353. return false;
  354. }
  355. inline bool only_sparse_compressed_add_trivial_cases(
  356. const Tensor& self,
  357. const Tensor& other,
  358. const Scalar& alpha,
  359. Tensor& out) {
  360. return only_sparse_compressed_binary_op_trivial_cases(
  361. self,
  362. other,
  363. alpha,
  364. out,
  365. [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
  366. return v1.add(v2, alpha);
  367. },
  368. [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
  369. return v1.add_(v2, alpha);
  370. });
  371. }
  372. inline Tensor to_type(const Tensor& input, ScalarType dtype) {
  373. auto [compressed_indices, plain_indices] =
  374. at::sparse_csr::getCompressedPlainIndices(input);
  375. return at::_sparse_compressed_tensor_unsafe(
  376. compressed_indices,
  377. plain_indices,
  378. std::move(input.values()).to(dtype),
  379. input.sizes(),
  380. dtype,
  381. input.layout(),
  382. input.device(),
  383. input.options().pinned_memory_opt());
  384. }
  385. template <typename acc_t, typename scalar_t>
  386. inline std::tuple<Tensor, Tensor> create_acc_buffer(
  387. TensorOptions option,
  388. ScalarType type,
  389. int64_t nnz = -1) {
  390. Tensor new_values, new_values_acc;
  391. constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
  392. bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
  393. if constexpr (need_acc) {
  394. auto acc_dtype = CppTypeToScalarType<acc_t>::value;
  395. new_values_acc = at::empty({}, option.dtype(acc_dtype));
  396. new_values = is_integral ? new_values_acc : at::empty({}, option);
  397. } else {
  398. new_values = new_values_acc = at::empty({}, option);
  399. }
  400. if (nnz != -1) {
  401. return std::make_tuple(
  402. new_values.resize_(nnz), new_values_acc.resize_(nnz));
  403. } else {
  404. return std::make_tuple(new_values, new_values_acc);
  405. }
  406. }
  407. inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
  408. if (!new_values_acc.is_same(new_values)) {
  409. new_values.copy_(new_values_acc);
  410. }
  411. }
  412. } // namespace at::sparse_csr