diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index d7784a5289..b3ebbc1c53 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -5,6 +5,8 @@ # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" from __future__ import annotations +from collections.abc import Sequence + import numpy.typing as npt import onnx @@ -78,3 +80,22 @@ def constant( A constant node. """ return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype))) + + +def merge_dims(dims: Sequence[int | INT64]) -> INT64: + """Concatenate dimensions into a single value.""" + + if not dims: + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) + + neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1])) + + result_dims = [ + op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d) + for d in dims + ] + + # Set the output type to INT64 so op.Concat can be used + for dim in result_dims: + dim.dtype = ir.DataType.INT64 + return op.Concat(*result_dims, axis=0) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8bb1665aaf..3607a11361 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1523,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::broadcast_to") -def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor: +@torch_op("aten::broadcast_to", trace_only=True) +def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - + size = common_ops.merge_dims(size) return op.Expand(self, size) @@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward( @torch_op("aten::empty.memory_format", trace_only=True) def aten_empty( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, memory_format: str = "", ) -> TensorType: # type: ignore[type-var] - # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # using Zeros to simulate np.empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + # using Zeros to simulate empty() + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3334,17 +3334,18 @@ def aten_empty_quantized( @torch_op("aten::empty_strided", trace_only=True) def aten_empty_strided( - size: INT64, + size: Sequence[INT64], stride: INT64, layout: str = "", + dtype: int = FLOAT.dtype, device: str = "", pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3392,13 +3393,14 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand", trace_only=True) -def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor: +def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor: """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - size = op.Abs(size) - return op.Expand(self, size) + # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely + # and isn't expected to appear from correct usages of SymInt. + size = [1 if isinstance(s, int) and s == -1 else s for s in size] + return op.Expand(self, common_ops.merge_dims(size)) @torch_op("aten::expand_as", trace_only=True) @@ -7409,12 +7411,10 @@ def aten_repeat_interleave_Tensor( ) -@torch_op("aten::reshape") -def aten_reshape(self: TTensor, shape: IntType) -> TTensor: +@torch_op("aten::reshape", trace_only=True) +def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" - - # Reshape only support INT64 as 'shape' - shape = op.Cast(shape, to=INT64.dtype) + shape = common_ops.merge_dims(shape) return op.Reshape(self, shape) @@ -9153,23 +9153,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) -def aten_view(self: TTensor, size: IntType) -> TTensor: +def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size, allowzero=True) -@torch_op(("aten::view", "aten::_unsafe_view"), complex=True) -def aten_view_complex(self: TTensor, size: IntType) -> TTensor: +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True) +def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) + complex_size = common_ops.merge_dims([*size, 2]) return op.Reshape(self, complex_size, allowzero=True) -@torch_op("aten::view_as") +@torch_op("aten::view_as", trace_only=True) def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -9213,11 +9212,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_copy") -def aten_view_copy(self: TTensor, size: IntType) -> TTensor: +@torch_op("aten::view_copy", trace_only=True) +def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size) @@ -9245,7 +9244,8 @@ def reshape_to_2d(tensor): "aten::where.ScalarSelf", "aten::where.ScalarOther", "aten::where.self", - ) + ), + trace_only=True, ) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" @@ -9261,7 +9261,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::zeros", trace_only=True) def aten_zeros( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -9270,9 +9270,9 @@ def aten_zeros( """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size)