diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4a607e75bd..49ae325698 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -825,10 +825,15 @@ def aten_leaky_relu_backward( def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - if len(input.shape) == 2: + if len(input.shape) == 2 and len(weight.shape) == 2: # Use Gemm for the rank 2 input return op.Gemm(input, weight, bias, transB=True) - weight_transposed = op.Transpose(weight, perm=[1, 0]) + if len(weight.shape) == 1: + # In rare cases the weight can be 1d + weight_transposed = op.Unsqueeze(weight, [1]) + else: + assert len(weight.shape) == 2 + weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) if bias is None: return mul