Skip to content

Commit 9421fd8

Browse files
committed
Fix cast issue in scatter() with uint8 type (#178)
* Fix cast issue in scatter() with uint8 type Also fix float16 for scatter_add() * Remove the redundant line in scatter()
1 parent b6f836a commit 9421fd8

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

aten/src/ATen/native/mps/operations/ScatterGather.mm

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
// Do we need to slice into the src tensor?
174174
bool needSlice = false;
175175
bool inputNeedSlice = false;
176+
bool needsCast = false;
176177

177178
for (const auto i : c10::irange(num_input_dims)) {
178179
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
@@ -184,8 +185,12 @@
184185
}
185186
TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS")
186187

187-
bool needsCast = isIntegralType(self.scalar_type(), true) &&
188-
(reduce != "set" || self.scalar_type() == ScalarType::Byte);
188+
MPSDataType src_type = getMPSDataType(src.scalar_type());
189+
if (reduce != "set" || self.scalar_type() == ScalarType::Byte) {
190+
src_type = isFloatingType(src.scalar_type()) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
191+
needsCast = true;
192+
}
193+
189194
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce);
190195
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
191196
if(!cachedGraph) {
@@ -196,55 +201,53 @@
196201
MPSGraph* mpsGraph = make_mps_graph();
197202
newCachedGraph = new CachedGraph(mpsGraph);
198203

199-
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), input_shape);
200-
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(index.scalar_type()), index_shape);
201-
MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(src.scalar_type()), src_shape);
204+
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
205+
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
206+
MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
202207

203-
MPSGraphTensor* getSrc = srcTensor;
204-
MPSGraphTensor* getInput = inputTensor;
208+
MPSGraphTensor* outputTensor = nil;
209+
MPSGraphTensor* castSrcTensor = srcTensor;
210+
MPSGraphTensor* castInputTensor = inputTensor;
211+
212+
if (needsCast) {
213+
castSrcTensor = [mpsGraph castTensor:srcTensor toType:src_type name:@"cast"];
214+
castInputTensor = [mpsGraph castTensor:inputTensor toType:src_type name:@"cast"];
215+
}
216+
MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor toType:MPSDataTypeInt32 name:@"cast"];
217+
218+
MPSGraphTensor* slicedSrc = castSrcTensor;
219+
MPSGraphTensor* slicedInput = castInputTensor;
205220

206221
// Use in case input needs to be smaller to get scatter
207-
NSMutableArray<NSNumber*>* scatterInputShape = [NSMutableArray arrayWithArray:input_shape];;
222+
NSMutableArray<NSNumber*>* scatterInputShape = [NSMutableArray arrayWithArray:input_shape];
208223

209-
// Slice into the src tensor IF NEEDED
224+
// Slice into the src or input tensors IF NEEDED
210225
if (needSlice || inputNeedSlice) {
211226
NSMutableArray<NSNumber*> *starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
212227
NSMutableArray<NSNumber*> *strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
213228
NSMutableArray<NSNumber*> *ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
214229

215230
for (const auto i : c10::irange(num_input_dims)) {
216-
// All strides are 1
217231
strides[i] = @1;
218-
// All starts are 0
219232
starts[i] = @0;
220233
ends_src[i] = index_shape[i];
221234
scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i];
222235
}
223236
if (needSlice) {
224-
getSrc = [mpsGraph sliceTensor:srcTensor
237+
slicedSrc = [mpsGraph sliceTensor:castSrcTensor
225238
starts:starts
226239
ends:ends_src
227240
strides:strides
228241
name:nil];
229242
}
230243
if (inputNeedSlice) {
231-
getInput = [mpsGraph sliceTensor:inputTensor
244+
slicedInput = [mpsGraph sliceTensor:castInputTensor
232245
starts:starts
233246
ends:scatterInputShape
234247
strides:strides
235248
name:nil];
236249
}
237250
}
238-
MPSGraphTensor* outputTensor = nil;
239-
MPSGraphTensor* castSrcTensor = getSrc;
240-
MPSGraphTensor* castInputTensor = getInput;
241-
242-
if (needsCast) {
243-
castSrcTensor = castMPSTensor(mpsGraph, getSrc, ScalarType::Int);
244-
castInputTensor = castMPSTensor(mpsGraph, getInput, ScalarType::Int);
245-
}
246-
MPSGraphTensor* castIndexTensor = castMPSTensor(mpsGraph, indexTensor, ScalarType::Int);
247-
248251
MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet;
249252

250253
if(reduce == "sum" || reduce == "add")
@@ -258,8 +261,8 @@
258261

259262
// Scatter this into the input with set mode
260263
MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim
261-
withDataTensor: castInputTensor
262-
updatesTensor: castSrcTensor
264+
withDataTensor: slicedInput
265+
updatesTensor: slicedSrc
263266
indicesTensor: castIndexTensor
264267
mode: scatter_mode
265268
name: nil];
@@ -301,7 +304,7 @@
301304
withShape:@[@-1]
302305
name:nil];
303306

304-
outputTensor = [mpsGraph scatterNDWithDataTensor:inputTensor
307+
outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor
305308
updatesTensor:flatValuesTensor
306309
indicesTensor:scatter_fullIndexTensor
307310
batchDimensions:0

test/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7647,7 +7647,7 @@ class TestConsistency(TestCase):
76477647
'round': ['f32', 'f16', 'i16', 'i32', 'i64'],
76487648
'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'],
76497649
'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7650-
'scatter_add': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
7650+
'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
76517651
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
76527652
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
76537653
'short': ['i16'],

0 commit comments

Comments
 (0)