From 2b16299fdf6d128ab4cc270f28fef0488489f4e4 Mon Sep 17 00:00:00 2001 From: David Berard Date: Tue, 10 Jan 2023 01:33:26 -0800 Subject: [PATCH] Remove torch.jit.fuser("fuser2") in test (#7069) * [WIP] Remove torch.jit.fuser("fuser2") in test Internally we're considering removing support for fuser2, so we need to remove this special case from the test. * completely remove special-casing --- test/test_ops.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c2f101b39bd..eb2e31c9bcf 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1555,13 +1555,7 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed): torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - if device == "cpu": - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - else: - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 - # We may remove this condition once the bug is resolved - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)