common.h 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #pragma once
  2. // Set of global constants that could be shareable between CPU and Metal code
  3. #ifdef __METAL__
  4. #include <metal_array>
  5. #define C10_METAL_CONSTEXPR constant constexpr
  6. #else
  7. #include <array>
  8. #define C10_METAL_CONSTEXPR constexpr
  9. #endif
  10. #define C10_METAL_ALL_TYPES_FUNCTOR(_) \
  11. _(Byte, 0) \
  12. _(Char, 1) \
  13. _(Short, 2) \
  14. _(Int, 3) \
  15. _(Long, 4) \
  16. _(Half, 5) \
  17. _(Float, 6) \
  18. _(ComplexHalf, 8) \
  19. _(ComplexFloat, 9) \
  20. _(Bool, 11) \
  21. _(BFloat16, 15)
  22. namespace c10 {
  23. namespace metal {
  24. C10_METAL_CONSTEXPR unsigned max_ndim = 16;
  25. C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
  26. #ifdef __METAL__
  27. template <typename T, unsigned N>
  28. using array = ::metal::array<T, N>;
  29. #else
  30. template <typename T, unsigned N>
  31. using array = std::array<T, N>;
  32. #endif
  33. enum class ScalarType {
  34. #define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
  35. C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
  36. #undef _DEFINE_ENUM_VAL_
  37. };
  38. } // namespace metal
  39. } // namespace c10