Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
graph,
ir,
)
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
Expand Down Expand Up @@ -1647,29 +1648,40 @@ def aten_choose_qparams_optimized(
raise NotImplementedError()


@torch_op("aten::chunk")
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
# This will create a Sequence of tensors
neg_1 = op.Constant(value_ints=[-1])
# Get size of specified dim
self_shape = op.Shape(self)
dim_size = op.Gather(self_shape, dim, axis=0)
# Compute size/chunk to get the number of data in one chunk
num_per_chunk = op.Div(dim_size, chunks)
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]

# Compute real chunk number
num_chunk = op.Div(dim_size, num_per_chunk)
# Get something like [n, n, n, n, ...], total num_chunk
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))

remainder = op.Mod(dim_size, num_per_chunk)
if remainder > 0: # type: ignore[operator]
# Append the remainder to the [n, n, n, n, ..., r]
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)

return op.SplitToSequence(self, list_split, axis=dim)
if version_utils.torch_older_than("2.7.0"):
# PyTorch <2.7 does not support determining the number of outputs for the Split op
# https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6
@torch_op("aten::chunk")
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
# This will create a Sequence of tensors
neg_1 = op.Constant(value_ints=[-1])
# Get size of specified dim
self_shape = op.Shape(self)
dim_size = op.Gather(self_shape, dim, axis=0)
# Compute size/chunk to get the number of data in one chunk
num_per_chunk = op.Div(dim_size, chunks)
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]

# Compute real chunk number
num_chunk = op.Div(dim_size, num_per_chunk)
# Get something like [n, n, n, n, ...], total num_chunk
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))

remainder = op.Mod(dim_size, num_per_chunk)
if remainder > 0: # type: ignore[operator]
# Append the remainder to the [n, n, n, n, ..., r]
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)

return op.SplitToSequence(self, list_split, axis=dim)
else:

@torch_op("aten::chunk", trace_only=True)
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
if chunks == 1:
return op.Identity(self)
return op.Split(self, axis=dim, num_outputs=chunks)


@torch_op("aten::clamp", trace_only=True)
Expand Down
1 change: 0 additions & 1 deletion tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def run_test_output_match(
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
if (
op.name.startswith("split")
or op.name.startswith("chunk")
or op.name.startswith("unbind")
or op.name
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
Expand Down
15 changes: 3 additions & 12 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,18 +694,9 @@ def _where_input_wrangler(
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
),
TorchLibOpInfo("ceil", core_ops.aten_ceil),
TorchLibOpInfo(
"chunk",
core_ops.aten_chunk,
)
.xfail(
dtypes=(torch.float16,),
enabled_if=version_utils.onnxruntime_older_than("1.17"),
reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006",
)
.xfail(
dtypes=(torch.bool,),
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
TorchLibOpInfo("chunk", core_ops.aten_chunk).skip(
enabled_if=version_utils.torch_older_than("2.7"),
reason="Test for chunk is not configured for torch<2.7",
),
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip(
reason="Size 0 inputs are not handled by design",
Expand Down
Loading