Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,16 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha

// this type inference is only required at the time of graph creation
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());
// Integer input must be cast to float if output is float
if (isIntegralType(common_dtype, true) && isFloatingType(output.scalar_type())) {
common_dtype = output_.scalar_type();
if (isIntegralType(common_dtype, true)) {
// integer inputs must be cast to float, if output is float
if (isFloatingType(output_.scalar_type())) {
common_dtype = output_.scalar_type();
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
} else if (output_.scalar_type() == ScalarType::Bool &&
(self.scalar_type() == ScalarType::Byte ||
other.scalar_type() == ScalarType::Byte)) {
common_dtype = ScalarType::Byte;
}
}
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
Expand Down Expand Up @@ -230,16 +237,15 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp
name:nil]; }); \
}

// Boolean Ops require casting output to "MPSDataTypeBool"
// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor()
#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
MPSGraph* mpsGraph = cachedGraph->graph(); \
MPSGraphTensor* outputTensor = [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
secondaryTensor:secondaryCastTensor \
name:nil]; \
return mps::castMPSTensor(mpsGraph, outputTensor, ScalarType::Bool); }); \
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
secondaryTensor:secondaryCastTensor \
name:nil]; }); \
}

// Boolean Binary Ops
Expand Down
11 changes: 9 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,6 +2903,14 @@ def test_eq(self):

self.assertEqual(result_cpu, result_mps.to('cpu'))

def test_signed_vs_unsigned_comparison(self):
cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
# in the comparison of signed vs. unsigned we should always cast to unsigned
self.assertEqual(cpu_x == -1, mps_x == -1)
self.assertEqual(cpu_x > -1, mps_x > -1)
self.assertEqual(cpu_x < -1, mps_x < -1)

def test_eq_int64(self):
values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
Expand Down Expand Up @@ -8088,7 +8096,7 @@ class TestConsistency(TestCase):
'nn.functional.conv1d': ['f32'],
'nn.functional.conv2d': ['f32'],
'nn.functional.conv_transpose1d': ['f32'],
'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64'],
'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nn.functional.cosine_similarity': ['f32'],
'nn.functional.elu': ['f32'],
'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -8464,7 +8472,6 @@ class TestConsistency(TestCase):
'nn.functional.avg_pool2d': ['torch.float32', 'torch.int64'],
'nn.functional.bilinear': ['torch.float32'],
'nn.functional.conv_transpose2d': ['torch.float32'],
'nn.functional.cosine_embedding_loss': ['torch.uint8'],
'nn.functional.interpolate': ['torch.float32', 'torch.float32', 'torch.float32'],
'nn.functional.max_pool1d': ['torch.float32'],
'nn.functional.max_pool2d': ['torch.float32'],
Expand Down