Skip to content

Commit 5974b4b

Browse files
Ronian526kulinseth
authored andcommitted
Fix index_add type issue (#235)
* - changing alphaTensor type into using self.scalar_type() - remove index_add from blocklist - block bool,int16,int64,float16,uint8 as scatterWithDataTensor giving wrong results * - casting tensors so that scatterWithDataTensor uses only float32 and int32 - throw an error for unsupported index_add for int64
1 parent e000329 commit 5974b4b

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,14 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
503503
using namespace mps;
504504
MPSStream* stream = getCurrentMPSStream();
505505
dim = maybe_wrap_dim(dim, self.dim());
506-
if (index.numel() == 0) {
506+
auto numel = index.numel();
507+
508+
if (numel == 0) {
507509
return;
508510
}
509511

510-
TORCH_CHECK(source.scalar_type() != ScalarType::Long, "index_add(): Expected non int64 dtype for source.");
511-
auto casted_type = isFloatingType(source.scalar_type()) ? ScalarType::Float : ScalarType::Int;
512+
TORCH_CHECK(self.scalar_type() != ScalarType::Long,
513+
"MPS: does not support index_add op with int64 input");
512514

513515
struct CachedGraph : public MPSCachedGraph
514516
{
@@ -538,26 +540,46 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
538540
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
539541
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
540542
MPSGraphTensor* sourceTensor = mpsGraphRankedPlaceHolder(mpsGraph, source);
541-
MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(casted_type));
542-
MPSGraphTensor* castedInputTensor = inputTensor;
543-
MPSGraphTensor* castedSourceTensor = sourceTensor;
544-
if (source.scalar_type() != casted_type) {
545-
castedInputTensor = castMPSTensor(mpsGraph, castedInputTensor, casted_type);
546-
castedSourceTensor = castMPSTensor(mpsGraph, castedSourceTensor, casted_type);
543+
MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()));
544+
545+
MPSGraphTensor* castInputTensor = inputTensor;
546+
MPSGraphTensor* castSourceTensor = sourceTensor;
547+
MPSGraphTensor* castAlphaTensor = alphaTensor;
548+
549+
MPSDataType dataType = [inputTensor dataType];
550+
551+
// failure due to issue #104289647: Wrong results from scatterWithDataTensor
552+
if (dataType != MPSDataTypeInt32 &&
553+
dataType != MPSDataTypeFloat32) {
554+
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
555+
castInputTensor = [mpsGraph castTensor:inputTensor
556+
toType:dataType
557+
name:@"castInputTensor"];
558+
castSourceTensor = [mpsGraph castTensor:sourceTensor
559+
toType:dataType
560+
name:@"castSourceTensor"];
561+
castAlphaTensor = [mpsGraph castTensor:alphaTensor
562+
toType:dataType
563+
name:@"castAlphaTensor"];
547564
}
548-
MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor:castedSourceTensor
549-
secondaryTensor:alphaTensor
550-
name:nil];
551565

552-
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:castedInputTensor
566+
MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor:castSourceTensor
567+
secondaryTensor:castAlphaTensor
568+
name:nil];
569+
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:castInputTensor
553570
updatesTensor:alphaSourceSlice
554571
indicesTensor:indexTensor
555572
axis:dim
556573
mode:MPSGraphScatterModeAdd
557574
name:nil];
558-
if (source.scalar_type() != casted_type) {
559-
outputTensor = castMPSTensor(mpsGraph, outputTensor, source.scalar_type());
560-
}
575+
dataType = [inputTensor dataType];
576+
if (dataType != MPSDataTypeInt32 &&
577+
dataType != MPSDataTypeFloat32) {
578+
outputTensor = [mpsGraph castTensor:outputTensor
579+
toType:[inputTensor dataType]
580+
name:@"castOutputTensor"];
581+
}
582+
561583
newCachedGraph->inputTensor_ = inputTensor;
562584
newCachedGraph->indexTensor_ = indexTensor;
563585
newCachedGraph->sourceTensor_ = sourceTensor;
@@ -572,7 +594,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
572594
Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index);
573595
Placeholder sourcePlaceholder = Placeholder(cachedGraph->sourceTensor_, source);
574596
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
575-
MPSScalar alpha_scalar = getMPSScalar(alpha, casted_type);
597+
MPSScalar alpha_scalar = getMPSScalar(alpha, self.scalar_type());
576598

577599
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
578600
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),

test/test_mps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9238,7 +9238,8 @@ class TestConsistency(TestCase):
92389238
'xlogy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
92399239
'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
92409240
'zeros': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
9241-
'zeros_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8']
9241+
'zeros_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
9242+
'index_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
92429243
}
92439244

92449245
ALLOWLIST_OP_GRAD = {
@@ -9421,7 +9422,6 @@ class TestConsistency(TestCase):
94219422
# All the entries in this list should be removed
94229423
BLOCKLIST = {
94239424
# Functions that hard crash
9424-
'index_add': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
94259425
'nn.functional.softplus': [torch.float32],
94269426
'nonzero': [torch.bool, torch.uint8, torch.float16],
94279427
'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16],

0 commit comments

Comments
 (0)