Skip to content

Adding Autocast tests for all C++ Ops #2938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 30, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def get_script_fn(*args, **kwargs):
def expected_fn(*args, **kwargs):
pass

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
for x_dtype in (torch.float, torch.half):
for rois_dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)


class RoIPoolTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down Expand Up @@ -295,13 +302,6 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_autocast(self):
for x_dtype in (torch.float, torch.half):
for rois_dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)


class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down Expand Up @@ -425,7 +425,8 @@ def test_nms(self):
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda(self):
def test_nms_cuda(self, dtype=torch.float64):
tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'

for iou in [0.2, 0.5, 0.8]:
Expand All @@ -437,9 +438,15 @@ def test_nms_cuda(self):
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()])
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
self.assertTrue(is_eq, err_msg.format(iou))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
for dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self.test_nms_cuda(dtype=dtype)


class NewEmptyTensorTester(unittest.TestCase):
def test_new_empty_tensor(self):
Expand Down Expand Up @@ -492,7 +499,7 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
out += bias.view(1, n_out_channels, 1, 1)
return out

def get_fn_args(self, device, contiguous, batch_sz):
def get_fn_args(self, device, contiguous, batch_sz, dtype):
n_in_channels = 6
n_out_channels = 2
n_weight_grps = 2
Expand All @@ -511,15 +518,15 @@ def get_fn_args(self, device, contiguous, batch_sz):
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True)
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)

offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=self.dtype, requires_grad=True)
device=device, dtype=dtype, requires_grad=True)

weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=self.dtype, requires_grad=True)
device=device, dtype=dtype, requires_grad=True)

bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True)
bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)

if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
Expand All @@ -528,26 +535,29 @@ def get_fn_args(self, device, contiguous, batch_sz):

return x, weight, offset, bias, stride, pad, dilation

def _test_forward(self, device, contiguous):
def _test_forward(self, device, contiguous, dtype=None):
dtype = self.dtype if dtype is None else dtype
for batch_sz in [0, 33]:
self._test_forward_with_batchsize(device, contiguous, batch_sz)
self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)

def _test_forward_with_batchsize(self, device, contiguous, batch_sz):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2

layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups).to(device=x.device, dtype=x.dtype)
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
res = layer(x, offset)

weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)

self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
tol = 1e-3 if dtype is torch.half else 1e-5
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))

# test for wrong sizes
with self.assertRaises(RuntimeError):
Expand All @@ -559,7 +569,7 @@ def _test_backward(self, device, contiguous):
self._test_backward_with_batchsize(device, contiguous, batch_sz)

def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)

def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
Expand Down Expand Up @@ -603,6 +613,12 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
res_grads = init_weight.grad.to("cpu")
self.assertTrue(true_cpu_grads.allclose(res_grads))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
for dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), False, dtype=dtype)


class FrozenBNTester(unittest.TestCase):
def test_frozenbatchnorm2d_repr(self):
Expand Down