diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 92b8abb36d..da16bba726 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1647,29 +1647,12 @@ def aten_choose_qparams_optimized( raise NotImplementedError() -@torch_op("aten::chunk") +@torch_op("aten::chunk", trace_only=True) def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" - # This will create a Sequence of tensors - neg_1 = op.Constant(value_ints=[-1]) - # Get size of specified dim - self_shape = op.Shape(self) - dim_size = op.Gather(self_shape, dim, axis=0) - # Compute size/chunk to get the number of data in one chunk - num_per_chunk = op.Div(dim_size, chunks) - num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] - - # Compute real chunk number - num_chunk = op.Div(dim_size, num_per_chunk) - # Get something like [n, n, n, n, ...], total num_chunk - list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) - - remainder = op.Mod(dim_size, num_per_chunk) - if remainder > 0: # type: ignore[operator] - # Append the remainder to the [n, n, n, n, ..., r] - list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) - - return op.SplitToSequence(self, list_split, axis=dim) + if chunks == 1: + return op.Identity(self) + return op.Split(self, axis=dim, num_outputs=chunks) @torch_op("aten::clamp", trace_only=True)