@@ -503,12 +503,14 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
503503 using namespace mps ;
504504 MPSStream* stream = getCurrentMPSStream ();
505505 dim = maybe_wrap_dim (dim, self.dim ());
506- if (index.numel () == 0 ) {
506+ auto numel = index.numel ();
507+
508+ if (numel == 0 ) {
507509 return ;
508510 }
509511
510- TORCH_CHECK (source .scalar_type () != ScalarType::Long, " index_add(): Expected non int64 dtype for source. " );
511- auto casted_type = isFloatingType (source. scalar_type ()) ? ScalarType::Float : ScalarType::Int ;
512+ TORCH_CHECK (self .scalar_type () != ScalarType::Long,
513+ " MPS: does not support index_add op with int64 input " ) ;
512514
513515 struct CachedGraph : public MPSCachedGraph
514516 {
@@ -538,26 +540,46 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
538540 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder (mpsGraph, self);
539541 MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder (mpsGraph, index);
540542 MPSGraphTensor* sourceTensor = mpsGraphRankedPlaceHolder (mpsGraph, source);
541- MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder (mpsGraph, getMPSScalarType (casted_type));
542- MPSGraphTensor* castedInputTensor = inputTensor;
543- MPSGraphTensor* castedSourceTensor = sourceTensor;
544- if (source.scalar_type () != casted_type) {
545- castedInputTensor = castMPSTensor (mpsGraph, castedInputTensor, casted_type);
546- castedSourceTensor = castMPSTensor (mpsGraph, castedSourceTensor, casted_type);
543+ MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder (mpsGraph, getMPSScalarType (self.scalar_type ()));
544+
545+ MPSGraphTensor* castInputTensor = inputTensor;
546+ MPSGraphTensor* castSourceTensor = sourceTensor;
547+ MPSGraphTensor* castAlphaTensor = alphaTensor;
548+
549+ MPSDataType dataType = [inputTensor dataType ];
550+
551+ // failure due to issue #104289647: Wrong results from scatterWithDataTensor
552+ if (dataType != MPSDataTypeInt32 &&
553+ dataType != MPSDataTypeFloat32) {
554+ dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
555+ castInputTensor = [mpsGraph castTensor: inputTensor
556+ toType: dataType
557+ name: @" castInputTensor" ];
558+ castSourceTensor = [mpsGraph castTensor: sourceTensor
559+ toType: dataType
560+ name: @" castSourceTensor" ];
561+ castAlphaTensor = [mpsGraph castTensor: alphaTensor
562+ toType: dataType
563+ name: @" castAlphaTensor" ];
547564 }
548- MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor: castedSourceTensor
549- secondaryTensor: alphaTensor
550- name: nil ];
551565
552- MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor: castedInputTensor
566+ MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor: castSourceTensor
567+ secondaryTensor: castAlphaTensor
568+ name: nil ];
569+ MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor: castInputTensor
553570 updatesTensor: alphaSourceSlice
554571 indicesTensor: indexTensor
555572 axis: dim
556573 mode: MPSGraphScatterModeAdd
557574 name: nil ];
558- if (source.scalar_type () != casted_type) {
559- outputTensor = castMPSTensor (mpsGraph, outputTensor, source.scalar_type ());
560- }
575+ dataType = [inputTensor dataType ];
576+ if (dataType != MPSDataTypeInt32 &&
577+ dataType != MPSDataTypeFloat32) {
578+ outputTensor = [mpsGraph castTensor: outputTensor
579+ toType: [inputTensor dataType ]
580+ name: @" castOutputTensor" ];
581+ }
582+
561583 newCachedGraph->inputTensor_ = inputTensor;
562584 newCachedGraph->indexTensor_ = indexTensor;
563585 newCachedGraph->sourceTensor_ = sourceTensor;
@@ -572,7 +594,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
572594 Placeholder indexPlaceholder = Placeholder (cachedGraph->indexTensor_ , index);
573595 Placeholder sourcePlaceholder = Placeholder (cachedGraph->sourceTensor_ , source);
574596 Placeholder outputPlaceholder = Placeholder (cachedGraph->outputTensor_ , result);
575- MPSScalar alpha_scalar = getMPSScalar (alpha, casted_type );
597+ MPSScalar alpha_scalar = getMPSScalar (alpha, self. scalar_type () );
576598
577599 NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
578600 selfPlaceholder.getMPSGraphTensor () : selfPlaceholder.getMPSGraphTensorData (),
0 commit comments