Skip to content

Commit fd1139e

Browse files
authored
Display error for int64 pooling inputs on Monterey (#208)
* Display error for int64 pooling inputs on Monterey Also skip the test on TestConsistency when such errors are generated
1 parent 314d316 commit fd1139e

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

aten/src/ATen/native/mps/operations/Pooling.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
3232
if (input.numel() == 0)
3333
return;
3434

35+
if (!is_macos_13_or_newer()) {
36+
TORCH_CHECK(input.scalar_type() != ScalarType::Long,
37+
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0.");
38+
}
3539
const int64_t ndims = input.ndimension();
3640
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
3741
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));

test/test_mps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8769,6 +8769,9 @@ def get_samples():
87698769
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
87708770

87718771
except Exception as e:
8772+
if any(s in str(e).lower() for s in ["int64", "macos 13"]):
8773+
self.skipTest(f"{str(e)}")
8774+
87728775
if not generate_new_truth:
87738776
raise e
87748777
forward_failed = True

0 commit comments

Comments
 (0)