OptionalArrayRef.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // This file defines OptionalArrayRef<T>, a class that has almost the same
  2. // exact functionality as std::optional<ArrayRef<T>>, except that its
  3. // converting constructor fixes a dangling pointer issue.
  4. //
  5. // The implicit converting constructor of both std::optional<ArrayRef<T>> and
  6. // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
  7. // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
  8. // a std::optional<ArrayRef<T>> and fixing the constructor implementation.
  9. //
  10. // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
  11. #pragma once
  12. #include <c10/util/ArrayRef.h>
  13. #include <cstdint>
  14. #include <initializer_list>
  15. #include <optional>
  16. #include <type_traits>
  17. #include <utility>
  18. namespace c10 {
  19. template <typename T>
  20. class OptionalArrayRef final {
  21. public:
  22. // Constructors
  23. constexpr OptionalArrayRef() noexcept = default;
  24. constexpr OptionalArrayRef(std::nullopt_t) noexcept {}
  25. OptionalArrayRef(const OptionalArrayRef& other) = default;
  26. OptionalArrayRef(OptionalArrayRef&& other) noexcept = default;
  27. constexpr OptionalArrayRef(const std::optional<ArrayRef<T>>& other) noexcept
  28. : wrapped_opt_array_ref(other) {}
  29. constexpr OptionalArrayRef(std::optional<ArrayRef<T>>&& other) noexcept
  30. : wrapped_opt_array_ref(std::move(other)) {}
  31. constexpr OptionalArrayRef(const T& value) noexcept
  32. : wrapped_opt_array_ref(value) {}
  33. template <
  34. typename U = ArrayRef<T>,
  35. std::enable_if_t<
  36. !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
  37. !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
  38. std::is_constructible_v<ArrayRef<T>, U&&> &&
  39. std::is_convertible_v<U&&, ArrayRef<T>> &&
  40. !std::is_convertible_v<U&&, T>,
  41. bool> = false>
  42. constexpr OptionalArrayRef(U&& value) noexcept(
  43. std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
  44. : wrapped_opt_array_ref(std::forward<U>(value)) {}
  45. template <
  46. typename U = ArrayRef<T>,
  47. std::enable_if_t<
  48. !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
  49. !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
  50. std::is_constructible_v<ArrayRef<T>, U&&> &&
  51. !std::is_convertible_v<U&&, ArrayRef<T>>,
  52. bool> = false>
  53. constexpr explicit OptionalArrayRef(U&& value) noexcept(
  54. std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
  55. : wrapped_opt_array_ref(std::forward<U>(value)) {}
  56. template <typename... Args>
  57. constexpr explicit OptionalArrayRef(
  58. std::in_place_t ip,
  59. Args&&... args) noexcept
  60. : wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {}
  61. template <typename U, typename... Args>
  62. constexpr explicit OptionalArrayRef(
  63. std::in_place_t ip,
  64. std::initializer_list<U> il,
  65. Args&&... args)
  66. : wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {}
  67. constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
  68. : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
  69. // Destructor
  70. ~OptionalArrayRef() = default;
  71. // Assignment
  72. constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept {
  73. wrapped_opt_array_ref = std::nullopt;
  74. return *this;
  75. }
  76. OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
  77. OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default;
  78. constexpr OptionalArrayRef& operator=(
  79. const std::optional<ArrayRef<T>>& other) noexcept {
  80. wrapped_opt_array_ref = other;
  81. return *this;
  82. }
  83. constexpr OptionalArrayRef& operator=(
  84. std::optional<ArrayRef<T>>&& other) noexcept {
  85. wrapped_opt_array_ref = std::move(other);
  86. return *this;
  87. }
  88. template <
  89. typename U = ArrayRef<T>,
  90. typename = std::enable_if_t<
  91. !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
  92. std::is_constructible_v<ArrayRef<T>, U&&> &&
  93. std::is_assignable_v<ArrayRef<T>&, U&&>>>
  94. constexpr OptionalArrayRef& operator=(U&& value) noexcept(
  95. std::is_nothrow_constructible_v<ArrayRef<T>, U&&> &&
  96. std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) {
  97. wrapped_opt_array_ref = std::forward<U>(value);
  98. return *this;
  99. }
  100. // Observers
  101. constexpr ArrayRef<T>* operator->() noexcept {
  102. return &wrapped_opt_array_ref.value();
  103. }
  104. constexpr const ArrayRef<T>* operator->() const noexcept {
  105. return &wrapped_opt_array_ref.value();
  106. }
  107. constexpr ArrayRef<T>& operator*() & noexcept {
  108. return wrapped_opt_array_ref.value();
  109. }
  110. constexpr const ArrayRef<T>& operator*() const& noexcept {
  111. return wrapped_opt_array_ref.value();
  112. }
  113. constexpr ArrayRef<T>&& operator*() && noexcept {
  114. return std::move(wrapped_opt_array_ref.value());
  115. }
  116. constexpr const ArrayRef<T>&& operator*() const&& noexcept {
  117. return std::move(wrapped_opt_array_ref.value());
  118. }
  119. constexpr explicit operator bool() const noexcept {
  120. return wrapped_opt_array_ref.has_value();
  121. }
  122. constexpr bool has_value() const noexcept {
  123. return wrapped_opt_array_ref.has_value();
  124. }
  125. constexpr ArrayRef<T>& value() & {
  126. return wrapped_opt_array_ref.value();
  127. }
  128. constexpr const ArrayRef<T>& value() const& {
  129. // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
  130. return wrapped_opt_array_ref.value();
  131. }
  132. constexpr ArrayRef<T>&& value() && {
  133. return std::move(wrapped_opt_array_ref.value());
  134. }
  135. constexpr const ArrayRef<T>&& value() const&& {
  136. return std::move(wrapped_opt_array_ref.value());
  137. }
  138. template <typename U>
  139. constexpr std::
  140. enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
  141. value_or(U&& default_value) const& {
  142. return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
  143. }
  144. template <typename U>
  145. constexpr std::
  146. enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
  147. value_or(U&& default_value) && {
  148. return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
  149. }
  150. // Modifiers
  151. constexpr void swap(OptionalArrayRef& other) noexcept {
  152. std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
  153. }
  154. constexpr void reset() noexcept {
  155. wrapped_opt_array_ref.reset();
  156. }
  157. template <typename... Args>
  158. constexpr std::
  159. enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&>
  160. emplace(Args&&... args) noexcept(
  161. std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) {
  162. return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...);
  163. }
  164. template <typename U, typename... Args>
  165. constexpr ArrayRef<T>& emplace(
  166. std::initializer_list<U> il,
  167. Args&&... args) noexcept {
  168. return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...);
  169. }
  170. private:
  171. std::optional<ArrayRef<T>> wrapped_opt_array_ref;
  172. };
  173. using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
  174. inline bool operator==(
  175. const OptionalIntArrayRef& a1,
  176. const IntArrayRef& other) {
  177. if (!a1.has_value()) {
  178. return false;
  179. }
  180. return a1.value() == other;
  181. }
  182. inline bool operator==(
  183. const c10::IntArrayRef& a1,
  184. const c10::OptionalIntArrayRef& a2) {
  185. return a2 == a1;
  186. }
  187. } // namespace c10