Skip to content

Commit 131e497

Browse files
justinchubyxadupre
authored andcommitted
[torchlib] Improves aten_chunk conversion (microsoft#2469)
Simplify implementation for `aten_chunk` and allow it to work on all data types. Original author: @xadupre Updated: Conditionally use the new implementation when torch>=2.7 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
1 parent 127aee8 commit 131e497

File tree

3 files changed

+38
-36
lines changed

3 files changed

+38
-36
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
graph,
3737
ir,
3838
)
39+
from onnxscript._internal import version_utils
3940
from onnxscript.function_libs.torch_lib.ops import common as common_ops
4041
from onnxscript.function_libs.torch_lib.registration import torch_op
4142
from onnxscript.function_libs.torch_lib.tensor_typing import (
@@ -1647,29 +1648,40 @@ def aten_choose_qparams_optimized(
16471648
raise NotImplementedError()
16481649

16491650

1650-
@torch_op("aten::chunk")
1651-
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1652-
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1653-
# This will create a Sequence of tensors
1654-
neg_1 = op.Constant(value_ints=[-1])
1655-
# Get size of specified dim
1656-
self_shape = op.Shape(self)
1657-
dim_size = op.Gather(self_shape, dim, axis=0)
1658-
# Compute size/chunk to get the number of data in one chunk
1659-
num_per_chunk = op.Div(dim_size, chunks)
1660-
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1661-
1662-
# Compute real chunk number
1663-
num_chunk = op.Div(dim_size, num_per_chunk)
1664-
# Get something like [n, n, n, n, ...], total num_chunk
1665-
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
1666-
1667-
remainder = op.Mod(dim_size, num_per_chunk)
1668-
if remainder > 0: # type: ignore[operator]
1669-
# Append the remainder to the [n, n, n, n, ..., r]
1670-
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
1671-
1672-
return op.SplitToSequence(self, list_split, axis=dim)
1651+
if version_utils.torch_older_than("2.7.0"):
1652+
# PyTorch <2.7 does not support determining the number of outputs for the Split op
1653+
# https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6
1654+
@torch_op("aten::chunk")
1655+
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1656+
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1657+
# This will create a Sequence of tensors
1658+
neg_1 = op.Constant(value_ints=[-1])
1659+
# Get size of specified dim
1660+
self_shape = op.Shape(self)
1661+
dim_size = op.Gather(self_shape, dim, axis=0)
1662+
# Compute size/chunk to get the number of data in one chunk
1663+
num_per_chunk = op.Div(dim_size, chunks)
1664+
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1665+
1666+
# Compute real chunk number
1667+
num_chunk = op.Div(dim_size, num_per_chunk)
1668+
# Get something like [n, n, n, n, ...], total num_chunk
1669+
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
1670+
1671+
remainder = op.Mod(dim_size, num_per_chunk)
1672+
if remainder > 0: # type: ignore[operator]
1673+
# Append the remainder to the [n, n, n, n, ..., r]
1674+
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
1675+
1676+
return op.SplitToSequence(self, list_split, axis=dim)
1677+
else:
1678+
1679+
@torch_op("aten::chunk", trace_only=True)
1680+
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1681+
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1682+
if chunks == 1:
1683+
return op.Identity(self)
1684+
return op.Split(self, axis=dim, num_outputs=chunks)
16731685

16741686

16751687
@torch_op("aten::clamp", trace_only=True)

tests/function_libs/torch_lib/ops_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def run_test_output_match(
200200
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
201201
if (
202202
op.name.startswith("split")
203-
or op.name.startswith("chunk")
204203
or op.name.startswith("unbind")
205204
or op.name
206205
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -694,18 +694,9 @@ def _where_input_wrangler(
694694
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
695695
),
696696
TorchLibOpInfo("ceil", core_ops.aten_ceil),
697-
TorchLibOpInfo(
698-
"chunk",
699-
core_ops.aten_chunk,
700-
)
701-
.xfail(
702-
dtypes=(torch.float16,),
703-
enabled_if=version_utils.onnxruntime_older_than("1.17"),
704-
reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006",
705-
)
706-
.xfail(
707-
dtypes=(torch.bool,),
708-
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
697+
TorchLibOpInfo("chunk", core_ops.aten_chunk).skip(
698+
enabled_if=version_utils.torch_older_than("2.7"),
699+
reason="Test for chunk is not configured for torch<2.7",
709700
),
710701
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip(
711702
reason="Size 0 inputs are not handled by design",

0 commit comments

Comments
 (0)