utils.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. // Metal helper functions
  2. #pragma once
  3. #include <c10/metal/common.h>
  4. #include <metal_stdlib>
  5. namespace c10 {
  6. namespace metal {
  7. namespace detail {
  8. template <typename T>
  9. struct vectypes {};
  10. template <>
  11. struct vectypes<float> {
  12. using type4 = float4;
  13. using type3 = float3;
  14. using type2 = float2;
  15. };
  16. template <>
  17. struct vectypes<half> {
  18. using type4 = half4;
  19. using type3 = half3;
  20. using type2 = half2;
  21. };
  22. template <>
  23. struct vectypes<bfloat> {
  24. using type4 = bfloat4;
  25. using type3 = bfloat3;
  26. using type2 = bfloat2;
  27. };
  28. template <>
  29. struct vectypes<short> {
  30. using type4 = short4;
  31. using type3 = short3;
  32. using type2 = short2;
  33. };
  34. template <>
  35. struct vectypes<int> {
  36. using type4 = int4;
  37. using type3 = int3;
  38. using type2 = int2;
  39. };
  40. template <>
  41. struct vectypes<long> {
  42. using type4 = short4;
  43. using type3 = short3;
  44. using type2 = short2;
  45. };
  46. template <typename T>
  47. struct OpMathType {
  48. using type = T;
  49. };
  50. template <>
  51. struct OpMathType<half> {
  52. using type = float;
  53. };
  54. template <>
  55. struct OpMathType<short> {
  56. using type = int;
  57. };
  58. template <>
  59. struct OpMathType<char> {
  60. using type = int;
  61. };
  62. template <>
  63. struct OpMathType<uchar> {
  64. using type = int;
  65. };
  66. template <>
  67. struct OpMathType<bfloat> {
  68. using type = float;
  69. };
  70. // Type promotion structure for higher precision accumulation
  71. template <typename T>
  72. struct AccumulationType {
  73. using type = T;
  74. };
  75. // Specialization for half - promote to float for accumulation
  76. template <>
  77. struct AccumulationType<half> {
  78. using type = float;
  79. };
  80. // Specialization for bfloat - promote to float for accumulation
  81. template <>
  82. struct AccumulationType<bfloat> {
  83. using type = float;
  84. };
  85. } // namespace detail
  86. template <typename T>
  87. ::metal::enable_if_t<::metal::is_floating_point_v<T>, T> max(T a, T b) {
  88. return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b);
  89. }
  90. template <typename T, typename U>
  91. ::metal::enable_if_t<::metal::is_integral_v<T>&& ::metal::is_integral_v<U>, T>
  92. max(T a, U b) {
  93. return ::metal::max(a, static_cast<T>(b));
  94. }
  95. template <typename T>
  96. ::metal::enable_if_t<::metal::is_floating_point_v<T>, T> min(T a, T b) {
  97. return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b);
  98. }
  99. template <typename T, typename U>
  100. ::metal::enable_if_t<::metal::is_integral_v<T>&& ::metal::is_integral_v<U>, T>
  101. min(T a, U b) {
  102. return ::metal::min(a, static_cast<T>(b));
  103. }
  104. template <>
  105. inline bfloat min(bfloat a, bfloat b) {
  106. return bfloat(
  107. ::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b)));
  108. }
  109. template <>
  110. inline bfloat max(bfloat a, bfloat b) {
  111. return bfloat(
  112. ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
  113. }
  114. template <typename T>
  115. using vec2type_t = typename detail::vectypes<T>::type2;
  116. template <typename T>
  117. using vec4type_t = typename detail::vectypes<T>::type4;
  118. template <typename T>
  119. using opmath_t = typename detail::OpMathType<T>::type;
  120. template <typename T>
  121. using accum_t = typename detail::AccumulationType<T>::type;
  122. // TODO: Move it to type_traits header may be
  123. template <typename F, typename... Args>
  124. using result_of = decltype(::metal::declval<F>()(::metal::declval<Args>()...));
  125. template <typename T>
  126. constexpr constant bool is_complex_v =
  127. ::metal::is_same_v<T, float2> || ::metal::is_same_v<T, half2>;
  128. template <typename T>
  129. constexpr constant bool is_scalar_floating_point_v =
  130. ::metal::is_floating_point_v<T> && ::metal::is_scalar_v<T>;
  131. template <typename T>
  132. constexpr constant bool is_scalar_integral_v =
  133. ::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
  134. template <typename U, typename V>
  135. using common_dtype = decltype(U(0) + V(0));
  136. // floor_divide
  137. template <
  138. typename T,
  139. typename U,
  140. ::metal::enable_if_t<
  141. is_scalar_integral_v<T> && is_scalar_integral_v<U>,
  142. bool> = true>
  143. inline common_dtype<T, U> floor_divide(T x, U y) {
  144. const auto quot = x / y;
  145. return (x < 0) == (y < 0) ? quot : (x % y != 0) ? quot - 1 : quot;
  146. }
  147. template <
  148. typename T,
  149. typename U,
  150. ::metal::enable_if_t<
  151. is_scalar_floating_point_v<T> && is_scalar_floating_point_v<U>,
  152. bool> = true>
  153. inline common_dtype<T, U> floor_divide(T x, U y) {
  154. return ::metal::floor(x / y);
  155. }
  156. // fmod
  157. template <
  158. typename T,
  159. typename U,
  160. ::metal::enable_if_t<
  161. is_scalar_integral_v<T> && is_scalar_integral_v<U>,
  162. bool> = true>
  163. inline common_dtype<T, U> fmod(T x, U y) {
  164. return x % y;
  165. }
  166. template <
  167. typename T,
  168. typename U,
  169. ::metal::enable_if_t<
  170. is_scalar_floating_point_v<T> && is_scalar_floating_point_v<U>,
  171. bool> = true>
  172. inline common_dtype<T, U> fmod(T x, U y) {
  173. return ::metal::fmod(x, y);
  174. }
  175. // cast_to primitives
  176. // - No-op if types as the same
  177. template <
  178. typename T,
  179. typename U,
  180. ::metal::enable_if_t<::metal::is_same_v<U, T>, bool> = true>
  181. inline T cast_to(const U from) {
  182. return from;
  183. }
  184. // - Simple cast between scalar and complex dtypes
  185. template <
  186. typename T,
  187. typename U,
  188. ::metal::enable_if_t<
  189. !::metal::is_same_v<U, T> && (is_complex_v<T> == is_complex_v<U>),
  190. bool> = true>
  191. inline T cast_to(const U from) {
  192. return static_cast<T>(from);
  193. }
  194. // - Scalar to complex
  195. template <
  196. typename T,
  197. typename U,
  198. ::metal::enable_if_t<is_complex_v<T> && !is_complex_v<U>, bool> = true>
  199. inline T cast_to(const U from) {
  200. return T(float(from), 0.0);
  201. }
  202. // - Complex to scalar (should not really be used, but exists for compliteness)
  203. template <
  204. typename T,
  205. typename U,
  206. ::metal::enable_if_t<!is_complex_v<T> && is_complex_v<U>, bool> = true>
  207. inline T cast_to(const U from) {
  208. return static_cast<T>(from.x);
  209. }
  210. // Generalizable math operators (used for both scalar and complex)
  211. template <
  212. typename T,
  213. typename U,
  214. ::metal::enable_if_t<!is_complex_v<T>, bool> = true>
  215. inline common_dtype<T, U> mul(const T x, const U y) {
  216. return x * y;
  217. }
  218. template <
  219. typename T,
  220. typename U,
  221. ::metal::enable_if_t<is_complex_v<T> && is_complex_v<U>, bool> = true>
  222. inline common_dtype<T, U> mul(const T x, const U y) {
  223. return T(x.x * y.x - x.y * y.y, x.x * y.y + x.y * y.x);
  224. }
  225. template <
  226. typename T,
  227. typename U,
  228. ::metal::enable_if_t<!is_complex_v<T>, bool> = true>
  229. inline common_dtype<T, U> div(const T x, const U y) {
  230. return x / y;
  231. }
  232. template <
  233. typename T,
  234. typename U,
  235. ::metal::enable_if_t<is_complex_v<T> && is_complex_v<U>, bool> = true>
  236. inline common_dtype<T, U> div(const T x, const U y) {
  237. return T(::metal::dot(x, y), x.y * y.x - x.x * y.y) / ::metal::dot(y, y);
  238. }
  239. // Remainder operator
  240. template <
  241. typename T,
  242. typename U,
  243. ::metal::enable_if_t<
  244. is_scalar_floating_point_v<T> || is_scalar_floating_point_v<U>,
  245. bool> = true>
  246. inline float remainder(const T x, const U y) {
  247. const auto x_f = static_cast<float>(x);
  248. const auto y_f = static_cast<float>(y);
  249. return x_f - y_f * floor_divide(x_f, y_f);
  250. }
  251. template <
  252. typename T,
  253. typename U,
  254. ::metal::enable_if_t<
  255. is_scalar_integral_v<T> && is_scalar_integral_v<U>,
  256. bool> = true>
  257. inline common_dtype<T, U> remainder(const T x, const U y) {
  258. auto rc = x % y;
  259. return rc == 0 || (x ^ y) > 0 ? rc : rc + y;
  260. }
  261. // Based on algorithm described in
  262. // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
  263. inline float log1p(float x) {
  264. const auto xp1 = 1.0f + x;
  265. // First two elements of Taylor series for log(1+x) in Horner's form are:
  266. // log(1+x) = x * (1 - x * (.5 ...)), but if 1 + x == x, then it's just x
  267. if (xp1 == 1.0f) {
  268. return x;
  269. }
  270. auto rc = ::metal::precise::log(xp1);
  271. if (x > -.5 && x < .5) {
  272. // Order of operations is important here for higher precision
  273. rc *= x / (xp1 - 1.0f);
  274. }
  275. return rc;
  276. }
  277. template <typename T1, typename T2 = T1>
  278. struct pair {
  279. T1 first;
  280. T2 second;
  281. };
  282. } // namespace metal
  283. } // namespace c10