ceil_div.h 497 B

123456789101112131415161718192021222324
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <type_traits>
  4. namespace at {
  5. /**
  6. Computes ceil(a / b)
  7. */
  8. template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
  9. C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) {
  10. return (a + b - 1) / b;
  11. }
  12. /**
  13. Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest
  14. multiple of b
  15. */
  16. template <typename T>
  17. C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) {
  18. return ceil_div(a, b) * b;
  19. }
  20. } // namespace at