diff --git a/test/test_mps.py b/test/test_mps.py index 0f074fbe88fcd..fa224eb7fe3d4 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8680,6 +8680,13 @@ class TestConsistency(TestCase): '__rpow__': ['torch.float32'], } + FP16_LOW_PRECISION_LIST = { + "add", "sub", + "__rdiv__", "__rmul__", + "nn.functional.huber_loss", + "true_divide" + } + # Used for accept mode only NEW_ALLOW_LIST = defaultdict(list) NEW_ALLOW_LIST_GRAD = defaultdict(list) @@ -8746,7 +8753,7 @@ def get_samples(): if op.name == "nn.functional.conv2d" and dtype == torch.float32: atol = 1e-4 rtol = 3e-5 - elif (op.name == "add" or op.name == "sub" or op.name == "nn.functional.huber_loss") and dtype == torch.float16: + elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16: atol = 1e-2 rtol = 1e-2 elif (op.name == "masked.mean"): @@ -8813,7 +8820,7 @@ def req_grad(t): cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) - self.assertEqual(cpu_grad_inputs, mps_grad_inputs) + self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) except Exception as e: if not generate_new_truth: raise e