Skip to content

Commit c138e51

Browse files
DenisVieriu97kulinseth
authored andcommitted
Add support for native binary ops (#398)
* More fixes for strided kernels * Fix type inference crash * Address comments * More fixes * Remove logs * Clean up * More clean up * More clean up #2 * More clean up #2 * More clean up #3 * Fix metal version * Use native binary kernels * Fix build failure & add more ops * Fix lint * Fix failing tests --------- Co-authored-by: Kulin Seth <kulin_seth@apple.com>
1 parent cb89e79 commit c138e51

File tree

9 files changed

+585
-100
lines changed

9 files changed

+585
-100
lines changed

aten/src/ATen/mps/IndexKernels.h

Lines changed: 97 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,24 @@
33
namespace at {
44
namespace mps {
55

6-
static const char * indexing_metal_shaders = R"INDEX_METAL(
6+
#define GET_IDX_TEMPLATE \
7+
"static inline uint3 get_idx( " \
8+
" uint tid, " \
9+
" constant uint * iter_shape, " \
10+
" const uint num_dimensions, " \
11+
" constant packed_uint3 * strides) {{ " \
12+
" uint3 data_offsets = 0; " \
13+
" uint32_t idx = tid; " \
14+
" for (uint32_t dim = 0; dim < num_dimensions; dim++) {{ " \
15+
" uint32_t remainder = idx % iter_shape[dim]; " \
16+
" idx /= iter_shape[dim]; " \
17+
" data_offsets += remainder * strides[dim]; " \
18+
" }} " \
19+
" return data_offsets; " \
20+
"}}"
21+
22+
static const char * indexing_metal_shaders = GET_IDX_TEMPLATE
23+
R"INDEX_METAL(
724
#include <metal_stdlib>
825
#include <metal_atomic>
926
@@ -18,7 +35,6 @@ struct IndexAB {
1835
struct IndexAB {
1936
constant int64_t* indexArray;
2037
};
21-
2238
#endif
2339
2440
template<typename T>
@@ -30,11 +46,17 @@ kernel void index_select(
3046
#endif
3147
constant void * indexSizes [[buffer(1)]],
3248
constant void * indexStrides [[buffer(2)]],
33-
constant uint3 * offsets [[buffer(3)]],
3449
constant void * inputData [[buffer(4)]],
3550
device void * outputData [[buffer(5)]],
3651
constant uint32_t & num_indices [[buffer(6)]],
52+
constant uint * iter_shape [[buffer(7)]],
53+
constant uint & num_dimensions [[buffer(8)]],
54+
constant packed_uint3 * strides [[buffer(9)]],
55+
3756
uint thread_index [[thread_position_in_grid]]) {
57+
58+
uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
59+
3860
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
3961
constant int64_t * index_strides = (constant int64_t *)indexStrides;
4062
int64_t offset = 0;
@@ -44,14 +66,14 @@ kernel void index_select(
4466
#else
4567
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
4668
#endif
47-
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
69+
int64_t index = indexArray[offsets.z / sizeof(int64_t)];
4870
if (index < 0) {
4971
index += index_sizes[i];
5072
}
5173
offset += index * index_strides[i];
5274
}
53-
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
54-
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
75+
device T * out = (device T*)((device char*)outputData + offsets.x);
76+
constant T * in = (constant T*)((constant char*)inputData + offsets.y + offset);
5577
*out = *in;
5678
}
5779
@@ -64,12 +86,19 @@ kernel void index_put(
6486
#endif
6587
constant void * indexSizes [[buffer(1)]],
6688
constant void * indexStrides [[buffer(2)]],
67-
constant uint3 * offsets [[buffer(3)]],
6889
constant void * inputData [[buffer(4)]],
6990
device void * outputData [[buffer(5)]],
7091
constant uint32_t & num_indices [[buffer(6)]],
92+
93+
constant uint * iter_shape [[buffer(7)]],
94+
constant uint & num_dimensions [[buffer(8)]],
95+
constant packed_uint3 * strides [[buffer(9)]],
96+
7197
uint thread_index [[thread_position_in_grid]]) {
7298
99+
uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
100+
101+
73102
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
74103
constant int64_t * index_strides = (constant int64_t *)indexStrides;
75104
int64_t offset = 0;
@@ -79,15 +108,15 @@ kernel void index_put(
79108
#else
80109
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
81110
#endif
82-
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
111+
int64_t index = indexArray[offsets.z / sizeof(int64_t)];
83112
84113
if (index < 0) {
85114
index += index_sizes[i];
86115
}
87116
offset += index * index_strides[i];
88117
}
89-
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
90-
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
118+
device T * out = (device T*)((device char*)outputData + offsets.x + offset);
119+
constant T * in = (constant T*)((constant char*)inputData + offsets.y);
91120
*out = *in;
92121
}
93122
@@ -96,26 +125,30 @@ kernel void index_put(
96125
template \
97126
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
98127
kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
99-
constant IndexAB & indexAB [[buffer(0)]], \
100-
constant void * indexSizes [[buffer(1)]], \
101-
constant void * indexStrides [[buffer(2)]], \
102-
constant uint3 * offsets [[buffer(3)]], \
103-
constant void * inputData [[buffer(4)]], \
104-
device void * outputData [[buffer(5)]], \
105-
constant uint32_t & num_indices [[buffer(6)]], \
128+
constant IndexAB & indexAB [[buffer(0)]], \
129+
constant void * indexSizes [[buffer(1)]], \
130+
constant void * indexStrides [[buffer(2)]], \
131+
constant void * inputData [[buffer(4)]], \
132+
device void * outputData [[buffer(5)]], \
133+
constant uint32_t & num_indices [[buffer(6)]], \
134+
constant uint * iter_shape [[buffer(7)]], \
135+
constant uint & num_dimensions [[buffer(8)]], \
136+
constant packed_uint3 * strides [[buffer(9)]], \
106137
uint thread_index [[thread_position_in_grid]]);
107138
#else
108139
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
109140
template \
110141
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
111142
kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
112-
constant IndexAB * indexAB [[buffer(0)]], \
113-
constant void * indexSizes [[buffer(1)]], \
114-
constant void * indexStrides [[buffer(2)]], \
115-
constant uint3 * offsets [[buffer(3)]], \
116-
constant void * inputData [[buffer(4)]], \
117-
device void * outputData [[buffer(5)]], \
143+
constant IndexAB * indexAB [[buffer(0)]], \
144+
constant void * indexSizes [[buffer(1)]], \
145+
constant void * indexStrides [[buffer(2)]], \
146+
constant void * inputData [[buffer(4)]], \
147+
device void * outputData [[buffer(5)]], \
118148
constant uint32_t & num_indices [[buffer(6)]], \
149+
constant uint * iter_shape [[buffer(7)]], \
150+
constant uint & num_dimensions [[buffer(8)]], \
151+
constant packed_uint3 * strides [[buffer(9)]], \
119152
uint thread_index [[thread_position_in_grid]]);
120153
#endif
121154
@@ -147,17 +180,20 @@ kernel void kernel_index_offsets(constant packed_uint3 * strides [[buffe
147180
template<typename T, typename E>
148181
kernel void index_put_accumulate_native_dtypes(
149182
#if __METAL_VERSION__ >= 300
150-
constant IndexAB * indexAB [[buffer(0)]],
183+
constant IndexAB * indexAB [[buffer(0)]],
151184
#else
152-
constant IndexAB & indexAB [[buffer(0)]],
185+
constant IndexAB & indexAB [[buffer(0)]],
153186
#endif
154-
constant void * indexSizes [[buffer(1)]],
155-
constant void * indexStrides [[buffer(2)]],
156-
constant uint3 * offsets [[buffer(3)]],
157-
constant void * inputData [[buffer(4)]],
158-
device void * outputData [[buffer(5)]],
159-
constant uint32_t& num_indices [[buffer(6)]],
187+
constant void * indexSizes [[buffer(1)]],
188+
constant void * indexStrides [[buffer(2)]],
189+
constant void * inputData [[buffer(4)]],
190+
device void * outputData [[buffer(5)]],
191+
constant uint32_t & num_indices [[buffer(6)]],
192+
constant uint * iter_shape [[buffer(7)]],
193+
constant uint & num_dimensions [[buffer(8)]],
194+
constant packed_uint3 * strides [[buffer(9)]],
160195
uint thread_index [[thread_position_in_grid]]) {
196+
uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
161197
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
162198
constant int64_t * index_strides = (constant int64_t *)indexStrides;
163199
int64_t offset = 0;
@@ -167,14 +203,14 @@ kernel void index_put_accumulate_native_dtypes(
167203
#else
168204
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
169205
#endif
170-
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
206+
int64_t index = indexArray[offsets.z / sizeof(int64_t)];
171207
if (index < 0) {
172208
index += index_sizes[i];
173209
}
174210
offset += index * index_strides[i];
175211
}
176-
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
177-
constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
212+
device T * out = (device T*)((device char*)outputData + offsets.x + offset);
213+
constant E * in = (constant E*)((constant char*)inputData + offsets.y);
178214
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
179215
}
180216
@@ -191,17 +227,20 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a
191227
template<typename T>
192228
kernel void atomic_index_put_accumulate(
193229
#if __METAL_VERSION__ >= 300
194-
constant IndexAB * indexAB [[buffer(0)]],
230+
constant IndexAB * indexAB [[buffer(0)]],
195231
#else
196-
constant IndexAB & indexAB [[buffer(0)]],
232+
constant IndexAB & indexAB [[buffer(0)]],
197233
#endif
198-
constant void * indexSizes [[buffer(1)]],
199-
constant void * indexStrides [[buffer(2)]],
200-
constant uint3 * offsets [[buffer(3)]],
201-
constant void * inputData [[buffer(4)]],
202-
device void * outputData [[buffer(5)]],
203-
constant uint32_t& num_indices [[buffer(6)]],
234+
constant void * indexSizes [[buffer(1)]],
235+
constant void * indexStrides [[buffer(2)]],
236+
constant void * inputData [[buffer(4)]],
237+
device void * outputData [[buffer(5)]],
238+
constant uint32_t & num_indices [[buffer(6)]],
239+
constant uint * iter_shape [[buffer(7)]],
240+
constant uint & num_dimensions [[buffer(8)]],
241+
constant packed_uint3 * strides [[buffer(9)]],
204242
uint thread_index [[thread_position_in_grid]]) {
243+
uint3 offsets = get_idx(thread_index, iter_shape, num_dimensions, strides);
205244
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
206245
constant int64_t * index_strides = (constant int64_t *)indexStrides;
207246
int64_t offset = 0;
@@ -211,14 +250,14 @@ kernel void atomic_index_put_accumulate(
211250
#else
212251
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
213252
#endif
214-
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
253+
int64_t index = indexArray[offsets.z / sizeof(int64_t)];
215254
if (index < 0) {
216255
index += index_sizes[i];
217256
}
218257
offset += index * index_strides[i];
219258
}
220-
device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
221-
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
259+
device void * out = (device void*)((device char*)outputData + offsets.x + offset);
260+
constant T * in = (constant T*)((constant char*)inputData + offsets.y);
222261
atomic_fetch_add_relaxed<T>(out, *in);
223262
}
224263
@@ -232,26 +271,30 @@ kernel void atomic_index_put_accumulate<float>(
232271
#endif
233272
constant void * indexSizes [[buffer(1)]],
234273
constant void * indexStrides [[buffer(2)]],
235-
constant uint3 * offsets [[buffer(3)]],
236274
constant void * inputData [[buffer(4)]],
237275
device void * outputData [[buffer(5)]],
238276
constant uint32_t& num_indices [[buffer(6)]],
277+
constant uint * iter_shape [[buffer(7)]],
278+
constant uint & num_dimensions [[buffer(8)]],
279+
constant packed_uint3 * strides [[buffer(9)]],
239280
uint thread_index [[thread_position_in_grid]]);
240281
241282
template
242283
[[host_name("index_put_accumulate_32bit_int")]]
243284
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(
244285
#if __METAL_VERSION__ >= 300
245-
constant IndexAB * indexAB [[buffer(0)]],
286+
constant IndexAB * indexAB [[buffer(0)]],
246287
#else
247-
constant IndexAB & indexAB [[buffer(0)]],
288+
constant IndexAB & indexAB [[buffer(0)]],
248289
#endif
249-
constant void * indexSizes [[buffer(1)]],
250-
constant void * indexStrides [[buffer(2)]],
251-
constant uint3 * offsets [[buffer(3)]],
252-
constant void * inputData [[buffer(4)]],
253-
device void * outputData [[buffer(5)]],
254-
constant uint32_t& num_indices [[buffer(6)]],
290+
constant void * indexSizes [[buffer(1)]],
291+
constant void * indexStrides [[buffer(2)]],
292+
constant void * inputData [[buffer(4)]],
293+
device void * outputData [[buffer(5)]],
294+
constant uint32_t& num_indices [[buffer(6)]],
295+
constant uint * iter_shape [[buffer(7)]],
296+
constant uint & num_dimensions [[buffer(8)]],
297+
constant packed_uint3 * strides [[buffer(9)]],
255298
uint thread_index [[thread_position_in_grid]]);
256299
)INDEX_METAL";
257300

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArray
6161
std::string getMPSShapeString(MPSShape* shape);
6262
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
6363
std::string getArrayRefString(const IntArrayRef s);
64+
const std::string& getMetalScalarType(const Tensor& t);
65+
const std::string& getMetalScalarType(const c10::ScalarType& scalar_type);
6466
// use has_storage() on the returned tensor to determine if src actually is a view
6567
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
6668
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
@@ -223,7 +225,6 @@ struct MPSGraphCache
223225
}
224226

225227
MPSCachedGraph* LookUp(const std::string& key) const {
226-
227228
__block MPSCachedGraph* cachedGraph = nullptr;
228229

229230
MPSCacheKey hash = std::hash<std::string>{}(key);

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,27 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
223223
return ss.str();
224224
}
225225

226+
const std::string& getMetalScalarType(const c10::ScalarType& scalar_type) {
227+
static std::unordered_map<c10::ScalarType, std::string> scalarToMetalType = {
228+
{c10::ScalarType::Float, "float"},
229+
{c10::ScalarType::Half, "half"},
230+
{c10::ScalarType::Long, "long"},
231+
{c10::ScalarType::Int, "int"},
232+
{c10::ScalarType::Short, "short"},
233+
{c10::ScalarType::Char, "char"},
234+
{c10::ScalarType::Byte, "uchar"},
235+
{c10::ScalarType::Bool, "bool"},
236+
};
237+
238+
auto it = scalarToMetalType.find(scalar_type);
239+
TORCH_CHECK(it != scalarToMetalType.end(), "Unsupported type byte size: ", scalar_type);
240+
return it->second;
241+
}
242+
243+
const std::string& getMetalScalarType(const Tensor& t) {
244+
return getMetalScalarType(t.scalar_type());
245+
}
246+
226247
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
227248
std::string str;
228249
// The key format per tensor would look like ":Float32[1,1,1,10]:"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright © 2023 Apple Inc.
2+
#pragma once
3+
4+
namespace at {
5+
namespace native {
6+
namespace mps {
7+
bool dispatchNativeBinaryKernel(const Tensor& self,
8+
const Tensor& other,
9+
const Tensor& output,
10+
const Scalar& alpha,
11+
const std::string& op_name);
12+
}
13+
}
14+
}

0 commit comments

Comments
 (0)