diff --git a/test/test_ops.py b/test/test_ops.py index c2f101b39bd..eb2e31c9bcf 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1555,13 +1555,7 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed): torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - if device == "cpu": - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - else: - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 - # We may remove this condition once the bug is resolved - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)