accumulate.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. // Copyright 2004-present Facebook. All Rights Reserved.
  2. #pragma once
  3. #include <c10/util/Exception.h>
  4. #include <cstdint>
  5. #include <functional>
  6. #include <iterator>
  7. #include <numeric>
  8. #include <type_traits>
  9. #include <utility>
  10. namespace c10 {
  11. /// Sum of a list of integers; accumulates into the int64_t datatype
  12. template <
  13. typename C,
  14. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  15. inline int64_t sum_integers(const C& container) {
  16. // std::accumulate infers return type from `init` type, so if the `init` type
  17. // is not large enough to hold the result, computation can overflow. We use
  18. // `int64_t` here to avoid this.
  19. return std::accumulate(
  20. container.begin(), container.end(), static_cast<int64_t>(0));
  21. }
  22. /// Sum of integer elements referred to by iterators; accumulates into the
  23. /// int64_t datatype
  24. template <
  25. typename Iter,
  26. std::enable_if_t<
  27. std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
  28. int> = 0>
  29. inline int64_t sum_integers(Iter begin, Iter end) {
  30. // std::accumulate infers return type from `init` type, so if the `init` type
  31. // is not large enough to hold the result, computation can overflow. We use
  32. // `int64_t` here to avoid this.
  33. return std::accumulate(begin, end, static_cast<int64_t>(0));
  34. }
  35. /// Product of a list of integers; accumulates into the int64_t datatype
  36. template <
  37. typename C,
  38. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  39. inline int64_t multiply_integers(const C& container) {
  40. // std::accumulate infers return type from `init` type, so if the `init` type
  41. // is not large enough to hold the result, computation can overflow. We use
  42. // `int64_t` here to avoid this.
  43. return std::accumulate(
  44. container.begin(),
  45. container.end(),
  46. static_cast<int64_t>(1),
  47. std::multiplies<>());
  48. }
  49. /// Product of integer elements referred to by iterators; accumulates into the
  50. /// int64_t datatype
  51. template <
  52. typename Iter,
  53. std::enable_if_t<
  54. std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
  55. int> = 0>
  56. inline int64_t multiply_integers(Iter begin, Iter end) {
  57. // std::accumulate infers return type from `init` type, so if the `init` type
  58. // is not large enough to hold the result, computation can overflow. We use
  59. // `int64_t` here to avoid this.
  60. return std::accumulate(
  61. begin, end, static_cast<int64_t>(1), std::multiplies<>());
  62. }
  63. /// Return product of all dimensions starting from k
  64. /// Returns 1 if k>=dims.size()
  65. template <
  66. typename C,
  67. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  68. inline int64_t numelements_from_dim(const int k, const C& dims) {
  69. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
  70. if (k > static_cast<int>(dims.size())) {
  71. return 1;
  72. } else {
  73. auto cbegin = dims.cbegin();
  74. std::advance(cbegin, k);
  75. return multiply_integers(cbegin, dims.cend());
  76. }
  77. }
  78. /// Product of all dims up to k (not including dims[k])
  79. /// Throws an error if k>dims.size()
  80. template <
  81. typename C,
  82. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  83. inline int64_t numelements_to_dim(const int k, const C& dims) {
  84. TORCH_INTERNAL_ASSERT(0 <= k);
  85. TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
  86. auto cend = dims.cbegin();
  87. std::advance(cend, k);
  88. return multiply_integers(dims.cbegin(), cend);
  89. }
  90. /// Product of all dims between k and l (including dims[k] and excluding
  91. /// dims[l]) k and l may be supplied in either order
  92. template <
  93. typename C,
  94. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  95. inline int64_t numelements_between_dim(int k, int l, const C& dims) {
  96. TORCH_INTERNAL_ASSERT(0 <= k);
  97. TORCH_INTERNAL_ASSERT(0 <= l);
  98. if (k > l) {
  99. std::swap(k, l);
  100. }
  101. TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
  102. auto cbegin = dims.cbegin();
  103. auto cend = dims.cbegin();
  104. std::advance(cbegin, k);
  105. std::advance(cend, l);
  106. return multiply_integers(cbegin, cend);
  107. }
  108. } // namespace c10