atomic.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #pragma once
  2. #include <metal_atomic>
  3. namespace c10 {
  4. namespace metal {
  5. // Atomic operations helper
  6. template <typename T>
  7. struct AtomicType {};
  8. template <typename T>
  9. using AtomicType_t = typename AtomicType<T>::type;
  10. template <>
  11. struct AtomicType<float> {
  12. using type = ::metal::atomic<float>;
  13. static inline void atomic_add(device type* data, long offset, float value) {
  14. ::metal::atomic_fetch_add_explicit(
  15. data + offset, value, ::metal::memory_order_relaxed);
  16. }
  17. };
  18. template <>
  19. struct AtomicType<int> {
  20. using type = ::metal::atomic<int>;
  21. static inline void atomic_add(device type* data, long offset, int value) {
  22. ::metal::atomic_fetch_add_explicit(
  23. data + offset, value, ::metal::memory_order_relaxed);
  24. }
  25. };
  26. // As of Metal3.2 atomic operations are not supported on half-precision floats,
  27. // so they must be simulated Using atomic compare and exchange over 32-bit
  28. // atomic type
  29. template <typename T>
  30. static inline void atomic_add_helper(
  31. device ::metal::atomic<uint>* data,
  32. long offset,
  33. T value) {
  34. constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
  35. auto ptr = data + (offset / elem_per_enum);
  36. auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
  37. union {
  38. uint i;
  39. T t[elem_per_enum];
  40. } val;
  41. do {
  42. val.i = old;
  43. val.t[offset & (elem_per_enum - 1)] += value;
  44. } while (!::metal::atomic_compare_exchange_weak_explicit(
  45. ptr,
  46. &old,
  47. val.i,
  48. ::metal::memory_order_relaxed,
  49. ::metal::memory_order_relaxed));
  50. }
  51. template <>
  52. struct AtomicType<half> {
  53. using type = ::metal::atomic<uint>;
  54. static inline void atomic_add(device type* data, long offset, half value) {
  55. atomic_add_helper(data, offset, value);
  56. }
  57. };
  58. template <>
  59. struct AtomicType<short> {
  60. using type = ::metal::atomic<uint>;
  61. static inline void atomic_add(device type* data, long offset, short value) {
  62. atomic_add_helper(data, offset, value);
  63. }
  64. };
  65. template <>
  66. struct AtomicType<char> {
  67. using type = ::metal::atomic<uint>;
  68. static inline void atomic_add(device type* data, long offset, char value) {
  69. atomic_add_helper(data, offset, value);
  70. }
  71. };
  72. template <>
  73. struct AtomicType<uchar> {
  74. using type = ::metal::atomic<uint>;
  75. static inline void atomic_add(device type* data, long offset, char value) {
  76. atomic_add_helper(data, offset, value);
  77. }
  78. };
  79. template <>
  80. struct AtomicType<bfloat> {
  81. using type = ::metal::atomic<uint>;
  82. static inline void atomic_add(device type* data, long offset, bfloat value) {
  83. atomic_add_helper<bfloat>(data, offset, value);
  84. }
  85. };
  86. // Metal supports atomic_store_explicit for bools, but
  87. // sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
  88. // atomically modify unaligned memory, so fall back to compare and exchange
  89. // trick As accumulation over booleans are just or operation, do nothing if
  90. // value is false
  91. template <>
  92. struct AtomicType<bool> {
  93. using type = ::metal::atomic<uint>;
  94. static inline void atomic_add(device type* data, long offset, bool value) {
  95. if (!value) {
  96. return;
  97. }
  98. auto ptr = data + (offset >> 2);
  99. auto old =
  100. ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
  101. union {
  102. uint i;
  103. bool t[4];
  104. } val;
  105. do {
  106. val.i = old;
  107. val.t[offset & 3] = true;
  108. } while (!::metal::atomic_compare_exchange_weak_explicit(
  109. ptr,
  110. &old,
  111. val.i,
  112. ::metal::memory_order_relaxed,
  113. ::metal::memory_order_relaxed));
  114. }
  115. };
  116. // ComplexHalf atomic op
  117. template <>
  118. struct AtomicType<half2> {
  119. using type = ::metal::atomic<uint>;
  120. static inline void atomic_add(device type* data, long offset, half2 value) {
  121. auto ptr = data + offset;
  122. auto old =
  123. ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
  124. while (!::metal::atomic_compare_exchange_weak_explicit(
  125. ptr,
  126. &old,
  127. as_type<uint>(as_type<half2>(old) + value),
  128. ::metal::memory_order_relaxed,
  129. ::metal::memory_order_relaxed))
  130. ;
  131. }
  132. };
  133. // There are no atomic 64-bit add in Metal yet, but templates below implements a
  134. // consistent add I.e. if multiple threads are modify the same 64-bit value,
  135. // results stored at the address will eventually be equal to its original value
  136. // plus sum of all operands
  137. template <>
  138. struct AtomicType<long> {
  139. using type = ::metal::atomic<uint>;
  140. static inline void atomic_add(device type* data, long offset, long value) {
  141. const auto value_bits = as_type<ulong>(value);
  142. const uint low = static_cast<uint>(value_bits);
  143. uint high = static_cast<uint>(value_bits >> 32);
  144. auto ptr = data + (offset << 1);
  145. auto old_low =
  146. atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed);
  147. high += (old_low + low < old_low) ? 1 : 0;
  148. atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed);
  149. }
  150. };
  151. // ComplexFloat atomic op, which again is not really atomic, but eventually
  152. // consistent
  153. template <>
  154. struct AtomicType<float2> {
  155. using type = ::metal::atomic<float>;
  156. static inline void atomic_add(device type* data, long offset, float2 value) {
  157. auto ptr = data + (offset << 1);
  158. atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed);
  159. atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed);
  160. }
  161. };
  162. } // namespace metal
  163. } // namespace c10