Skip to content

Commit fc792e4

Browse files
authored
[torchlib] Improve handling of SymInt[] (#2522)
Previously sizes coming in as `SymInt[]` are first concatenated as INT64 then used. This created inefficiencies where we could not process any static dims from the size list and had to treat the whole shape as dynamic. In aten_expand, this meant we needed to add `Abs` on the shape. This change updates the functions that take `SymInt[]` such that they are no longer turned into INT64 first. I updated aten_expand to process constant `-1` values so an `Abs` is not required. I also added a helper `merge_dims` to create constants for consecutive constant dims first before concatinating. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 456a6bc commit fc792e4

File tree

2 files changed

+57
-36
lines changed

2 files changed

+57
-36
lines changed

onnxscript/function_libs/torch_lib/ops/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
66
from __future__ import annotations
77

8+
from collections.abc import Sequence
9+
810
import numpy.typing as npt
911
import onnx
1012

@@ -78,3 +80,22 @@ def constant(
7880
A constant node.
7981
"""
8082
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))
83+
84+
85+
def merge_dims(dims: Sequence[int | INT64]) -> INT64:
86+
"""Concatenate dimensions into a single value."""
87+
88+
if not dims:
89+
return op.Constant(value_ints=ir.AttrInt64s("value_ints", []))
90+
91+
neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1]))
92+
93+
result_dims = [
94+
op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d)
95+
for d in dims
96+
]
97+
98+
# Set the output type to INT64 so op.Concat can be used
99+
for dim in result_dims:
100+
dim.dtype = ir.DataType.INT64
101+
return op.Concat(*result_dims, axis=0)

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:
15231523
raise NotImplementedError()
15241524

15251525

1526-
@torch_op("aten::broadcast_to")
1527-
def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor:
1526+
@torch_op("aten::broadcast_to", trace_only=True)
1527+
def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor:
15281528
"""broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
1529-
1529+
size = common_ops.merge_dims(size)
15301530
return op.Expand(self, size)
15311531

15321532

@@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward(
32863286

32873287
@torch_op("aten::empty.memory_format", trace_only=True)
32883288
def aten_empty(
3289-
size: IntType,
3289+
size: Sequence[INT64],
32903290
dtype: int = FLOAT.dtype,
32913291
layout: str = "",
32923292
device: str = "",
32933293
pin_memory: bool = False,
32943294
memory_format: str = "",
32953295
) -> TensorType: # type: ignore[type-var]
3296-
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
3296+
"""empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
32973297
if dtype == -1:
32983298
dtype = FLOAT.dtype
3299-
# using Zeros to simulate np.empty()
3300-
size = op.Cast(size, to=INT64.dtype)
3301-
zero = op.Constant(value_float=0.0)
3302-
zero = op.Cast(zero, to=dtype)
3299+
3300+
# using Zeros to simulate empty()
3301+
zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
3302+
size = common_ops.merge_dims(size)
33033303

33043304
return op.Expand(zero, size)
33053305

@@ -3334,17 +3334,18 @@ def aten_empty_quantized(
33343334

33353335
@torch_op("aten::empty_strided", trace_only=True)
33363336
def aten_empty_strided(
3337-
size: INT64,
3337+
size: Sequence[INT64],
33383338
stride: INT64,
33393339
layout: str = "",
3340+
dtype: int = FLOAT.dtype,
33403341
device: str = "",
33413342
pin_memory: bool = False,
33423343
) -> TTensor: # type: ignore[type-var]
33433344
# empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
33443345

33453346
# using Zeros to simulate empty()
3346-
size = op.Cast(size, to=INT64.dtype)
3347-
zero = op.Constant(value_float=0.0)
3347+
zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
3348+
size = common_ops.merge_dims(size)
33483349

33493350
return op.Expand(zero, size)
33503351

@@ -3392,13 +3393,14 @@ def aten_exp2(self: TFloat) -> TFloat:
33923393

33933394

33943395
@torch_op("aten::expand", trace_only=True)
3395-
def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor:
3396+
def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor:
33963397
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
3397-
size = op.Cast(size, to=INT64.dtype)
33983398
# NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1.
33993399
# To support -1 dim, we need to convert -1 to 1.
3400-
size = op.Abs(size)
3401-
return op.Expand(self, size)
3400+
# Even though in theory a dynamic dim can still be -1, in practice it is very unlikely
3401+
# and isn't expected to appear from correct usages of SymInt.
3402+
size = [1 if isinstance(s, int) and s == -1 else s for s in size]
3403+
return op.Expand(self, common_ops.merge_dims(size))
34023404

34033405

34043406
@torch_op("aten::expand_as", trace_only=True)
@@ -7409,12 +7411,10 @@ def aten_repeat_interleave_Tensor(
74097411
)
74107412

74117413

7412-
@torch_op("aten::reshape")
7413-
def aten_reshape(self: TTensor, shape: IntType) -> TTensor:
7414+
@torch_op("aten::reshape", trace_only=True)
7415+
def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor:
74147416
"""reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""
7415-
7416-
# Reshape only support INT64 as 'shape'
7417-
shape = op.Cast(shape, to=INT64.dtype)
7417+
shape = common_ops.merge_dims(shape)
74187418
return op.Reshape(self, shape)
74197419

74207420

@@ -9153,23 +9153,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
91539153

91549154

91559155
@torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True)
9156-
def aten_view(self: TTensor, size: IntType) -> TTensor:
9156+
def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor:
91579157
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
91589158

9159-
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
9159+
size = common_ops.merge_dims(size)
91609160
return op.Reshape(self, size, allowzero=True)
91619161

91629162

9163-
@torch_op(("aten::view", "aten::_unsafe_view"), complex=True)
9164-
def aten_view_complex(self: TTensor, size: IntType) -> TTensor:
9163+
@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True)
9164+
def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor:
91659165
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
91669166

9167-
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
9168-
complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0)
9167+
complex_size = common_ops.merge_dims([*size, 2])
91699168
return op.Reshape(self, complex_size, allowzero=True)
91709169

91719170

9172-
@torch_op("aten::view_as")
9171+
@torch_op("aten::view_as", trace_only=True)
91739172
def aten_view_as(self: TTensor, other: TTensor2) -> TTensor:
91749173
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
91759174

@@ -9213,11 +9212,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:
92139212
return op.Identity(self)
92149213

92159214

9216-
@torch_op("aten::view_copy")
9217-
def aten_view_copy(self: TTensor, size: IntType) -> TTensor:
9215+
@torch_op("aten::view_copy", trace_only=True)
9216+
def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor:
92189217
"""view_copy(Tensor self, SymInt[] size) -> Tensor"""
92199218

9220-
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
9219+
size = common_ops.merge_dims(size)
92219220
return op.Reshape(self, size)
92229221

92239222

@@ -9245,7 +9244,8 @@ def reshape_to_2d(tensor):
92459244
"aten::where.ScalarSelf",
92469245
"aten::where.ScalarOther",
92479246
"aten::where.self",
9248-
)
9247+
),
9248+
trace_only=True,
92499249
)
92509250
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
92519251
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""
@@ -9261,7 +9261,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
92619261

92629262
@torch_op("aten::zeros", trace_only=True)
92639263
def aten_zeros(
9264-
size: IntType,
9264+
size: Sequence[INT64],
92659265
dtype: int = FLOAT.dtype,
92669266
layout: str = "",
92679267
device: str = "",
@@ -9270,9 +9270,9 @@ def aten_zeros(
92709270
"""zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
92719271
if dtype == -1:
92729272
dtype = FLOAT.dtype
9273-
size = op.Cast(size, to=INT64.dtype)
9274-
zero = op.Constant(value_float=0.0)
9275-
zero = op.Cast(zero, to=dtype)
9273+
9274+
zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
9275+
size = common_ops.merge_dims(size)
92769276

92779277
return op.Expand(zero, size)
92789278

0 commit comments

Comments
 (0)