33namespace at {
44namespace mps {
55
6- static const char * indexing_metal_shaders = R"INDEX_METAL(
6+ #define GET_IDX_TEMPLATE \
7+ " static inline uint3 get_idx( " \
8+ " uint tid, " \
9+ " constant uint * iter_shape, " \
10+ " const uint num_dimensions, " \
11+ " constant packed_uint3 * strides) {{ " \
12+ " uint3 data_offsets = 0; " \
13+ " uint32_t idx = tid; " \
14+ " for (uint32_t dim = 0; dim < num_dimensions; dim++) {{ " \
15+ " uint32_t remainder = idx % iter_shape[dim]; " \
16+ " idx /= iter_shape[dim]; " \
17+ " data_offsets += remainder * strides[dim]; " \
18+ " }} " \
19+ " return data_offsets; " \
20+ " }}"
21+
22+ static const char * indexing_metal_shaders = GET_IDX_TEMPLATE
23+ R"INDEX_METAL(
724#include <metal_stdlib>
825#include <metal_atomic>
926
@@ -18,7 +35,6 @@ struct IndexAB {
1835struct IndexAB {
1936 constant int64_t* indexArray;
2037};
21-
2238#endif
2339
2440template<typename T>
@@ -30,11 +46,17 @@ kernel void index_select(
3046#endif
3147 constant void * indexSizes [[buffer(1)]],
3248 constant void * indexStrides [[buffer(2)]],
33- constant uint3 * offsets [[buffer(3)]],
3449 constant void * inputData [[buffer(4)]],
3550 device void * outputData [[buffer(5)]],
3651 constant uint32_t & num_indices [[buffer(6)]],
52+ constant uint * iter_shape [[buffer(7)]],
53+ constant uint & num_dimensions [[buffer(8)]],
54+ constant packed_uint3 * strides [[buffer(9)]],
55+
3756 uint thread_index [[thread_position_in_grid]]) {
57+
58+ uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
59+
3860 constant int64_t * index_sizes = (constant int64_t *)indexSizes;
3961 constant int64_t * index_strides = (constant int64_t *)indexStrides;
4062 int64_t offset = 0;
@@ -44,14 +66,14 @@ kernel void index_select(
4466#else
4567 constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
4668#endif
47- int64_t index = indexArray[offsets[thread_index] .z / sizeof(int64_t)];
69+ int64_t index = indexArray[offsets.z / sizeof(int64_t)];
4870 if (index < 0) {
4971 index += index_sizes[i];
5072 }
5173 offset += index * index_strides[i];
5274 }
53- device T * out = (device T*)((device char*)outputData + offsets[thread_index] .x);
54- constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index] .y + offset);
75+ device T * out = (device T*)((device char*)outputData + offsets.x);
76+ constant T * in = (constant T*)((constant char*)inputData + offsets.y + offset);
5577 *out = *in;
5678}
5779
@@ -64,12 +86,19 @@ kernel void index_put(
6486#endif
6587 constant void * indexSizes [[buffer(1)]],
6688 constant void * indexStrides [[buffer(2)]],
67- constant uint3 * offsets [[buffer(3)]],
6889 constant void * inputData [[buffer(4)]],
6990 device void * outputData [[buffer(5)]],
7091 constant uint32_t & num_indices [[buffer(6)]],
92+
93+ constant uint * iter_shape [[buffer(7)]],
94+ constant uint & num_dimensions [[buffer(8)]],
95+ constant packed_uint3 * strides [[buffer(9)]],
96+
7197 uint thread_index [[thread_position_in_grid]]) {
7298
99+ uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
100+
101+
73102 constant int64_t * index_sizes = (constant int64_t *)indexSizes;
74103 constant int64_t * index_strides = (constant int64_t *)indexStrides;
75104 int64_t offset = 0;
@@ -79,15 +108,15 @@ kernel void index_put(
79108#else
80109 constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
81110#endif
82- int64_t index = indexArray[offsets[thread_index] .z / sizeof(int64_t)];
111+ int64_t index = indexArray[offsets.z / sizeof(int64_t)];
83112
84113 if (index < 0) {
85114 index += index_sizes[i];
86115 }
87116 offset += index * index_strides[i];
88117 }
89- device T * out = (device T*)((device char*)outputData + offsets[thread_index] .x + offset);
90- constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index] .y);
118+ device T * out = (device T*)((device char*)outputData + offsets.x + offset);
119+ constant T * in = (constant T*)((constant char*)inputData + offsets.y);
91120 *out = *in;
92121}
93122
@@ -96,26 +125,30 @@ kernel void index_put(
96125template \
97126[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
98127kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
99- constant IndexAB & indexAB [[buffer(0)]], \
100- constant void * indexSizes [[buffer(1)]], \
101- constant void * indexStrides [[buffer(2)]], \
102- constant uint3 * offsets [[buffer(3)]], \
103- constant void * inputData [[buffer(4)]], \
104- device void * outputData [[buffer(5)]], \
105- constant uint32_t & num_indices [[buffer(6)]], \
128+ constant IndexAB & indexAB [[buffer(0)]], \
129+ constant void * indexSizes [[buffer(1)]], \
130+ constant void * indexStrides [[buffer(2)]], \
131+ constant void * inputData [[buffer(4)]], \
132+ device void * outputData [[buffer(5)]], \
133+ constant uint32_t & num_indices [[buffer(6)]], \
134+ constant uint * iter_shape [[buffer(7)]], \
135+ constant uint & num_dimensions [[buffer(8)]], \
136+ constant packed_uint3 * strides [[buffer(9)]], \
106137 uint thread_index [[thread_position_in_grid]]);
107138#else
108139#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
109140template \
110141[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
111142kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
112- constant IndexAB * indexAB [[buffer(0)]], \
113- constant void * indexSizes [[buffer(1)]], \
114- constant void * indexStrides [[buffer(2)]], \
115- constant uint3 * offsets [[buffer(3)]], \
116- constant void * inputData [[buffer(4)]], \
117- device void * outputData [[buffer(5)]], \
143+ constant IndexAB * indexAB [[buffer(0)]], \
144+ constant void * indexSizes [[buffer(1)]], \
145+ constant void * indexStrides [[buffer(2)]], \
146+ constant void * inputData [[buffer(4)]], \
147+ device void * outputData [[buffer(5)]], \
118148 constant uint32_t & num_indices [[buffer(6)]], \
149+ constant uint * iter_shape [[buffer(7)]], \
150+ constant uint & num_dimensions [[buffer(8)]], \
151+ constant packed_uint3 * strides [[buffer(9)]], \
119152 uint thread_index [[thread_position_in_grid]]);
120153#endif
121154
@@ -147,17 +180,20 @@ kernel void kernel_index_offsets(constant packed_uint3 * strides [[buffe
147180template<typename T, typename E>
148181kernel void index_put_accumulate_native_dtypes(
149182#if __METAL_VERSION__ >= 300
150- constant IndexAB * indexAB [[buffer(0)]],
183+ constant IndexAB * indexAB [[buffer(0)]],
151184#else
152- constant IndexAB & indexAB [[buffer(0)]],
185+ constant IndexAB & indexAB [[buffer(0)]],
153186#endif
154- constant void * indexSizes [[buffer(1)]],
155- constant void * indexStrides [[buffer(2)]],
156- constant uint3 * offsets [[buffer(3)]],
157- constant void * inputData [[buffer(4)]],
158- device void * outputData [[buffer(5)]],
159- constant uint32_t& num_indices [[buffer(6)]],
187+ constant void * indexSizes [[buffer(1)]],
188+ constant void * indexStrides [[buffer(2)]],
189+ constant void * inputData [[buffer(4)]],
190+ device void * outputData [[buffer(5)]],
191+ constant uint32_t & num_indices [[buffer(6)]],
192+ constant uint * iter_shape [[buffer(7)]],
193+ constant uint & num_dimensions [[buffer(8)]],
194+ constant packed_uint3 * strides [[buffer(9)]],
160195 uint thread_index [[thread_position_in_grid]]) {
196+ uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
161197 constant int64_t * index_sizes = (constant int64_t *)indexSizes;
162198 constant int64_t * index_strides = (constant int64_t *)indexStrides;
163199 int64_t offset = 0;
@@ -167,14 +203,14 @@ kernel void index_put_accumulate_native_dtypes(
167203#else
168204 constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
169205#endif
170- int64_t index = indexArray[offsets[thread_index] .z / sizeof(int64_t)];
206+ int64_t index = indexArray[offsets.z / sizeof(int64_t)];
171207 if (index < 0) {
172208 index += index_sizes[i];
173209 }
174210 offset += index * index_strides[i];
175211 }
176- device T * out = (device T*)((device char*)outputData + offsets[thread_index] .x + offset);
177- constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index] .y);
212+ device T * out = (device T*)((device char*)outputData + offsets.x + offset);
213+ constant E * in = (constant E*)((constant char*)inputData + offsets.y);
178214 atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
179215}
180216
@@ -191,17 +227,20 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a
191227template<typename T>
192228kernel void atomic_index_put_accumulate(
193229#if __METAL_VERSION__ >= 300
194- constant IndexAB * indexAB [[buffer(0)]],
230+ constant IndexAB * indexAB [[buffer(0)]],
195231#else
196- constant IndexAB & indexAB [[buffer(0)]],
232+ constant IndexAB & indexAB [[buffer(0)]],
197233#endif
198- constant void * indexSizes [[buffer(1)]],
199- constant void * indexStrides [[buffer(2)]],
200- constant uint3 * offsets [[buffer(3)]],
201- constant void * inputData [[buffer(4)]],
202- device void * outputData [[buffer(5)]],
203- constant uint32_t& num_indices [[buffer(6)]],
234+ constant void * indexSizes [[buffer(1)]],
235+ constant void * indexStrides [[buffer(2)]],
236+ constant void * inputData [[buffer(4)]],
237+ device void * outputData [[buffer(5)]],
238+ constant uint32_t & num_indices [[buffer(6)]],
239+ constant uint * iter_shape [[buffer(7)]],
240+ constant uint & num_dimensions [[buffer(8)]],
241+ constant packed_uint3 * strides [[buffer(9)]],
204242 uint thread_index [[thread_position_in_grid]]) {
243+ uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
205244 constant int64_t * index_sizes = (constant int64_t *)indexSizes;
206245 constant int64_t * index_strides = (constant int64_t *)indexStrides;
207246 int64_t offset = 0;
@@ -211,14 +250,14 @@ kernel void atomic_index_put_accumulate(
211250#else
212251 constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
213252#endif
214- int64_t index = indexArray[offsets[thread_index] .z / sizeof(int64_t)];
253+ int64_t index = indexArray[offsets.z / sizeof(int64_t)];
215254 if (index < 0) {
216255 index += index_sizes[i];
217256 }
218257 offset += index * index_strides[i];
219258 }
220- device void * out = (device void*)((device char*)outputData + offsets[thread_index] .x + offset);
221- constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index] .y);
259+ device void * out = (device void*)((device char*)outputData + offsets.x + offset);
260+ constant T * in = (constant T*)((constant char*)inputData + offsets.y);
222261 atomic_fetch_add_relaxed<T>(out, *in);
223262}
224263
@@ -232,26 +271,30 @@ kernel void atomic_index_put_accumulate<float>(
232271#endif
233272 constant void * indexSizes [[buffer(1)]],
234273 constant void * indexStrides [[buffer(2)]],
235- constant uint3 * offsets [[buffer(3)]],
236274 constant void * inputData [[buffer(4)]],
237275 device void * outputData [[buffer(5)]],
238276 constant uint32_t& num_indices [[buffer(6)]],
277+ constant uint * iter_shape [[buffer(7)]],
278+ constant uint & num_dimensions [[buffer(8)]],
279+ constant packed_uint3 * strides [[buffer(9)]],
239280 uint thread_index [[thread_position_in_grid]]);
240281
241282template
242283[[host_name("index_put_accumulate_32bit_int")]]
243284kernel void index_put_accumulate_native_dtypes<atomic_int, int>(
244285#if __METAL_VERSION__ >= 300
245- constant IndexAB * indexAB [[buffer(0)]],
286+ constant IndexAB * indexAB [[buffer(0)]],
246287#else
247- constant IndexAB & indexAB [[buffer(0)]],
288+ constant IndexAB & indexAB [[buffer(0)]],
248289#endif
249- constant void * indexSizes [[buffer(1)]],
250- constant void * indexStrides [[buffer(2)]],
251- constant uint3 * offsets [[buffer(3)]],
252- constant void * inputData [[buffer(4)]],
253- device void * outputData [[buffer(5)]],
254- constant uint32_t& num_indices [[buffer(6)]],
290+ constant void * indexSizes [[buffer(1)]],
291+ constant void * indexStrides [[buffer(2)]],
292+ constant void * inputData [[buffer(4)]],
293+ device void * outputData [[buffer(5)]],
294+ constant uint32_t& num_indices [[buffer(6)]],
295+ constant uint * iter_shape [[buffer(7)]],
296+ constant uint & num_dimensions [[buffer(8)]],
297+ constant packed_uint3 * strides [[buffer(9)]],
255298 uint thread_index [[thread_position_in_grid]]);
256299)INDEX_METAL" ;
257300
0 commit comments