From cabbe189f75b6c54b191f5b6597c2251fc222803 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 09:23:48 -0700 Subject: [PATCH 01/19] [torchlib] Improve handling of SymInt[] Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/common.py | 27 ++++++++++ .../function_libs/torch_lib/ops/core.py | 54 +++++++++---------- 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index d7784a5289..d904965c4d 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -5,6 +5,8 @@ # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" from __future__ import annotations +from collections.abc import Sequence + import numpy.typing as npt import onnx @@ -78,3 +80,28 @@ def constant( A constant node. """ return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype))) + + +def merge_dims(dims: Sequence[int | INT64]) -> INT64: + """Merge consecutive constant dimensions.""" + + if not dims: + return op.Constant(value_ints=[]) + + remaining_dims = list(dims) + result_dims = [] + + while remaining_dims: + current_dim = remaining_dims.pop(0) + if isinstance(current_dim, int): + merged_dims = [current_dim] + # Merge consecutive constant dimensions into a constant node + while remaining_dims and isinstance(remaining_dims[0], int): + merged_dims.append(remaining_dims.pop(0)) + result_dims.append(op.Constant(value_ints=merged_dims)) + else: + # A dynamic dimension, just append it + result_dims.append(current_dim) + if len(result_dims) == 1: + return result_dims[0] + return op.Concat(result_dims, axis=0) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..f7cdb761b6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1524,9 +1524,9 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: @torch_op("aten::broadcast_to") -def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor: +def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - + size = common_ops.merge_dims(size) return op.Expand(self, size) @@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward( @torch_op("aten::empty.memory_format", trace_only=True) def aten_empty( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, memory_format: str = "", ) -> TensorType: # type: ignore[type-var] - # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # using Zeros to simulate np.empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + # using Zeros to simulate empty() + zero = op.Constant(value=ir.tensor(0, dtype=dtype)) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3334,7 +3334,7 @@ def aten_empty_quantized( @torch_op("aten::empty_strided", trace_only=True) def aten_empty_strided( - size: INT64, + size: Sequence[INT64], stride: INT64, layout: str = "", device: str = "", @@ -3343,8 +3343,8 @@ def aten_empty_strided( # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) + zero = op.Constant(value=ir.tensor(0, dtype=dtype)) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3392,13 +3392,12 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand", trace_only=True) -def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor: +def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor: """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - size = op.Abs(size) - return op.Expand(self, size) + size = [1 if s == -1 else s for s in size] + return op.Expand(self, common_ops.merge_dims(size)) @torch_op("aten::expand_as", trace_only=True) @@ -7301,11 +7300,9 @@ def aten_repeat_interleave( @torch_op("aten::reshape") -def aten_reshape(self: TTensor, shape: IntType) -> TTensor: +def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" - - # Reshape only support INT64 as 'shape' - shape = op.Cast(shape, to=INT64.dtype) + shape = common_ops.merge_dims(shape) return op.Reshape(self, shape) @@ -9045,19 +9042,18 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) -def aten_view(self: TTensor, size: IntType) -> TTensor: +def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size, allowzero=True) @torch_op(("aten::view", "aten::_unsafe_view"), complex=True) -def aten_view_complex(self: TTensor, size: IntType) -> TTensor: +def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) + complex_size = common_ops.merge_dims([*size, 2]) return op.Reshape(self, complex_size, allowzero=True) @@ -9109,7 +9105,7 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: def aten_view_copy(self: TTensor, size: IntType) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size) @@ -9153,7 +9149,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::zeros", trace_only=True) def aten_zeros( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -9162,9 +9158,9 @@ def aten_zeros( """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) From 5276a98fa35d69909bacda8a2dd35b9cc560bafb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 09:27:17 -0700 Subject: [PATCH 02/19] update Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f7cdb761b6..8676169da0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1523,7 +1523,7 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::broadcast_to") +@torch_op("aten::broadcast_to", trace_only=True) def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" size = common_ops.merge_dims(size) @@ -7299,7 +7299,7 @@ def aten_repeat_interleave( raise NotImplementedError() -@torch_op("aten::reshape") +@torch_op("aten::reshape", trace_only=True) def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" shape = common_ops.merge_dims(shape) From b8da2326cebf152e968fd0c8d04670d49a054117 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 09:30:10 -0700 Subject: [PATCH 03/19] trace where Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8676169da0..16c3ab209c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9049,7 +9049,7 @@ def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: return op.Reshape(self, size, allowzero=True) -@torch_op(("aten::view", "aten::_unsafe_view"), complex=True) +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True) def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" @@ -9057,7 +9057,7 @@ def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: return op.Reshape(self, complex_size, allowzero=True) -@torch_op("aten::view_as") +@torch_op("aten::view_as", trace_only=True) def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -9101,7 +9101,7 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_copy") +@torch_op("aten::view_copy", trace_only=True) def aten_view_copy(self: TTensor, size: IntType) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" @@ -9133,7 +9133,8 @@ def reshape_to_2d(tensor): "aten::where.ScalarSelf", "aten::where.ScalarOther", "aten::where.self", - ) + ), + trace_only=True, ) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" From d9196f4cf6aa703574c1a72acd808274695b124b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 09:37:24 -0700 Subject: [PATCH 04/19] fix Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index d904965c4d..72ad51d602 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -104,4 +104,8 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: result_dims.append(current_dim) if len(result_dims) == 1: return result_dims[0] - return op.Concat(result_dims, axis=0) + + # Set the output type to INT64 so op.Concat can be used + for dim in result_dims: + dim.dtype = ir.DataType.INT64 + return op.Concat(*result_dims, axis=0) From defa3566f1070bc0545dc4296b405b019d1db05a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 09:49:47 -0700 Subject: [PATCH 05/19] Use reshape for symsize Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 3 ++- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index 72ad51d602..bc8e3a3348 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -100,7 +100,8 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: merged_dims.append(remaining_dims.pop(0)) result_dims.append(op.Constant(value_ints=merged_dims)) else: - # A dynamic dimension, just append it + # A dynamic dimension, unsqueeze and append it + current_dim = op.Reshape(current_dim, op.Constant(value_ints=[])) result_dims.append(current_dim) if len(result_dims) == 1: return result_dims[0] diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 16c3ab209c..6a8d79520c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8290,7 +8290,7 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType: @torch_op("aten::sym_size.int", trace_only=True) def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: """sym_size.int(Tensor self, int dim) -> SymInt""" - return op.Squeeze(op.Shape(self, end=dim + 1, start=dim)) + return op.Reshape(op.Shape(self, end=dim + 1, start=dim), op.Constant(value_ints=[1])) def aten_symeig( From 3bc3ebf0d7941efa6570fdad1b4c61e7d05c7a03 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 09:52:16 -0700 Subject: [PATCH 06/19] attrs Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index bc8e3a3348..be5c31e2e1 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -86,7 +86,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: """Merge consecutive constant dimensions.""" if not dims: - return op.Constant(value_ints=[]) + return op.Constant(value_ints=ir.AttrInt64([])) remaining_dims = list(dims) result_dims = [] @@ -101,7 +101,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: result_dims.append(op.Constant(value_ints=merged_dims)) else: # A dynamic dimension, unsqueeze and append it - current_dim = op.Reshape(current_dim, op.Constant(value_ints=[])) + current_dim = op.Reshape(current_dim, op.Constant(value_ints=ir.AttrInt64([]))) result_dims.append(current_dim) if len(result_dims) == 1: return result_dims[0] From a3eaece97e0b559705c88fec83297fff4bf28f11 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 10:04:43 -0700 Subject: [PATCH 07/19] fix Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 6 ++++-- onnxscript/function_libs/torch_lib/ops/core.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index be5c31e2e1..ef211704aa 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -86,7 +86,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: """Merge consecutive constant dimensions.""" if not dims: - return op.Constant(value_ints=ir.AttrInt64([])) + return op.Constant(value_ints=ir.AttrInt64("value_ints", [])) remaining_dims = list(dims) result_dims = [] @@ -101,7 +101,9 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: result_dims.append(op.Constant(value_ints=merged_dims)) else: # A dynamic dimension, unsqueeze and append it - current_dim = op.Reshape(current_dim, op.Constant(value_ints=ir.AttrInt64([]))) + current_dim = op.Reshape( + current_dim, op.Constant(value_ints=ir.AttrInt64("value_ints", [])) + ) result_dims.append(current_dim) if len(result_dims) == 1: return result_dims[0] diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6a8d79520c..152b09178c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3337,6 +3337,7 @@ def aten_empty_strided( size: Sequence[INT64], stride: INT64, layout: str = "", + dtype: int = FLOAT.dtype, device: str = "", pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] @@ -3396,7 +3397,7 @@ def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - size = [1 if s == -1 else s for s in size] + size = [1 if isinstance(s, int) and s == -1 else s for s in size] return op.Expand(self, common_ops.merge_dims(size)) @@ -9102,7 +9103,7 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: @torch_op("aten::view_copy", trace_only=True) -def aten_view_copy(self: TTensor, size: IntType) -> TTensor: +def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" size = common_ops.merge_dims(size) From c8ef89cd9aa627aa868c4fedea2e161989d57df6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 10:12:23 -0700 Subject: [PATCH 08/19] fix symsize Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 152b09178c..14418f5546 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8291,7 +8291,10 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType: @torch_op("aten::sym_size.int", trace_only=True) def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: """sym_size.int(Tensor self, int dim) -> SymInt""" - return op.Reshape(op.Shape(self, end=dim + 1, start=dim), op.Constant(value_ints=[1])) + return op.Reshape( + op.Shape(self, end=dim + 1, start=dim), + op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64)), + ) def aten_symeig( From fd383988deb830b01afc9c2f60788ed326931c2d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 10:25:01 -0700 Subject: [PATCH 09/19] Use gather Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 14418f5546..68f65eecd5 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8291,10 +8291,8 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType: @torch_op("aten::sym_size.int", trace_only=True) def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: """sym_size.int(Tensor self, int dim) -> SymInt""" - return op.Reshape( - op.Shape(self, end=dim + 1, start=dim), - op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64)), - ) + + return op.Gather(op.Shape(self), dim, axis=0) def aten_symeig( From 3b406aac8b7c26bc905b1cd11728550a402e04e3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 10:33:34 -0700 Subject: [PATCH 10/19] fix shape Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index ef211704aa..b6f9c8665d 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -102,7 +102,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: else: # A dynamic dimension, unsqueeze and append it current_dim = op.Reshape( - current_dim, op.Constant(value_ints=ir.AttrInt64("value_ints", [])) + current_dim, op.Constant(value_ints=ir.AttrInt64("value_ints", [1])) ) result_dims.append(current_dim) if len(result_dims) == 1: From 7ac7b931826fa2ee4aa47fce613d3fabd6ddddd4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 10:41:25 -0700 Subject: [PATCH 11/19] type Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index b6f9c8665d..61cfaed95c 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -86,7 +86,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: """Merge consecutive constant dimensions.""" if not dims: - return op.Constant(value_ints=ir.AttrInt64("value_ints", [])) + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) remaining_dims = list(dims) result_dims = [] @@ -102,7 +102,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: else: # A dynamic dimension, unsqueeze and append it current_dim = op.Reshape( - current_dim, op.Constant(value_ints=ir.AttrInt64("value_ints", [1])) + current_dim, op.Constant(value_ints=ir.AttrInt64s("value_ints", [1])) ) result_dims.append(current_dim) if len(result_dims) == 1: From c1bc9aea1ec5dc7e4866464f2891c7e86520b5d9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 11:19:28 -0700 Subject: [PATCH 12/19] fix dtype Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 68f65eecd5..3c65fdbf6d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3298,7 +3298,7 @@ def aten_empty( dtype = FLOAT.dtype # using Zeros to simulate empty() - zero = op.Constant(value=ir.tensor(0, dtype=dtype)) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3344,7 +3344,7 @@ def aten_empty_strided( # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - zero = op.Constant(value=ir.tensor(0, dtype=dtype)) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) size = common_ops.merge_dims(size) return op.Expand(zero, size) From c022c6f8d57438158f40a12806641a911e05554a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 11:37:33 -0700 Subject: [PATCH 13/19] Simplify merge_dims Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/common.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index 61cfaed95c..4ae8725e76 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -88,25 +88,14 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: if not dims: return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) - remaining_dims = list(dims) - result_dims = [] - - while remaining_dims: - current_dim = remaining_dims.pop(0) - if isinstance(current_dim, int): - merged_dims = [current_dim] - # Merge consecutive constant dimensions into a constant node - while remaining_dims and isinstance(remaining_dims[0], int): - merged_dims.append(remaining_dims.pop(0)) - result_dims.append(op.Constant(value_ints=merged_dims)) - else: - # A dynamic dimension, unsqueeze and append it - current_dim = op.Reshape( - current_dim, op.Constant(value_ints=ir.AttrInt64s("value_ints", [1])) - ) - result_dims.append(current_dim) - if len(result_dims) == 1: - return result_dims[0] + one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [1])) + + dims = [ + op.Constant(value_ints=[current_dim]) + if isinstance(current_dim, int) + else op.Reshape(current_dim, one_1d) + for current_dim in dims + ] # Set the output type to INT64 so op.Concat can be used for dim in result_dims: From 45f58e3195a48db7e25ce029093fb08631ea1153 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 11:38:11 -0700 Subject: [PATCH 14/19] s Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index 4ae8725e76..fb2801f7a1 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -91,10 +91,8 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [1])) dims = [ - op.Constant(value_ints=[current_dim]) - if isinstance(current_dim, int) - else op.Reshape(current_dim, one_1d) - for current_dim in dims + op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, one_1d) + for d in dims ] # Set the output type to INT64 so op.Concat can be used From 39e3d44b85f25a79feac29e86432729911c1a5b4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 11:39:18 -0700 Subject: [PATCH 15/19] Concatenate Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index fb2801f7a1..8fabea6861 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -83,7 +83,7 @@ def constant( def merge_dims(dims: Sequence[int | INT64]) -> INT64: - """Merge consecutive constant dimensions.""" + """Concatenate dimensions into a single value.""" if not dims: return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) From 926b1acc43c2a28a678c44f60ce4160a5f54fcd4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 27 Aug 2025 11:39:52 -0700 Subject: [PATCH 16/19] result_dims Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index 8fabea6861..19d52bce1b 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -90,7 +90,7 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [1])) - dims = [ + result_dims = [ op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, one_1d) for d in dims ] From c32880b7059e5b777aa4a2aed5fa238eb30d8565 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 12:20:17 -0700 Subject: [PATCH 17/19] Revert sym_size Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3c65fdbf6d..7016cf36f4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8291,8 +8291,7 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType: @torch_op("aten::sym_size.int", trace_only=True) def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: """sym_size.int(Tensor self, int dim) -> SymInt""" - - return op.Gather(op.Shape(self), dim, axis=0) + return op.Squeeze(op.Shape(self, end=dim + 1, start=dim)) def aten_symeig( From baee64b8798b3c66ae199f9c9e8af211463e1328 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 12:21:20 -0700 Subject: [PATCH 18/19] docs Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 7016cf36f4..e79383f56a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3397,6 +3397,8 @@ def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. + # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely + # and isn't expected to appear from correct usages of SymInt. size = [1 if isinstance(s, int) and s == -1 else s for s in size] return op.Expand(self, common_ops.merge_dims(size)) From 4334c37c483cc05fead668015345cc861df18f3c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 16:00:42 -0700 Subject: [PATCH 19/19] neg_one_1d Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index 19d52bce1b..b3ebbc1c53 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -88,10 +88,10 @@ def merge_dims(dims: Sequence[int | INT64]) -> INT64: if not dims: return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) - one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [1])) + neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1])) result_dims = [ - op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, one_1d) + op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d) for d in dims ]