TensorGeometry.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #pragma once
  2. #include <ATen/core/TensorBase.h>
  3. #include <c10/core/WrapDimMinimal.h>
  4. namespace at {
  5. // Return if the tensor geometry represented by `sizes` and `strides` is
  6. // contiguous Although we cache is_contiguous in tensor now, this is till useful
  7. // because it allows checking if a particular geometry is contiguous without
  8. // explicitly constructing a tensor, e.g., when you want to choose a kernel
  9. // strategy based on whether a subgeometry is contiguous.
  10. TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
  11. struct TORCH_API TensorGeometry {
  12. TensorGeometry() = default;
  13. explicit TensorGeometry(c10::SymIntArrayRef sizes)
  14. : sizes_(sizes.vec()),
  15. strides_(sizes.size()),
  16. has_symbolic_sizes_strides_(
  17. !c10::asIntArrayRefSlowOpt(sizes).has_value()) {
  18. int64_t dim = static_cast<int64_t>(sizes.size());
  19. c10::SymInt expected_stride = 1;
  20. for (int64_t i = dim - 1; i >= 0; i--) {
  21. strides_[i] = expected_stride;
  22. expected_stride *= sizes_[i];
  23. }
  24. numel_ = expected_stride;
  25. }
  26. explicit TensorGeometry(const TensorBase& t)
  27. : sizes_(t.sym_sizes().vec()),
  28. strides_(t.sym_strides().vec()),
  29. storage_offset_(t.sym_storage_offset()),
  30. numel_(t.sym_numel()),
  31. has_symbolic_sizes_strides_(
  32. t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
  33. explicit TensorGeometry(
  34. std::vector<at::SymInt> sizes,
  35. std::vector<at::SymInt> strides,
  36. at::SymInt storage_offset)
  37. : sizes_(std::move(sizes)),
  38. strides_(std::move(strides)),
  39. storage_offset_(std::move(storage_offset)) {
  40. recompute();
  41. }
  42. // true if the tensor is contiguous
  43. bool is_contiguous() const;
  44. int64_t dim() const {
  45. return static_cast<int64_t>(sizes_.size());
  46. }
  47. int64_t size(int64_t dim) const {
  48. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  49. dim = c10::maybe_wrap_dim(dim, this->dim());
  50. return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
  51. }
  52. c10::IntArrayRef sizes() const {
  53. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  54. return c10::asIntArrayRefUnchecked(sizes_);
  55. }
  56. int64_t stride(int64_t dim) const {
  57. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  58. dim = c10::maybe_wrap_dim(dim, this->dim());
  59. return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
  60. }
  61. c10::IntArrayRef strides() const {
  62. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  63. return c10::asIntArrayRefUnchecked(strides_);
  64. }
  65. int64_t storage_offset() const {
  66. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  67. return storage_offset_.as_int_unchecked();
  68. }
  69. int64_t numel() const {
  70. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  71. return numel_.as_int_unchecked();
  72. }
  73. c10::SymInt sym_size(int64_t dim) const {
  74. dim = c10::maybe_wrap_dim(dim, this->dim());
  75. return sizes_.at(static_cast<size_t>(dim));
  76. }
  77. c10::SymIntArrayRef sym_sizes() const {
  78. return sizes_;
  79. }
  80. c10::SymInt sym_stride(int64_t dim) const {
  81. dim = c10::maybe_wrap_dim(dim, this->dim());
  82. return strides_.at(static_cast<size_t>(dim));
  83. }
  84. c10::SymIntArrayRef sym_strides() const {
  85. return strides_;
  86. }
  87. c10::SymInt sym_storage_offset() const {
  88. return storage_offset_;
  89. }
  90. c10::SymInt sym_numel() const {
  91. return numel_;
  92. }
  93. TensorGeometry transpose(int64_t dim0, int64_t dim1) {
  94. TensorGeometry r = *this; // copy
  95. TORCH_CHECK(
  96. dim0 < dim(),
  97. "transpose: dim0=",
  98. dim0,
  99. " out of range (dim=",
  100. dim(),
  101. ")")
  102. TORCH_CHECK(
  103. dim1 < dim(),
  104. "transpose: dim1=",
  105. dim1,
  106. " out of range (dim=",
  107. dim(),
  108. ")")
  109. std::swap(r.sizes_[dim0], r.sizes_[dim1]);
  110. std::swap(r.strides_[dim0], r.strides_[dim1]);
  111. return r;
  112. }
  113. std::vector<c10::SymInt>& mutable_sizes() {
  114. return sizes_;
  115. }
  116. std::vector<c10::SymInt>& mutable_strides() {
  117. return strides_;
  118. }
  119. c10::SymInt& mutable_storage_offset() {
  120. return storage_offset_;
  121. }
  122. void recompute() {
  123. // recalculate numel after a change
  124. c10::SymInt numel = 1;
  125. for (const auto& i : sizes_) {
  126. numel = numel * i;
  127. }
  128. numel_ = std::move(numel);
  129. has_symbolic_sizes_strides_ =
  130. !c10::asIntArrayRefSlowOpt(sizes_).has_value();
  131. }
  132. private:
  133. std::vector<c10::SymInt> sizes_;
  134. std::vector<c10::SymInt> strides_;
  135. c10::SymInt storage_offset_;
  136. c10::SymInt numel_;
  137. bool has_symbolic_sizes_strides_{false};
  138. };
  139. } // namespace at