From d2935421b8d9e16e7c17b73b210c60430b57f3bf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 25 May 2025 08:23:43 -0700 Subject: [PATCH 1/4] [torchlib] Update linear implementation to support 1d weights It is possible when users call `F.linear()` directly in PyTorch. --- onnxscript/function_libs/torch_lib/ops/nn.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4a607e75bd..c5291fe54d 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -828,7 +828,12 @@ def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> if len(input.shape) == 2: # Use Gemm for the rank 2 input return op.Gemm(input, weight, bias, transB=True) - weight_transposed = op.Transpose(weight, perm=[1, 0]) + if len(weight.shape) == 1: + # In rare cases the weight can be 1d + weight_transposed = op.Unsqueeze(weight, [1]) + else: + assert len(weight.shape) == 2 + weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) if bias is None: return mul From d0f7385c7e5b8be714191461a52d3bec55e04b4f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 25 May 2025 08:25:56 -0700 Subject: [PATCH 2/4] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index c5291fe54d..a877fea889 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -825,15 +825,11 @@ def aten_leaky_relu_backward( def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - if len(input.shape) == 2: + if len(input.shape) == 2 and len(weight.shape) == 2: # Use Gemm for the rank 2 input return op.Gemm(input, weight, bias, transB=True) - if len(weight.shape) == 1: - # In rare cases the weight can be 1d - weight_transposed = op.Unsqueeze(weight, [1]) - else: - assert len(weight.shape) == 2 - weight_transposed = op.Transpose(weight, perm=[1, 0]) + + weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) if bias is None: return mul From 499b7a523a74800abad336e1f1a89e3243e76048 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 25 May 2025 08:27:49 -0700 Subject: [PATCH 3/4] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a877fea889..20f26243ea 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -828,8 +828,12 @@ def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> if len(input.shape) == 2 and len(weight.shape) == 2: # Use Gemm for the rank 2 input return op.Gemm(input, weight, bias, transB=True) - - weight_transposed = op.Transpose(weight, perm=[1, 0]) + if len(weight.shape) == 1: + # In rare cases the weight can be 1d + weight_transposed = op.Unsqueeze(weight, [1]) + else: + assert len(weight.shape) == 2: + weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) if bias is None: return mul From 777398520a7bbde30bca76a65649337c3efdcf79 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 25 May 2025 08:38:32 -0700 Subject: [PATCH 4/4] Update onnxscript/function_libs/torch_lib/ops/nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 20f26243ea..49ae325698 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -832,7 +832,7 @@ def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> # In rare cases the weight can be 1d weight_transposed = op.Unsqueeze(weight, [1]) else: - assert len(weight.shape) == 2: + assert len(weight.shape) == 2 weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) if bias is None: