Skip to content

Commit 5040edf

Browse files
razarmehrkulinseth
authored andcommitted
Cast to unsigned type when comparing signed vs. unsigned integers in BinaryOps (#173)
Also remove the double cast to boolean in comparison ops
1 parent aaf9078 commit 5040edf

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,16 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
7373

7474
// this type inference is only required at the time of graph creation
7575
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());
76-
// Integer input must be cast to float if output is float
77-
if (isIntegralType(common_dtype, true) && isFloatingType(output.scalar_type())) {
78-
common_dtype = output_.scalar_type();
76+
if (isIntegralType(common_dtype, true)) {
77+
// integer inputs must be cast to float, if output is float
78+
if (isFloatingType(output_.scalar_type())) {
79+
common_dtype = output_.scalar_type();
80+
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
81+
} else if (output_.scalar_type() == ScalarType::Bool &&
82+
(self.scalar_type() == ScalarType::Byte ||
83+
other.scalar_type() == ScalarType::Byte)) {
84+
common_dtype = ScalarType::Byte;
85+
}
7986
}
8087
if (self.scalar_type() != common_dtype) {
8188
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
@@ -230,16 +237,15 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp
230237
name:nil]; }); \
231238
}
232239

233-
// Boolean Ops require casting output to "MPSDataTypeBool"
240+
// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor()
234241
#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \
235242
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
236243
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
237244
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
238245
MPSGraph* mpsGraph = cachedGraph->graph(); \
239-
MPSGraphTensor* outputTensor = [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
240-
secondaryTensor:secondaryCastTensor \
241-
name:nil]; \
242-
return mps::castMPSTensor(mpsGraph, outputTensor, ScalarType::Bool); }); \
246+
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
247+
secondaryTensor:secondaryCastTensor \
248+
name:nil]; }); \
243249
}
244250

245251
// Boolean Binary Ops

test/test_mps.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2911,6 +2911,14 @@ def test_eq(self):
29112911

29122912
self.assertEqual(result_cpu, result_mps.to('cpu'))
29132913

2914+
def test_signed_vs_unsigned_comparison(self):
2915+
cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
2916+
mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
2917+
# in the comparison of signed vs. unsigned we should always cast to unsigned
2918+
self.assertEqual(cpu_x == -1, mps_x == -1)
2919+
self.assertEqual(cpu_x > -1, mps_x > -1)
2920+
self.assertEqual(cpu_x < -1, mps_x < -1)
2921+
29142922
def test_eq_int64(self):
29152923
values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
29162924
values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
@@ -8173,7 +8181,7 @@ class TestConsistency(TestCase):
81738181
'nn.functional.conv1d': ['f32'],
81748182
'nn.functional.conv2d': ['f32'],
81758183
'nn.functional.conv_transpose1d': ['f32'],
8176-
'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64'],
8184+
'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
81778185
'nn.functional.cosine_similarity': ['f32'],
81788186
'nn.functional.elu': ['f32'],
81798187
'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8550,7 +8558,6 @@ class TestConsistency(TestCase):
85508558
'nn.functional.avg_pool2d': ['torch.float32', 'torch.int64'],
85518559
'nn.functional.bilinear': ['torch.float32'],
85528560
'nn.functional.conv_transpose2d': ['torch.float32'],
8553-
'nn.functional.cosine_embedding_loss': ['torch.uint8'],
85548561
'nn.functional.interpolate': ['torch.float32', 'torch.float32', 'torch.float32'],
85558562
'nn.functional.max_pool1d': ['torch.float32'],
85568563
'nn.functional.max_pool2d': ['torch.float32'],

0 commit comments

Comments
 (0)