Skip to content

[BUG] SparseTensor.__add__ sometimes won't work even both spatial_range stay the same.  #325

Open
@ytliu74

Description

@ytliu74

Is there an existing issue for this?

  • I have searched the existing issues

Current Behavior

SparseResBlock cannot work when kernel size > 1 and stride > 1.

Here's my code of that part:

model = SparseResBlock(C, C_OUT, kernel_size=3, stride=2).to("cuda")
output = model(sparse_map)

The error message is as follows:

Traceback (most recent call last):
  File "example_from_dense.py", line 64, in <module>
    output = model(sparse_map)
  File "/opt/conda/envs/habitat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/habitat/lib/python3.7/site-packages/torchsparse-2.1.0-py3.7-linux-x86_64.egg/torchsparse/backbones/modules/blocks.py", line 84, in forward
    x = self.relu(self.main(x) + self.shortcut(x))
  File "/opt/conda/envs/habitat/lib/python3.7/site-packages/torchsparse-2.1.0-py3.7-linux-x86_64.egg/torchsparse/tensor.py", line 109, in __add__
    feats=self.feats + other.feats,
RuntimeError: The size of tensor a (101) must match the size of tensor b (16) at non-singleton dimension 0

So it seems to be the issue with SparseTensor.__add__. When I set kernel to 1 OR stride to 1, there's no issue.

I thought the issue is from the mismatch of SparseTensor.spatial_range, as there's no padding setting thus kernel > 1 AND stride > 1 would make the self.main(x) and self.shortcut(x) have different sparial_range.

Thus I implemented my own version as below:

class TorchSparseBasicBlock(nn.Module):
    def __init__(self, nInput: int, nInner: int, stride: int):
        super(TorchSparseBasicBlock, self).__init__()

        nOut = nInner
        self.conv = nn.Sequential(
            spnn.Conv3d(
                nInner,
                nInner,
                kernel_size=(1, 3, 3),
                stride=(1, stride, stride),
                padding=(0, 1, 1),
            ),
            spnn.BatchNorm(nInner),
            spnn.ReLU(True),
            spnn.Conv3d(nInner, nOut, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
            spnn.BatchNorm(nOut),
        )

        if stride > 1 or nInput != nOut:
            self.downsample = nn.Sequential(
                spnn.Conv3d(
                    nInput,
                    nOut,
                    kernel_size=(1, 1, 1),
                    stride=(1, stride, stride),
                ),
                spnn.BatchNorm(nOut),
            )
        else:
            self.downsample = nn.Identity()

        self.relu = spnn.ReLU(True)

    def forward(self, x):
        print("Input.shape:", x.spatial_range)
        conv = self.conv(x)
        print("Conv.shape:", conv.spatial_range)
        print("Conv", self.conv)
        downsample = self.downsample(x)
        print("Downsample.shape:", downsample.spatial_range)
        print("Downsample", self.downsample)
        x = self.relu(conv + downsample)
        # x = self.relu(self.conv(x) + self.downsample(x))
        return x

And I add some print for debugging. Here's what I got:

Input.shape: torch.Size([8, 1, 480, 480])
Conv.shape: torch.Size([8, 1, 240, 240])
Conv Sequential(
  (0): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), bias=False)
  (1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv3d(128, 128, kernel_size=(1, 3, 3), bias=False)
  (4): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Downsample.shape: torch.Size([8, 1, 240, 240])
Downsample Sequential(
  (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
  (1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

But the same error message persists, even though the spatial_range of both SparseTensors is the same. But I found in some way, it works:

            self.downsample = nn.Sequential(
                spnn.Conv3d(
                    nInput,
                    nOut,
                    kernel_size=(1, 3, 3),
                    stride=(1, stride, stride),
                    padding=(0, 1, 1),
                ),
                spnn.BatchNorm(nOut),
            )

Basically I only update the downsample:

  1. Change the kernel size from (1, 1, 1) to (1, 3, 3).
  2. Add padding to make sure the spatial_range stays the same when strided.
    Then it works, the partial output:
Input.shape: torch.Size([8, 1, 480, 480])
Conv.shape: torch.Size([8, 1, 240, 240])
Conv Sequential(
  (0): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), bias=False)
  (1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv3d(128, 128, kernel_size=(1, 3, 3), bias=False)
  (4): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Downsample.shape: torch.Size([8, 1, 240, 240])
Downsample Sequential(
  (0): Conv3d(64, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), bias=False)
  (1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

Where both spatial_range still stays the same.

So my main concern is:
Why doesn't the SparseTensor.__add__ function as expected under certain conditions, even when both spatial_range are aligned.

Expected Behavior

SparseTensor.__add__ should work when both spatial_range are the same.

Environment

- GCC: 7.5.0
- NVCC: 11.1
- PyTorch: 1.10.0+cu111
- PyTorch CUDA: 11.1
- TorchSparse: 2.1.0

Anything else?

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions