Skip to content

Commit 6398bd1

Browse files
committed
Display error for int64 pooling inputs on Monterey (#208)
* Display error for int64 pooling inputs on Monterey Also skip the test on TestConsistency when such errors are generated
1 parent 7100976 commit 6398bd1

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

aten/src/ATen/native/mps/operations/Pooling.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
3232
if (input.numel() == 0)
3333
return;
3434

35+
if (!is_macos_13_or_newer()) {
36+
TORCH_CHECK(input.scalar_type() != ScalarType::Long,
37+
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0.");
38+
}
3539
const int64_t ndims = input.ndimension();
3640
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
3741
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));

test/test_mps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8688,6 +8688,9 @@ def get_samples():
86888688
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
86898689

86908690
except Exception as e:
8691+
if any(s in str(e).lower() for s in ["int64", "macos 13"]):
8692+
self.skipTest(f"{str(e)}")
8693+
86918694
if not generate_new_truth:
86928695
raise e
86938696
forward_failed = True

0 commit comments

Comments
 (0)