Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations

from collections.abc import Sequence

import numpy.typing as npt
import onnx

Expand Down Expand Up @@ -78,3 +80,22 @@ def constant(
A constant node.
"""
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))


def merge_dims(dims: Sequence[int | INT64]) -> INT64:
"""Concatenate dimensions into a single value."""

if not dims:
return op.Constant(value_ints=ir.AttrInt64s("value_ints", []))

neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1]))

result_dims = [
op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d)
for d in dims
]

# Set the output type to INT64 so op.Concat can be used
for dim in result_dims:
dim.dtype = ir.DataType.INT64
return op.Concat(*result_dims, axis=0)
72 changes: 36 additions & 36 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()


@torch_op("aten::broadcast_to")
def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor:
@torch_op("aten::broadcast_to", trace_only=True)
def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = common_ops.merge_dims(size)
return op.Expand(self, size)


Expand Down Expand Up @@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward(

@torch_op("aten::empty.memory_format", trace_only=True)
def aten_empty(
size: IntType,
size: Sequence[INT64],
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
pin_memory: bool = False,
memory_format: str = "",
) -> TensorType: # type: ignore[type-var]
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
"""empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
if dtype == -1:
dtype = FLOAT.dtype
# using Zeros to simulate np.empty()
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)
zero = op.Cast(zero, to=dtype)

# using Zeros to simulate empty()
zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
size = common_ops.merge_dims(size)

return op.Expand(zero, size)

Expand Down Expand Up @@ -3334,17 +3334,18 @@ def aten_empty_quantized(

@torch_op("aten::empty_strided", trace_only=True)
def aten_empty_strided(
size: INT64,
size: Sequence[INT64],
stride: INT64,
layout: str = "",
dtype: int = FLOAT.dtype,
device: str = "",
pin_memory: bool = False,
) -> TTensor: # type: ignore[type-var]
# empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

# using Zeros to simulate empty()
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)
zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
size = common_ops.merge_dims(size)

return op.Expand(zero, size)

Expand Down Expand Up @@ -3392,13 +3393,14 @@ def aten_exp2(self: TFloat) -> TFloat:


@torch_op("aten::expand", trace_only=True)
def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor:
def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor:
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
size = op.Cast(size, to=INT64.dtype)
# NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1.
# To support -1 dim, we need to convert -1 to 1.
size = op.Abs(size)
return op.Expand(self, size)
# Even though in theory a dynamic dim can still be -1, in practice it is very unlikely
# and isn't expected to appear from correct usages of SymInt.
size = [1 if isinstance(s, int) and s == -1 else s for s in size]
return op.Expand(self, common_ops.merge_dims(size))


@torch_op("aten::expand_as", trace_only=True)
Expand Down Expand Up @@ -7409,12 +7411,10 @@ def aten_repeat_interleave_Tensor(
)


@torch_op("aten::reshape")
def aten_reshape(self: TTensor, shape: IntType) -> TTensor:
@torch_op("aten::reshape", trace_only=True)
def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor:
"""reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""

# Reshape only support INT64 as 'shape'
shape = op.Cast(shape, to=INT64.dtype)
shape = common_ops.merge_dims(shape)
return op.Reshape(self, shape)


Expand Down Expand Up @@ -9153,23 +9153,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:


@torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True)
def aten_view(self: TTensor, size: IntType) -> TTensor:
def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
size = common_ops.merge_dims(size)
return op.Reshape(self, size, allowzero=True)


@torch_op(("aten::view", "aten::_unsafe_view"), complex=True)
def aten_view_complex(self: TTensor, size: IntType) -> TTensor:
@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True)
def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0)
complex_size = common_ops.merge_dims([*size, 2])
return op.Reshape(self, complex_size, allowzero=True)


@torch_op("aten::view_as")
@torch_op("aten::view_as", trace_only=True)
def aten_view_as(self: TTensor, other: TTensor2) -> TTensor:
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""

Expand Down Expand Up @@ -9213,11 +9212,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:
return op.Identity(self)


@torch_op("aten::view_copy")
def aten_view_copy(self: TTensor, size: IntType) -> TTensor:
@torch_op("aten::view_copy", trace_only=True)
def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view_copy(Tensor self, SymInt[] size) -> Tensor"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
size = common_ops.merge_dims(size)
return op.Reshape(self, size)


Expand Down Expand Up @@ -9245,7 +9244,8 @@ def reshape_to_2d(tensor):
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
"aten::where.self",
)
),
trace_only=True,
)
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""
Expand All @@ -9261,7 +9261,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:

@torch_op("aten::zeros", trace_only=True)
def aten_zeros(
size: IntType,
size: Sequence[INT64],
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
Expand All @@ -9270,9 +9270,9 @@ def aten_zeros(
"""zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
if dtype == -1:
dtype = FLOAT.dtype
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)
zero = op.Cast(zero, to=dtype)

zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
size = common_ops.merge_dims(size)

return op.Expand(zero, size)

Expand Down
Loading