-
Notifications
You must be signed in to change notification settings - Fork 627
Open
Description
Pytorch allows developers to pass the strides for a Conv2d layer as integer or list/tuple.
It accepts tuples of length one or two. If two values are in the tuple, the lowering to linalg works as expected.
However, if the tuple contains a single value, the lowering step to linalg fails and torch-mlir crashes.
minimum example:
class SingleConvNet(nn.Module):
def __init__(self, stride=1):
super(SingleConvNet, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=stride, padding=1)
def forward(self, x):
x = self.conv(x)
return x
# Example usage:
model_bad = SingleConvNet(stride=(1,))
model_good = SingleConvNet(stride=(1,1))
model_good2 = SingleConvNet(stride=1)
input_tensor = torch.randn(1, 1, 28, 28)the "bad" version gets converted to:
module @module {
func.func @main(%arg0: !torch.vtensor<[1,1,28,28],f32>) -> !torch.vtensor<[1,1,28,28],f32> attributes {torch.assume_strict_symbolic_shapes} {
%0 = torch.vtensor.literal(dense_resource<torch_tensor_1_1_3_3_torch.float32> : tensor<1x1x3x3xf32>) : !torch.vtensor<[1,1,3,3],f32>
%1 = torch.vtensor.literal(dense<-0.158787966> : tensor<1xf32>) : !torch.vtensor<[1],f32>
%int1 = torch.constant.int 1
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%int1_0 = torch.constant.int 1
%int1_1 = torch.constant.int 1
%3 = torch.prim.ListConstruct %int1_0, %int1_1 : (!torch.int, !torch.int) -> !torch.list<int>
%int1_2 = torch.constant.int 1
%int1_3 = torch.constant.int 1
%4 = torch.prim.ListConstruct %int1_2, %int1_3 : (!torch.int, !torch.int) -> !torch.list<int>
%int1_4 = torch.constant.int 1
%5 = torch.aten.conv2d %arg0, %0, %1, %2, %3, %4, %int1_4 : !torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,28,28],f32>
return %5 : !torch.vtensor<[1,1,28,28],f32>
}
}
{-#
dialect_resources: {
builtin: {
torch_tensor_1_1_3_3_torch.float32: "0x040000001B4AD23D566BAABE22FD9B3EF6B5283EAED98A3E36B51DBE2016503D004BE3BB3FEC9B3E"
}
}
#-}
while the correct/"good" versions look like this:
module @module {
func.func @main(%arg0: !torch.vtensor<[1,1,28,28],f32>) -> !torch.vtensor<[1,1,28,28],f32> attributes {torch.assume_strict_symbolic_shapes} {
%0 = torch.vtensor.literal(dense_resource<torch_tensor_1_1_3_3_torch.float32> : tensor<1x1x3x3xf32>) : !torch.vtensor<[1,1,3,3],f32>
%1 = torch.vtensor.literal(dense<-0.155213714> : tensor<1xf32>) : !torch.vtensor<[1],f32>
%int1 = torch.constant.int 1
%int1_0 = torch.constant.int 1
%2 = torch.prim.ListConstruct %int1, %int1_0 : (!torch.int, !torch.int) -> !torch.list<int>
%int1_1 = torch.constant.int 1
%int1_2 = torch.constant.int 1
%3 = torch.prim.ListConstruct %int1_1, %int1_2 : (!torch.int, !torch.int) -> !torch.list<int>
%int1_3 = torch.constant.int 1
%int1_4 = torch.constant.int 1
%4 = torch.prim.ListConstruct %int1_3, %int1_4 : (!torch.int, !torch.int) -> !torch.list<int>
%int1_5 = torch.constant.int 1
%5 = torch.aten.conv2d %arg0, %0, %1, %2, %3, %4, %int1_5 : !torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,28,28],f32>
return %5 : !torch.vtensor<[1,1,28,28],f32>
}
}
{-#
dialect_resources: {
builtin: {
torch_tensor_1_1_3_3_torch.float32: "0x04000000606EEABDB6239DBDCE38013E60017E3D5BE7B03D982C01BE18EA1CBEDE1241BE8BD746BE"
}
}
#-}
I attached the crash dump.
The crash seems to occur during the ConvertTorchToLinalg stage, the last completed previous stage is ConvertTorchToTensor.
crash_dump.txt
Metadata
Metadata
Assignees
Labels
No labels