Skip to content

incorrect handling of strides for nn.Conv2d layers #4119

@MaxS1996

Description

@MaxS1996

Pytorch allows developers to pass the strides for a Conv2d layer as integer or list/tuple.
It accepts tuples of length one or two. If two values are in the tuple, the lowering to linalg works as expected.
However, if the tuple contains a single value, the lowering step to linalg fails and torch-mlir crashes.

minimum example:

class SingleConvNet(nn.Module):
    def __init__(self, stride=1):
        super(SingleConvNet, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=stride, padding=1)

    def forward(self, x):
        x = self.conv(x)
        return x

# Example usage:
model_bad = SingleConvNet(stride=(1,))
model_good = SingleConvNet(stride=(1,1))
model_good2 = SingleConvNet(stride=1)

input_tensor = torch.randn(1, 1, 28, 28)

the "bad" version gets converted to:

module @module {
  func.func @main(%arg0: !torch.vtensor<[1,1,28,28],f32>) -> !torch.vtensor<[1,1,28,28],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %0 = torch.vtensor.literal(dense_resource<torch_tensor_1_1_3_3_torch.float32> : tensor<1x1x3x3xf32>) : !torch.vtensor<[1,1,3,3],f32>
    %1 = torch.vtensor.literal(dense<-0.158787966> : tensor<1xf32>) : !torch.vtensor<[1],f32>
    %int1 = torch.constant.int 1
    %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
    %int1_0 = torch.constant.int 1
    %int1_1 = torch.constant.int 1
    %3 = torch.prim.ListConstruct %int1_0, %int1_1 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_2 = torch.constant.int 1
    %int1_3 = torch.constant.int 1
    %4 = torch.prim.ListConstruct %int1_2, %int1_3 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_4 = torch.constant.int 1
    %5 = torch.aten.conv2d %arg0, %0, %1, %2, %3, %4, %int1_4 : !torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,28,28],f32>
    return %5 : !torch.vtensor<[1,1,28,28],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_1_1_3_3_torch.float32: "0x040000001B4AD23D566BAABE22FD9B3EF6B5283EAED98A3E36B51DBE2016503D004BE3BB3FEC9B3E"
    }
  }
#-}

while the correct/"good" versions look like this:

module @module {
  func.func @main(%arg0: !torch.vtensor<[1,1,28,28],f32>) -> !torch.vtensor<[1,1,28,28],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %0 = torch.vtensor.literal(dense_resource<torch_tensor_1_1_3_3_torch.float32> : tensor<1x1x3x3xf32>) : !torch.vtensor<[1,1,3,3],f32>
    %1 = torch.vtensor.literal(dense<-0.155213714> : tensor<1xf32>) : !torch.vtensor<[1],f32>
    %int1 = torch.constant.int 1
    %int1_0 = torch.constant.int 1
    %2 = torch.prim.ListConstruct %int1, %int1_0 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_1 = torch.constant.int 1
    %int1_2 = torch.constant.int 1
    %3 = torch.prim.ListConstruct %int1_1, %int1_2 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_3 = torch.constant.int 1
    %int1_4 = torch.constant.int 1
    %4 = torch.prim.ListConstruct %int1_3, %int1_4 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_5 = torch.constant.int 1
    %5 = torch.aten.conv2d %arg0, %0, %1, %2, %3, %4, %int1_5 : !torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,28,28],f32>
    return %5 : !torch.vtensor<[1,1,28,28],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_1_1_3_3_torch.float32: "0x04000000606EEABDB6239DBDCE38013E60017E3D5BE7B03D982C01BE18EA1CBEDE1241BE8BD746BE"
    }
  }
#-}

I attached the crash dump.
The crash seems to occur during the ConvertTorchToLinalg stage, the last completed previous stage is ConvertTorchToTensor.
crash_dump.txt

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions