diff --git a/test/test_ops.py b/test/test_ops.py index 79294ed173e..5570d969cd7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -120,6 +120,13 @@ def get_script_fn(*args, **kwargs): def expected_fn(*args, **kwargs): pass + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_autocast(self): + for x_dtype in (torch.float, torch.half): + for rois_dtype in (torch.float, torch.half): + with torch.cuda.amp.autocast(): + self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) + class RoIPoolTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): @@ -295,13 +302,6 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r def _test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_align_autocast(self): - for x_dtype in (torch.float, torch.half): - for rois_dtype in (torch.float, torch.half): - with torch.cuda.amp.autocast(): - self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) - class PSRoIAlignTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): @@ -425,7 +425,8 @@ def test_nms(self): self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5) @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_nms_cuda(self): + def test_nms_cuda(self, dtype=torch.float64): + tol = 1e-3 if dtype is torch.half else 1e-5 err_msg = 'NMS incompatible between CPU and CUDA for IoU={}' for iou in [0.2, 0.5, 0.8]: @@ -437,9 +438,15 @@ def test_nms_cuda(self): if not is_eq: # if the indices are not the same, ensure that it's because the scores # are duplicate - is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()]) + is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) self.assertTrue(is_eq, err_msg.format(iou)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_autocast(self): + for dtype in (torch.float, torch.half): + with torch.cuda.amp.autocast(): + self.test_nms_cuda(dtype=dtype) + class NewEmptyTensorTester(unittest.TestCase): def test_new_empty_tensor(self): @@ -492,7 +499,7 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1): out += bias.view(1, n_out_channels, 1, 1) return out - def get_fn_args(self, device, contiguous, batch_sz): + def get_fn_args(self, device, contiguous, batch_sz, dtype): n_in_channels = 6 n_out_channels = 2 n_weight_grps = 2 @@ -511,15 +518,15 @@ def get_fn_args(self, device, contiguous, batch_sz): out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1 out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1 - x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True) + x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True) offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, - device=device, dtype=self.dtype, requires_grad=True) + device=device, dtype=dtype, requires_grad=True) weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, - device=device, dtype=self.dtype, requires_grad=True) + device=device, dtype=dtype, requires_grad=True) - bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True) + bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) @@ -528,26 +535,29 @@ def get_fn_args(self, device, contiguous, batch_sz): return x, weight, offset, bias, stride, pad, dilation - def _test_forward(self, device, contiguous): + def _test_forward(self, device, contiguous, dtype=None): + dtype = self.dtype if dtype is None else dtype for batch_sz in [0, 33]: - self._test_forward_with_batchsize(device, contiguous, batch_sz) + self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype) - def _test_forward_with_batchsize(self, device, contiguous, batch_sz): - x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz) + def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype): + x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) in_channels = 6 out_channels = 2 kernel_size = (3, 2) groups = 2 layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups).to(device=x.device, dtype=x.dtype) + dilation=dilation, groups=groups).to(device=x.device, dtype=dtype) res = layer(x, offset) weight = layer.weight.data bias = layer.bias.data expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation) - self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected)) + tol = 1e-3 if dtype is torch.half else 1e-5 + self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), + '\nres:\n{}\nexpected:\n{}'.format(res, expected)) # test for wrong sizes with self.assertRaises(RuntimeError): @@ -559,7 +569,7 @@ def _test_backward(self, device, contiguous): self._test_backward_with_batchsize(device, contiguous, batch_sz) def _test_backward_with_batchsize(self, device, contiguous, batch_sz): - x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz) + x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype) def func(x_, offset_, weight_, bias_): return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation) @@ -603,6 +613,12 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_): res_grads = init_weight.grad.to("cpu") self.assertTrue(true_cpu_grads.allclose(res_grads)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_autocast(self): + for dtype in (torch.float, torch.half): + with torch.cuda.amp.autocast(): + self._test_forward(torch.device("cuda"), False, dtype=dtype) + class FrozenBNTester(unittest.TestCase): def test_frozenbatchnorm2d_repr(self):