SharedReduceOps.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. #pragma once
  2. // Please note that this file is
  3. // used across both CPU and GPU.
  4. #include <type_traits>
  5. #include <complex>
  6. #include <c10/macros/Macros.h>
  7. #include <ATen/detail/FunctionTraits.h>
  8. #include <ATen/NumericUtils.h>
  9. #include <ATen/OpMathType.h>
  10. #if defined(__CUDACC__)
  11. #include <ATen/cuda/DeviceUtils.cuh>
  12. #include <ATen/native/cuda/DeviceSqrt.cuh>
  13. #elif defined(__HIPCC__)
  14. #include <ATen/hip/DeviceUtils.cuh>
  15. #include <ATen/native/hip/DeviceSqrt.cuh>
  16. #endif
  17. #if defined(__CUDACC__) || defined(__HIPCC__)
  18. #include <thrust/pair.h>
  19. #else
  20. #include <cmath>
  21. #define device_sqrt std::sqrt
  22. #endif
  23. #if defined(__CUDACC__) || defined(__HIPCC__)
  24. template <typename scalar_t>
  25. inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
  26. #if defined(__HIPCC__)
  27. // TODO: remove this special case for HIP when issue is fixed:
  28. // https://github.com/ROCm/hip/issues/2209
  29. scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
  30. #else
  31. scalar_t max = at::_isnan(b) ? b : std::max(a, b);
  32. #endif
  33. return max;
  34. }
  35. template <typename scalar_t>
  36. inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
  37. #if defined(__HIPCC__)
  38. // TODO: remove this special case for HIP when issue is fixed:
  39. // https://github.com/ROCm/hip/issues/2209
  40. scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
  41. #else
  42. scalar_t min = at::_isnan(b) ? b : std::min(a, b);
  43. #endif
  44. return min;
  45. }
  46. #define MAX(X, Y) max_propagate_nan(X,Y)
  47. #define MIN(X, Y) min_propagate_nan(X,Y)
  48. #else
  49. #include <ATen/native/cpu/zmath.h>
  50. #define MAX(X, Y) max_impl(X,Y)
  51. #define MIN(X, Y) min_impl(X,Y)
  52. #endif
  53. // ROCM hcc doesn't work well with using std:: in kernel functions
  54. #if defined(__CUDA_ARCH__)
  55. #include <c10/cuda/CUDAMathCompat.h>
  56. #define compat_pow c10::cuda::compat::pow
  57. #elif defined(__HIPCC__)
  58. #include <c10/hip/HIPMathCompat.h>
  59. #define compat_pow c10::hip::compat::pow
  60. #else
  61. #define compat_pow std::pow
  62. #endif
  63. namespace at::native {
  64. namespace detail {
  65. #if defined(__CUDACC__) || defined(__HIPCC__)
  66. template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
  67. #else
  68. template <typename T1, typename T2> using pair = std::pair<T1, T2>;
  69. #endif
  70. } // namespace detail
  71. template <typename scalar_t, typename index_t>
  72. struct WelfordData {
  73. scalar_t mean;
  74. scalar_t m2;
  75. index_t n;
  76. scalar_t nf;
  77. C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
  78. C10_HOST_DEVICE WelfordData(
  79. scalar_t mean,
  80. scalar_t m2,
  81. index_t n,
  82. scalar_t nf)
  83. : mean(mean), m2(m2), n(n), nf(nf) {}
  84. };
  85. template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
  86. struct WelfordOps {
  87. acc_scalar_t correction;
  88. bool take_sqrt;
  89. public:
  90. using acc_t = WelfordData<acc_scalar_t, index_t>;
  91. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
  92. // We accumulate n in index_t to avoid cumulative rounding error, but still
  93. // need nf for use in combine where int32 may overflow.
  94. index_t new_n = acc.n + 1;
  95. acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
  96. acc_scalar_t delta = data - acc.mean;
  97. acc_scalar_t new_mean = acc.mean + delta / new_nf;
  98. acc_scalar_t new_delta = data - new_mean;
  99. return {
  100. new_mean,
  101. acc.m2 + delta * new_delta,
  102. new_n,
  103. new_nf,
  104. };
  105. }
  106. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  107. if (a.nf == 0) {
  108. return b;
  109. }
  110. if (b.nf == 0) {
  111. return a;
  112. }
  113. acc_scalar_t delta = b.mean - a.mean;
  114. acc_scalar_t new_count = a.nf + b.nf;
  115. acc_scalar_t nb_over_n = b.nf / new_count;
  116. return {
  117. a.mean + delta * nb_over_n,
  118. a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
  119. // setting acc.n as -1 since acc.n might not be able to represent the count
  120. // correctly within its range, setting it to -1 to avoid confusion
  121. -1,
  122. new_count
  123. };
  124. }
  125. inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
  126. const auto mean = static_cast<scalar_t>(acc.mean);
  127. const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
  128. const auto var = acc.m2 / divisor;
  129. res_t results(take_sqrt ? device_sqrt(var) : var, mean);
  130. return results;
  131. }
  132. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  133. return acc;
  134. }
  135. #if defined(__CUDACC__) || defined(__HIPCC__)
  136. inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
  137. return {
  138. WARP_SHFL_DOWN(acc.mean, offset)
  139. , WARP_SHFL_DOWN(acc.m2, offset)
  140. , WARP_SHFL_DOWN(acc.n, offset)
  141. , WARP_SHFL_DOWN(acc.nf, offset)
  142. };
  143. }
  144. #endif
  145. C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
  146. : correction(correction), take_sqrt(take_sqrt) {}
  147. };
  148. template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
  149. struct MeanOps {
  150. factor_t factor;
  151. inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
  152. return combine(a, static_cast<acc_t>(b));
  153. }
  154. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  155. return a + b;
  156. }
  157. inline C10_DEVICE out_t project(acc_t a) const {
  158. return a * factor;
  159. }
  160. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  161. return acc;
  162. }
  163. #if defined(__CUDACC__) || defined(__HIPCC__)
  164. inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
  165. return WARP_SHFL_DOWN(data, offset);
  166. }
  167. #endif
  168. MeanOps(factor_t factor): factor(factor) {
  169. }
  170. };
  171. // This accumulator template is used to calculate the minimum absolute value of
  172. // a set of numbers.
  173. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  174. // value. These types differ for complex number input support.
  175. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  176. struct AbsMinOps {
  177. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  178. return MIN(acc, static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data))));
  179. }
  180. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  181. return MIN(a, b);
  182. }
  183. inline C10_DEVICE out_t project(acc_t a) const {
  184. return a;
  185. }
  186. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  187. return acc;
  188. }
  189. #if defined(__CUDACC__) || defined(__HIPCC__)
  190. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  191. return WARP_SHFL_DOWN(acc, offset);
  192. }
  193. #endif
  194. };
  195. // This accumulator template is used to calculate the maximum absolute value of
  196. // a set of numbers.
  197. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  198. // value. These types differ for complex number input support.
  199. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  200. struct AbsMaxOps {
  201. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  202. return MAX(acc, static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data))));
  203. }
  204. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  205. return MAX(a, b);
  206. }
  207. inline C10_DEVICE out_t project(acc_t a) const {
  208. return a;
  209. }
  210. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  211. return acc;
  212. }
  213. #if defined(__CUDACC__) || defined(__HIPCC__)
  214. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  215. return WARP_SHFL_DOWN(acc, offset);
  216. }
  217. #endif
  218. };
  219. // This accumulator template is used to calculate the norm of the absolute value
  220. // of a set of numbers.
  221. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  222. // value. These types differ for complex number input support.
  223. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  224. struct NormOps {
  225. acc_t norm_;
  226. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  227. return acc + compat_pow(static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data))), norm_);
  228. }
  229. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  230. return a + b;
  231. }
  232. inline C10_DEVICE out_t project(acc_t a) const {
  233. return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
  234. }
  235. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  236. return acc;
  237. }
  238. #if defined(__CUDACC__) || defined(__HIPCC__)
  239. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  240. return WARP_SHFL_DOWN(acc, offset);
  241. }
  242. #endif
  243. NormOps(acc_t norm_): norm_(norm_) {
  244. }
  245. };
  246. // This accumulator template is used to calculate the order zero norm of the
  247. // absolute value of a set of numbers.
  248. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  249. // value. These types differ for complex number input support.
  250. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  251. struct NormZeroOps {
  252. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  253. return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
  254. }
  255. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  256. return a + b;
  257. }
  258. inline C10_DEVICE out_t project(acc_t a) const {
  259. return a;
  260. }
  261. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  262. return acc;
  263. }
  264. #if defined(__CUDACC__) || defined(__HIPCC__)
  265. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  266. return WARP_SHFL_DOWN(acc, offset);
  267. }
  268. #endif
  269. };
  270. // This accumulator template is used to calculate the order one norm of the
  271. // absolute value of a set of numbers.
  272. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  273. // value. These types differ for complex number input support.
  274. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  275. struct NormOneOps {
  276. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  277. return acc + static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data)));
  278. }
  279. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  280. return a + b;
  281. }
  282. inline C10_DEVICE out_t project(acc_t a) const {
  283. return a;
  284. }
  285. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  286. return acc;
  287. }
  288. #if defined(__CUDACC__) || defined(__HIPCC__)
  289. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  290. return WARP_SHFL_DOWN(acc, offset);
  291. }
  292. #endif
  293. };
  294. template<typename acc_t>
  295. struct AbsSwitch {};
  296. template<typename scalar_t, typename acc_t>
  297. inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
  298. return static_cast<acc_t>(data);
  299. }
  300. template<typename scalar_t, typename acc_t>
  301. inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
  302. return static_cast<acc_t>(std::abs(data));
  303. }
  304. template<typename scalar_t, typename acc_t>
  305. inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
  306. return static_cast<acc_t>(std::abs(at::opmath_type<c10::complex<scalar_t>>(data)));
  307. }
  308. // This accumulator template is used to calculate the order two norm of the
  309. // absolute value of a set of numbers.
  310. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  311. // value. These types differ for complex number input support.
  312. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  313. struct NormTwoOps {
  314. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  315. acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
  316. return acc + data_ * data_;
  317. }
  318. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  319. return a + b;
  320. }
  321. inline C10_DEVICE out_t project(acc_t a) const {
  322. return device_sqrt(a);
  323. }
  324. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  325. return acc;
  326. }
  327. #if defined(__CUDACC__) || defined(__HIPCC__)
  328. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  329. return WARP_SHFL_DOWN(acc, offset);
  330. }
  331. #endif
  332. };
  333. template <typename acc_t, typename data_t>
  334. struct NanSumOps {
  335. inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
  336. return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
  337. }
  338. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  339. return a + b;
  340. }
  341. inline C10_DEVICE data_t project(acc_t a) const {
  342. return data_t{a};
  343. }
  344. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  345. return acc;
  346. }
  347. #if defined(__CUDACC__) || defined(__HIPCC__)
  348. inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
  349. return WARP_SHFL_DOWN(data, offset);
  350. }
  351. #endif
  352. };
  353. namespace detail {
  354. template <typename scalar_t>
  355. struct LessOrNan {
  356. C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
  357. // If (a == b), then choose the one with lower idx, else min(a, b)
  358. if (at::_isnan(a)) {
  359. if (at::_isnan(b)) {
  360. return idx_a < idx_b;
  361. }
  362. return true;
  363. }
  364. return (a == b) ? idx_a < idx_b : (a < b);
  365. }
  366. };
  367. template <typename scalar_t>
  368. struct GreaterOrNan {
  369. C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
  370. // If (a == b), then choose the one with lower idx, else max(a, b)
  371. if (at::_isnan(a)) {
  372. if (at::_isnan(b)) {
  373. return idx_a < idx_b;
  374. }
  375. return true;
  376. }
  377. return (a == b) ? idx_a < idx_b : (a > b);
  378. }
  379. };
  380. template <typename comp_t>
  381. struct MinMaxReductionOps {
  382. using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
  383. using index_t = int64_t;
  384. using arg_t = detail::pair<scalar_t, index_t>;
  385. static C10_DEVICE arg_t project(arg_t arg) {
  386. return arg;
  387. }
  388. static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
  389. return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
  390. }
  391. static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
  392. return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
  393. }
  394. static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
  395. return {a.first, a.second + base_idx};
  396. }
  397. #if defined(__CUDACC__) || defined(__HIPCC__)
  398. static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
  399. return arg_t(WARP_SHFL_DOWN(arg.first, offset),
  400. WARP_SHFL_DOWN(arg.second, offset));
  401. }
  402. #endif
  403. };
  404. template <typename comp_t>
  405. struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
  406. using typename MinMaxReductionOps<comp_t>::scalar_t;
  407. using typename MinMaxReductionOps<comp_t>::index_t;
  408. using typename MinMaxReductionOps<comp_t>::arg_t;
  409. static C10_DEVICE index_t project(arg_t arg) {
  410. return arg.second;
  411. }
  412. };
  413. } // namespace detail
  414. template <typename scalar_t>
  415. struct ArgMaxOps :
  416. public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
  417. };
  418. template <typename scalar_t>
  419. struct ArgMinOps :
  420. public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
  421. };
  422. template <typename scalar_t>
  423. struct MinOps :
  424. public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
  425. };
  426. template <typename scalar_t>
  427. struct MaxOps :
  428. public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
  429. };
  430. template <typename scalar_t, typename acc_scalar_t, typename index_t>
  431. struct MinMaxOps {
  432. using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
  433. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
  434. return combine(acc, {data, data});
  435. }
  436. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  437. auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
  438. auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
  439. return {min_val, max_val};
  440. }
  441. inline C10_DEVICE acc_t project(acc_t acc) const {
  442. return acc;
  443. }
  444. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  445. return acc;
  446. }
  447. #if defined(__CUDACC__) || defined(__HIPCC__)
  448. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  449. return {
  450. WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
  451. };
  452. }
  453. #endif
  454. };
  455. } // namespace at::native
  456. #undef MAX
  457. #undef MIN