IndexKernels.h 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #pragma once
  2. namespace at::mps {
  3. static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
  4. template<typename Y, typename X>
  5. Y cast(const X x);
  6. template<>
  7. {1} cast<{1}, {0}>(const {0} x) {{
  8. return {2};
  9. }}
  10. kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]],
  11. constant void * src_ [[buffer(0)]],
  12. device void * dst_ [[buffer(1)]],
  13. constant uint32_t * size [[buffer(2)]],
  14. constant uint32_t * stride [[buffer(3)]],
  15. constant uint32_t & numel [[buffer(4)]],
  16. constant int32_t & ndim [[buffer(5)]]) {{
  17. if (linear_index >= numel) return;
  18. constant {0} * src = (constant {0} *)src_;
  19. device {1} * dst = (device {1} *)dst_;
  20. uint64_t dst_offs = 0;
  21. auto dst_idx = linear_index;
  22. for(int dim = ndim - 1; dim >= 0; --dim) {{
  23. dst_offs += stride[dim] * (dst_idx % size[dim]);
  24. dst_idx /= size[dim];
  25. }}
  26. dst[dst_offs] = cast<{1}>(src[linear_index]);
  27. }}
  28. kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
  29. constant void * src_ [[buffer(0)]],
  30. device void * dst_ [[buffer(1)]],
  31. constant packed_uint4 & size [[buffer(2)]],
  32. constant packed_uint4 & stride [[buffer(3)]],
  33. constant uint32_t & numel [[buffer(4)]]) {{
  34. if (linear_index >= numel) return;
  35. constant {0} * src = (constant {0} *)src_;
  36. device {1} * dst = (device {1} *)dst_;
  37. packed_uint4 local_index;
  38. local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
  39. local_index.y = linear_index / (size[3] * size[2]) % size[1];
  40. local_index.z = linear_index / size[3] % size[2];
  41. local_index.w = linear_index % size[3];
  42. const packed_uint4 strided_index = local_index * stride;
  43. dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
  44. }}
  45. kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
  46. constant void * src_ [[buffer(0)]],
  47. device void * dst_ [[buffer(1)]],
  48. constant packed_uint3 & size [[buffer(2)]],
  49. constant packed_uint3 & stride [[buffer(3)]],
  50. constant uint32_t & numel [[buffer(4)]]) {{
  51. if (linear_index >= numel) return;
  52. constant {0} * src = (constant {0} *)src_;
  53. device {1} * dst = (device {1} *)dst_;
  54. packed_uint3 local_index;
  55. local_index.x = linear_index / (size[2] * size[1]) % size[0];
  56. local_index.y = linear_index / size[2] % size[1];
  57. local_index.z = linear_index % size[2];
  58. const packed_uint3 strided_index = local_index * stride;
  59. dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
  60. }}
  61. kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
  62. constant void * src_ [[buffer(0)]],
  63. device void * dst_ [[buffer(1)]],
  64. constant packed_uint2 & size [[buffer(2)]],
  65. constant packed_uint2 & stride [[buffer(3)]],
  66. constant uint32_t & numel [[buffer(4)]]) {{
  67. if (linear_index >= numel) return;
  68. constant {0} * src = (constant {0} *)src_;
  69. device {1} * dst = (device {1} *)dst_;
  70. packed_uint2 local_index;
  71. local_index.x = linear_index / size[1] % size[0];
  72. local_index.y = linear_index % size[1];
  73. const packed_uint2 strided_index = local_index * stride;
  74. dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
  75. }}
  76. kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
  77. constant void * src_ [[buffer(0)]],
  78. device void * dst_ [[buffer(1)]],
  79. constant int & size [[buffer(2)]],
  80. constant int & stride [[buffer(3)]],
  81. constant uint32_t & numel [[buffer(4)]]) {{
  82. if (linear_index >= numel) return;
  83. constant {0} * src = (constant {0} *)src_;
  84. device {1} * dst = (device {1} *)dst_;
  85. const int local_index = linear_index % size;
  86. const int strided_index = local_index * stride;
  87. dst[strided_index] = cast<{1}>(src[linear_index]);
  88. }}
  89. )METAL_SCATTER";
  90. static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
  91. template<typename Y, typename X>
  92. Y cast(const X x);
  93. template<>
  94. {1} cast<{1}, {0}>(const {0} x) {{
  95. return {2};
  96. }}
  97. kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]],
  98. constant void * src_ [[buffer(0)]],
  99. device void * dst_ [[buffer(1)]],
  100. constant uint32_t * size [[buffer(2)]],
  101. constant uint32_t * stride [[buffer(3)]],
  102. constant uint32_t & numel [[buffer(4)]],
  103. constant int32_t & ndim [[buffer(5)]]) {{
  104. if (linear_index >= numel) return;
  105. constant {0} * src = (constant {0} *)src_;
  106. device {1} * dst = (device {1} *)dst_;
  107. uint64_t src_offs = 0;
  108. auto src_idx = linear_index;
  109. for(int dim = ndim - 1; dim >= 0; --dim) {{
  110. src_offs += stride[dim] * (src_idx % size[dim]);
  111. src_idx /= size[dim];
  112. }}
  113. dst[linear_index] = cast<{1}>(src[src_offs]);
  114. }}
  115. kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
  116. constant void * src_ [[buffer(0)]],
  117. device void * dst_ [[buffer(1)]],
  118. constant packed_uint4 & size [[buffer(2)]],
  119. constant packed_uint4 & stride [[buffer(3)]],
  120. constant uint32_t & numel [[buffer(4)]]) {{
  121. if (linear_index >= numel) return;
  122. constant {0} * src = (constant {0} *)src_;
  123. device {1} * dst = (device {1} *)dst_;
  124. packed_uint4 local_index;
  125. local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
  126. local_index.y = linear_index / (size[3] * size[2]) % size[1];
  127. local_index.z = linear_index / size[3] % size[2];
  128. local_index.w = linear_index % size[3];
  129. const packed_uint4 strided_index = local_index * stride;
  130. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
  131. }}
  132. kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
  133. constant void * src_ [[buffer(0)]],
  134. device void * dst_ [[buffer(1)]],
  135. constant packed_uint3 & size [[buffer(2)]],
  136. constant packed_uint3 & stride [[buffer(3)]],
  137. constant uint32_t & numel [[buffer(4)]]) {{
  138. if (linear_index >= numel) return;
  139. constant {0} * src = (constant {0} *)src_;
  140. device {1} * dst = (device {1} *)dst_;
  141. packed_uint3 local_index;
  142. local_index.x = linear_index / (size[2] * size[1]) % size[0];
  143. local_index.y = linear_index / size[2] % size[1];
  144. local_index.z = linear_index % size[2];
  145. const packed_uint3 strided_index = local_index * stride;
  146. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
  147. }}
  148. kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
  149. constant void * src_ [[buffer(0)]],
  150. device void * dst_ [[buffer(1)]],
  151. constant packed_uint2 & size [[buffer(2)]],
  152. constant packed_uint2 & stride [[buffer(3)]],
  153. constant uint32_t & numel [[buffer(4)]]) {{
  154. if (linear_index >= numel) return;
  155. constant {0} * src = (constant {0} *)src_;
  156. device {1} * dst = (device {1} *)dst_;
  157. packed_uint2 local_index;
  158. local_index.x = linear_index / size[1] % size[0];
  159. local_index.y = linear_index % size[1];
  160. const packed_uint2 strided_index = local_index * stride;
  161. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
  162. }}
  163. kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
  164. constant void * src_ [[buffer(0)]],
  165. device void * dst_ [[buffer(1)]],
  166. constant int & size [[buffer(2)]],
  167. constant int & stride [[buffer(3)]],
  168. constant uint32_t & numel [[buffer(4)]]) {{
  169. if (linear_index >= numel) return;
  170. constant {0} * src = (constant {0} *)src_;
  171. device {1} * dst = (device {1} *)dst_;
  172. const int local_index = linear_index % size;
  173. const int strided_index = local_index * stride;
  174. dst[linear_index] = cast<{1}>(src[strided_index]);
  175. }}
  176. )METAL_GATHER";
  177. } // namespace at::mps