Enumerate.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /*
  2. * Ported from folly/container/Enumerate.h
  3. */
  4. #pragma once
  5. #include <iterator>
  6. #include <memory>
  7. #ifdef _WIN32
  8. #include <basetsd.h> // @manual
  9. using ssize_t = SSIZE_T;
  10. #endif
  11. #include <c10/macros/Macros.h>
  12. /**
  13. * Similar to Python's enumerate(), enumerate() can be used to
  14. * iterate a range with a for-range loop, and it also allows to
  15. * retrieve the count of iterations so far. Can be used in constexpr
  16. * context.
  17. *
  18. * For example:
  19. *
  20. * for (auto&& [index, element] : enumerate(vec)) {
  21. * // index is a const reference to a size_t containing the iteration count.
  22. * // element is a reference to the type contained within vec, mutable
  23. * // unless vec is const.
  24. * }
  25. *
  26. * If the binding is const, the element reference is too.
  27. *
  28. * for (const auto&& [index, element] : enumerate(vec)) {
  29. * // element is always a const reference.
  30. * }
  31. *
  32. * It can also be used as follows:
  33. *
  34. * for (auto&& it : enumerate(vec)) {
  35. * // *it is a reference to the current element. Mutable unless vec is const.
  36. * // it->member can be used as well.
  37. * // it.index contains the iteration count.
  38. * }
  39. *
  40. * As before, const auto&& it can also be used.
  41. */
  42. namespace c10 {
  43. namespace detail {
  44. template <class T>
  45. struct MakeConst {
  46. using type = const T;
  47. };
  48. template <class T>
  49. struct MakeConst<T&> {
  50. using type = const T&;
  51. };
  52. template <class T>
  53. struct MakeConst<T*> {
  54. using type = const T*;
  55. };
  56. template <class Iterator>
  57. class Enumerator {
  58. public:
  59. constexpr explicit Enumerator(Iterator it) : it_(std::move(it)) {}
  60. class Proxy {
  61. public:
  62. using difference_type = ssize_t;
  63. using value_type = typename std::iterator_traits<Iterator>::value_type;
  64. using reference = typename std::iterator_traits<Iterator>::reference;
  65. using pointer = typename std::iterator_traits<Iterator>::pointer;
  66. using iterator_category = std::input_iterator_tag;
  67. C10_ALWAYS_INLINE constexpr explicit Proxy(const Enumerator& e)
  68. : index(e.idx_), element(*e.it_) {}
  69. // Non-const Proxy: Forward constness from Iterator.
  70. C10_ALWAYS_INLINE constexpr reference operator*() {
  71. return element;
  72. }
  73. C10_ALWAYS_INLINE constexpr pointer operator->() {
  74. return std::addressof(element);
  75. }
  76. // Const Proxy: Force const references.
  77. C10_ALWAYS_INLINE constexpr typename MakeConst<reference>::type operator*()
  78. const {
  79. return element;
  80. }
  81. C10_ALWAYS_INLINE constexpr typename MakeConst<pointer>::type operator->()
  82. const {
  83. return std::addressof(element);
  84. }
  85. public:
  86. size_t index;
  87. reference element;
  88. };
  89. C10_ALWAYS_INLINE constexpr Proxy operator*() const {
  90. return Proxy(*this);
  91. }
  92. C10_ALWAYS_INLINE constexpr Enumerator& operator++() {
  93. ++it_;
  94. ++idx_;
  95. return *this;
  96. }
  97. template <typename OtherIterator>
  98. C10_ALWAYS_INLINE constexpr bool operator==(
  99. const Enumerator<OtherIterator>& rhs) const {
  100. return it_ == rhs.it_;
  101. }
  102. template <typename OtherIterator>
  103. C10_ALWAYS_INLINE constexpr bool operator!=(
  104. const Enumerator<OtherIterator>& rhs) const {
  105. return !(it_ == rhs.it_);
  106. }
  107. private:
  108. template <typename OtherIterator>
  109. friend class Enumerator;
  110. Iterator it_;
  111. size_t idx_ = 0;
  112. };
  113. template <class Range>
  114. class RangeEnumerator {
  115. Range r_;
  116. using BeginIteratorType = decltype(std::declval<Range>().begin());
  117. using EndIteratorType = decltype(std::declval<Range>().end());
  118. public:
  119. // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
  120. constexpr explicit RangeEnumerator(Range&& r) : r_(std::forward<Range>(r)) {}
  121. constexpr Enumerator<BeginIteratorType> begin() {
  122. return Enumerator<BeginIteratorType>(r_.begin());
  123. }
  124. constexpr Enumerator<EndIteratorType> end() {
  125. return Enumerator<EndIteratorType>(r_.end());
  126. }
  127. };
  128. } // namespace detail
  129. template <class Range>
  130. constexpr detail::RangeEnumerator<Range> enumerate(Range&& r) {
  131. return detail::RangeEnumerator<Range>(std::forward<Range>(r));
  132. }
  133. } // namespace c10