Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion detectron2/layers/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,14 @@ def forward(
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError("Deformable Conv is not supported on CPUs!")
# TODO: let torchvision support full features of our deformconv.
if deformable_groups != 1:
raise NotImplementedError(
"Deformable Conv with deformable_groups != 1 is not supported on CPUs!"
)
return deform_conv2d(
input, offset, weight, stride=stride, padding=padding, dilation=dilation, mask=mask
)
if (
weight.requires_grad
or mask.requires_grad
Expand Down
82 changes: 45 additions & 37 deletions tests/layers/test_deformable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
"This test fails under cuda11 + torch1.8.",
)
class DeformableTest(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Deformable not supported for cpu")
def test_forward_output(self):
device = torch.device("cuda")
def forward_output(self, device):
N, C, H, W = shape = 1, 1, 5, 5
kernel_size = 3
padding = 1
Expand Down Expand Up @@ -57,44 +55,24 @@ def test_forward_output(self):
output = output.detach().cpu().numpy()
self.assertTrue(np.allclose(output.flatten(), deform_results.flatten() * 0.5))

def test_forward_output_on_cpu(self):
device = torch.device("cpu")
N, C, H, W = shape = 1, 1, 5, 5
kernel_size = 3
padding = 1

inputs = torch.arange(np.prod(shape), dtype=torch.float32).reshape(*shape).to(device)
@unittest.skipIf(not torch.cuda.is_available(), "This test requires gpu access")
def test_forward_output_cuda(self):
self.forward_output(torch.device("cuda"))

offset_channels = kernel_size * kernel_size * 2
offset = torch.full((N, offset_channels, H, W), 0.5, dtype=torch.float32).to(device)

# Test DCN v1 on cpu
deform = DeformConv(C, C, kernel_size=kernel_size, padding=padding).to(device)
deform.weight = torch.nn.Parameter(torch.ones_like(deform.weight))
output = deform(inputs, offset)
output = output.detach().cpu().numpy()
deform_results = np.array(
[
[30, 41.25, 48.75, 45, 28.75],
[62.25, 81, 90, 80.25, 50.25],
[99.75, 126, 135, 117.75, 72.75],
[105, 131.25, 138.75, 120, 73.75],
[71.75, 89.25, 93.75, 80.75, 49.5],
]
)
self.assertTrue(np.allclose(output.flatten(), deform_results.flatten()))
def test_forward_output_cpu(self):
self.forward_output(torch.device("cpu"))

@unittest.skipIf(not torch.cuda.is_available(), "This test requires gpu access")
def test_forward_output_on_cpu_equals_output_on_gpu(self):
N, C, H, W = shape = 2, 4, 10, 10
kernel_size = 3
padding = 1

for groups in [1, 2]:
inputs = torch.arange(np.prod(shape), dtype=torch.float32).reshape(*shape)
offset_channels = kernel_size * kernel_size * 2
offset = torch.full((N, offset_channels, H, W), 0.5, dtype=torch.float32)
inputs = torch.arange(np.prod(shape), dtype=torch.float32).reshape(*shape)
offset_channels = kernel_size * kernel_size * 2
offset = torch.full((N, offset_channels, H, W), 0.5, dtype=torch.float32)

for groups in [1, 2]:
deform_gpu = DeformConv(
C, C, kernel_size=kernel_size, padding=padding, groups=groups
).to("cuda")
Expand All @@ -107,11 +85,42 @@ def test_forward_output_on_cpu_equals_output_on_gpu(self):
deform_cpu.weight = torch.nn.Parameter(torch.ones_like(deform_cpu.weight))
output_cpu = deform_cpu(inputs.to("cpu"), offset.to("cpu")).detach().numpy()

self.assertTrue(np.allclose(output_gpu.flatten(), output_cpu.flatten()))
self.assertTrue(np.allclose(output_gpu.flatten(), output_cpu.flatten()))

mask_channels = kernel_size * kernel_size
mask = torch.full((N, mask_channels, H, W), 0.5, dtype=torch.float32)
for groups in [1, 2]:
modulate_deform_gpu = ModulatedDeformConv(
C, C, kernel_size=kernel_size, padding=padding, groups=groups, bias=False
).to("cuda")
modulate_deform_gpu.weight = torch.nn.Parameter(
torch.ones_like(modulate_deform_gpu.weight)
)
output_modulate_gpu = (
modulate_deform_gpu(inputs.to("cuda"), offset.to("cuda"), mask.to("cuda"))
.detach()
.cpu()
.numpy()
)

modulate_deform_cpu = ModulatedDeformConv(
C, C, kernel_size=kernel_size, padding=padding, groups=groups, bias=False
).to("cpu")
modulate_deform_cpu.weight = torch.nn.Parameter(
torch.ones_like(modulate_deform_cpu.weight)
)
output_modulate_cpu = (
modulate_deform_cpu(inputs.to("cpu"), offset.to("cpu"), mask.to("cpu"))
.detach()
.numpy()
)

self.assertTrue(
np.allclose(output_modulate_gpu.flatten(), output_modulate_cpu.flatten())
)

@unittest.skipIf(not torch.cuda.is_available(), "Deformable not supported for cpu")
def test_small_input(self):
device = torch.device("cuda")
device = torch.device("cpu")
for kernel_size in [3, 5]:
padding = kernel_size // 2
N, C, H, W = shape = (1, 1, kernel_size - 1, kernel_size - 1)
Expand All @@ -132,9 +141,8 @@ def test_small_input(self):
output = modulate_deform(inputs, offset, mask)
self.assertTrue(output.shape == inputs.shape)

@unittest.skipIf(not torch.cuda.is_available(), "Deformable not supported for cpu")
def test_raise_exception(self):
device = torch.device("cuda")
device = torch.device("cpu")
N, C, H, W = shape = 1, 1, 3, 3
kernel_size = 3
padding = 1
Expand Down