Description
Is there an existing issue for this?
- I have searched the existing issues
Current Behavior
SparseResBlock
cannot work when kernel size > 1 and stride > 1.
Here's my code of that part:
model = SparseResBlock(C, C_OUT, kernel_size=3, stride=2).to("cuda")
output = model(sparse_map)
The error message is as follows:
Traceback (most recent call last):
File "example_from_dense.py", line 64, in <module>
output = model(sparse_map)
File "/opt/conda/envs/habitat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/habitat/lib/python3.7/site-packages/torchsparse-2.1.0-py3.7-linux-x86_64.egg/torchsparse/backbones/modules/blocks.py", line 84, in forward
x = self.relu(self.main(x) + self.shortcut(x))
File "/opt/conda/envs/habitat/lib/python3.7/site-packages/torchsparse-2.1.0-py3.7-linux-x86_64.egg/torchsparse/tensor.py", line 109, in __add__
feats=self.feats + other.feats,
RuntimeError: The size of tensor a (101) must match the size of tensor b (16) at non-singleton dimension 0
So it seems to be the issue with SparseTensor.__add__
. When I set kernel to 1 OR stride to 1, there's no issue.
I thought the issue is from the mismatch of SparseTensor.spatial_range
, as there's no padding setting thus kernel > 1 AND stride > 1 would make the self.main(x)
and self.shortcut(x)
have different sparial_range
.
Thus I implemented my own version as below:
class TorchSparseBasicBlock(nn.Module):
def __init__(self, nInput: int, nInner: int, stride: int):
super(TorchSparseBasicBlock, self).__init__()
nOut = nInner
self.conv = nn.Sequential(
spnn.Conv3d(
nInner,
nInner,
kernel_size=(1, 3, 3),
stride=(1, stride, stride),
padding=(0, 1, 1),
),
spnn.BatchNorm(nInner),
spnn.ReLU(True),
spnn.Conv3d(nInner, nOut, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
spnn.BatchNorm(nOut),
)
if stride > 1 or nInput != nOut:
self.downsample = nn.Sequential(
spnn.Conv3d(
nInput,
nOut,
kernel_size=(1, 1, 1),
stride=(1, stride, stride),
),
spnn.BatchNorm(nOut),
)
else:
self.downsample = nn.Identity()
self.relu = spnn.ReLU(True)
def forward(self, x):
print("Input.shape:", x.spatial_range)
conv = self.conv(x)
print("Conv.shape:", conv.spatial_range)
print("Conv", self.conv)
downsample = self.downsample(x)
print("Downsample.shape:", downsample.spatial_range)
print("Downsample", self.downsample)
x = self.relu(conv + downsample)
# x = self.relu(self.conv(x) + self.downsample(x))
return x
And I add some print for debugging. Here's what I got:
Input.shape: torch.Size([8, 1, 480, 480])
Conv.shape: torch.Size([8, 1, 240, 240])
Conv Sequential(
(0): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), bias=False)
(1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv3d(128, 128, kernel_size=(1, 3, 3), bias=False)
(4): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Downsample.shape: torch.Size([8, 1, 240, 240])
Downsample Sequential(
(0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
But the same error message persists, even though the spatial_range
of both SparseTensors is the same. But I found in some way, it works:
self.downsample = nn.Sequential(
spnn.Conv3d(
nInput,
nOut,
kernel_size=(1, 3, 3),
stride=(1, stride, stride),
padding=(0, 1, 1),
),
spnn.BatchNorm(nOut),
)
Basically I only update the downsample
:
- Change the kernel size from
(1, 1, 1)
to(1, 3, 3)
. - Add padding to make sure the
spatial_range
stays the same when strided.
Then it works, the partial output:
Input.shape: torch.Size([8, 1, 480, 480])
Conv.shape: torch.Size([8, 1, 240, 240])
Conv Sequential(
(0): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), bias=False)
(1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv3d(128, 128, kernel_size=(1, 3, 3), bias=False)
(4): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Downsample.shape: torch.Size([8, 1, 240, 240])
Downsample Sequential(
(0): Conv3d(64, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), bias=False)
(1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Where both spatial_range
still stays the same.
So my main concern is:
Why doesn't the SparseTensor.__add__
function as expected under certain conditions, even when both spatial_range
are aligned.
Expected Behavior
SparseTensor.__add__
should work when both spatial_range
are the same.
Environment
- GCC: 7.5.0
- NVCC: 11.1
- PyTorch: 1.10.0+cu111
- PyTorch CUDA: 11.1
- TorchSparse: 2.1.0
Anything else?
No response