diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 6edb7bdb2a3de..e6d25b82eae3b 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -73,9 +73,16 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha // this type inference is only required at the time of graph creation ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type()); - // Integer input must be cast to float if output is float - if (isIntegralType(common_dtype, true) && isFloatingType(output.scalar_type())) { - common_dtype = output_.scalar_type(); + if (isIntegralType(common_dtype, true)) { + // integer inputs must be cast to float, if output is float + if (isFloatingType(output_.scalar_type())) { + common_dtype = output_.scalar_type(); + // in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type + } else if (output_.scalar_type() == ScalarType::Bool && + (self.scalar_type() == ScalarType::Byte || + other.scalar_type() == ScalarType::Byte)) { + common_dtype = ScalarType::Byte; + } } if (self.scalar_type() != common_dtype) { primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); @@ -230,16 +237,15 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp name:nil]; }); \ } -// Boolean Ops require casting output to "MPSDataTypeBool" +// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor() #define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ MPSGraph* mpsGraph = cachedGraph->graph(); \ - MPSGraphTensor* outputTensor = [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - return mps::castMPSTensor(mpsGraph, outputTensor, ScalarType::Bool); }); \ + return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; }); \ } // Boolean Binary Ops diff --git a/test/test_mps.py b/test/test_mps.py index 2c040f7e6d544..286b63655989c 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2903,6 +2903,14 @@ def test_eq(self): self.assertEqual(result_cpu, result_mps.to('cpu')) + def test_signed_vs_unsigned_comparison(self): + cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8) + mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8) + # in the comparison of signed vs. unsigned we should always cast to unsigned + self.assertEqual(cpu_x == -1, mps_x == -1) + self.assertEqual(cpu_x > -1, mps_x > -1) + self.assertEqual(cpu_x < -1, mps_x < -1) + def test_eq_int64(self): values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]] @@ -8088,7 +8096,7 @@ class TestConsistency(TestCase): 'nn.functional.conv1d': ['f32'], 'nn.functional.conv2d': ['f32'], 'nn.functional.conv_transpose1d': ['f32'], - 'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64'], + 'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.cosine_similarity': ['f32'], 'nn.functional.elu': ['f32'], 'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -8464,7 +8472,6 @@ class TestConsistency(TestCase): 'nn.functional.avg_pool2d': ['torch.float32', 'torch.int64'], 'nn.functional.bilinear': ['torch.float32'], 'nn.functional.conv_transpose2d': ['torch.float32'], - 'nn.functional.cosine_embedding_loss': ['torch.uint8'], 'nn.functional.interpolate': ['torch.float32', 'torch.float32', 'torch.float32'], 'nn.functional.max_pool1d': ['torch.float32'], 'nn.functional.max_pool2d': ['torch.float32'],