MatrixRef.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #pragma once
  2. #include <ATen/Utils.h>
  3. #include <c10/util/ArrayRef.h>
  4. namespace at {
  5. /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
  6. /// we can easily view it as a multidimensional array.
  7. ///
  8. /// Like ArrayRef, this class does not own the underlying data, it is expected
  9. /// to be used in situations where the data resides in some other buffer.
  10. ///
  11. /// This is intended to be trivially copyable, so it should be passed by
  12. /// value.
  13. ///
  14. /// For now, 2D only (so the copies are actually cheap, without having
  15. /// to write a SmallVector class) and contiguous only (so we can
  16. /// return non-strided ArrayRef on index).
  17. ///
  18. /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
  19. template <typename T>
  20. class MatrixRef {
  21. public:
  22. typedef size_t size_type;
  23. private:
  24. /// Underlying ArrayRef
  25. ArrayRef<T> arr;
  26. /// Stride of dim 0 (outer dimension)
  27. size_type stride0;
  28. // Stride of dim 1 is assumed to be 1
  29. public:
  30. /// Construct an empty Matrixref.
  31. /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
  32. /// Construct an MatrixRef from an ArrayRef and outer stride.
  33. /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
  34. : arr(arr), stride0(stride0) {
  35. TORCH_CHECK(
  36. arr.size() % stride0 == 0,
  37. "MatrixRef: ArrayRef size ",
  38. arr.size(),
  39. " not divisible by stride ",
  40. stride0)
  41. }
  42. /// @}
  43. /// @name Simple Operations
  44. /// @{
  45. /// empty - Check if the matrix is empty.
  46. bool empty() const {
  47. return arr.empty();
  48. }
  49. const T* data() const {
  50. return arr.data();
  51. }
  52. /// size - Get size a dimension
  53. size_t size(size_t dim) const {
  54. if (dim == 0) {
  55. return arr.size() / stride0;
  56. } else if (dim == 1) {
  57. return stride0;
  58. } else {
  59. TORCH_CHECK(
  60. 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
  61. }
  62. }
  63. size_t numel() const {
  64. return arr.size();
  65. }
  66. /// equals - Check for element-wise equality.
  67. bool equals(MatrixRef RHS) const {
  68. return stride0 == RHS.stride0 && arr.equals(RHS.arr);
  69. }
  70. /// @}
  71. /// @name Operator Overloads
  72. /// @{
  73. ArrayRef<T> operator[](size_t Index) const {
  74. return arr.slice(Index * stride0, stride0);
  75. }
  76. /// Disallow accidental assignment from a temporary.
  77. ///
  78. /// The declaration here is extra complicated so that "arrayRef = {}"
  79. /// continues to select the move assignment operator.
  80. template <typename U>
  81. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  82. std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
  83. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  84. U&& Temporary) = delete;
  85. /// Disallow accidental assignment from a temporary.
  86. ///
  87. /// The declaration here is extra complicated so that "arrayRef = {}"
  88. /// continues to select the move assignment operator.
  89. template <typename U>
  90. std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
  91. std::initializer_list<U>) = delete;
  92. };
  93. } // end namespace at