irange.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. // Copyright 2004-present Facebook. All Rights Reserved.
  2. #pragma once
  3. #include <c10/util/TypeSafeSignMath.h>
  4. #include <algorithm>
  5. #include <cstddef>
  6. #include <iterator>
  7. #include <type_traits>
  8. namespace c10 {
  9. namespace detail {
  10. template <
  11. typename I,
  12. bool one_sided = false,
  13. std::enable_if_t<std::is_integral_v<I>, int> = 0>
  14. struct integer_iterator {
  15. using iterator_category = std::input_iterator_tag;
  16. using value_type = I;
  17. using difference_type = std::ptrdiff_t;
  18. using pointer = I*;
  19. using reference = I&;
  20. explicit constexpr integer_iterator(I val) : value(val) {}
  21. constexpr I operator*() const {
  22. return value;
  23. }
  24. constexpr I const* operator->() const {
  25. return &value;
  26. }
  27. constexpr integer_iterator& operator++() {
  28. ++value;
  29. return *this;
  30. }
  31. constexpr integer_iterator operator++(int) {
  32. const auto copy = *this;
  33. ++*this;
  34. return copy;
  35. }
  36. constexpr bool operator==(const integer_iterator& other) const {
  37. if constexpr (one_sided) {
  38. // Range-for loops' end test is `begin != end`, not `begin <
  39. // end`. To handle `c10::irange(n)` where n < 0 (which should be
  40. // empty), we just make `begin != end` fail whenever `end` is
  41. // negative.
  42. return is_negative(other.value) || value == other.value;
  43. } else {
  44. return value == other.value;
  45. }
  46. // Suppress "warning: missing return statement at end of non-void function"
  47. // which Nvidia's Robert Crovella confirms is an NVCC compiler error
  48. // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27
  49. // `__builtin_unreachable();` would be best here, but it's not
  50. // available with all compilers. So we instead return an arbitrary
  51. // value trusting that this line will, in fact, never be reached.
  52. return false; // Horrible hack
  53. }
  54. constexpr bool operator!=(const integer_iterator& other) const {
  55. return !(*this == other);
  56. }
  57. protected:
  58. I value;
  59. };
  60. } // namespace detail
  61. template <
  62. typename I,
  63. bool one_sided = false,
  64. std::enable_if_t<std::is_integral_v<I>, bool> = true>
  65. struct integer_range {
  66. public:
  67. constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {}
  68. using iterator = detail::integer_iterator<I, one_sided>;
  69. constexpr iterator begin() const {
  70. return begin_;
  71. }
  72. constexpr iterator end() const {
  73. return end_;
  74. }
  75. private:
  76. iterator begin_;
  77. iterator end_;
  78. };
  79. /// Creates an integer range for the half-open interval [begin, end)
  80. /// If end<=begin, then the range is empty.
  81. /// The range has the type of the `end` integer; `begin` integer is
  82. /// cast to this type.
  83. template <
  84. typename Integer1,
  85. typename Integer2,
  86. std::enable_if_t<std::is_integral_v<Integer1>, bool> = true,
  87. std::enable_if_t<std::is_integral_v<Integer2>, bool> = true>
  88. constexpr integer_range<Integer2> irange(Integer1 begin, Integer2 end) {
  89. // If end<=begin then the range is empty; we can achieve this effect by
  90. // choosing the larger of {begin, end} as the loop terminator
  91. return {
  92. static_cast<Integer2>(begin),
  93. std::max(static_cast<Integer2>(begin), end)};
  94. }
  95. /// Creates an integer range for the half-open interval [0, end)
  96. /// If end<=begin, then the range is empty
  97. template <
  98. typename Integer,
  99. std::enable_if_t<std::is_integral_v<Integer>, bool> = true>
  100. constexpr integer_range<Integer, true> irange(Integer end) {
  101. return {Integer(), end};
  102. }
  103. } // namespace c10