Skip to content

Commit d905b2a

Browse files
authored
Improves aten_chunk conversion (#2434)
1 parent 75c1a4d commit d905b2a

File tree

1 file changed

+4
-21
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+4
-21
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,29 +1647,12 @@ def aten_choose_qparams_optimized(
16471647
raise NotImplementedError()
16481648

16491649

1650-
@torch_op("aten::chunk")
1650+
@torch_op("aten::chunk", trace_only=True)
16511651
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
16521652
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1653-
# This will create a Sequence of tensors
1654-
neg_1 = op.Constant(value_ints=[-1])
1655-
# Get size of specified dim
1656-
self_shape = op.Shape(self)
1657-
dim_size = op.Gather(self_shape, dim, axis=0)
1658-
# Compute size/chunk to get the number of data in one chunk
1659-
num_per_chunk = op.Div(dim_size, chunks)
1660-
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1661-
1662-
# Compute real chunk number
1663-
num_chunk = op.Div(dim_size, num_per_chunk)
1664-
# Get something like [n, n, n, n, ...], total num_chunk
1665-
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
1666-
1667-
remainder = op.Mod(dim_size, num_per_chunk)
1668-
if remainder > 0: # type: ignore[operator]
1669-
# Append the remainder to the [n, n, n, n, ..., r]
1670-
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
1671-
1672-
return op.SplitToSequence(self, list_split, axis=dim)
1653+
if chunks == 1:
1654+
return op.Identity(self)
1655+
return op.Split(self, axis=dim, num_outputs=chunks)
16731656

16741657

16751658
@torch_op("aten::clamp", trace_only=True)

0 commit comments

Comments
 (0)