Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,29 +1647,12 @@ def aten_choose_qparams_optimized(
raise NotImplementedError()


@torch_op("aten::chunk")
@torch_op("aten::chunk", trace_only=True)
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
# This will create a Sequence of tensors
neg_1 = op.Constant(value_ints=[-1])
# Get size of specified dim
self_shape = op.Shape(self)
dim_size = op.Gather(self_shape, dim, axis=0)
# Compute size/chunk to get the number of data in one chunk
num_per_chunk = op.Div(dim_size, chunks)
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]

# Compute real chunk number
num_chunk = op.Div(dim_size, num_per_chunk)
# Get something like [n, n, n, n, ...], total num_chunk
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))

remainder = op.Mod(dim_size, num_per_chunk)
if remainder > 0: # type: ignore[operator]
# Append the remainder to the [n, n, n, n, ..., r]
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)

return op.SplitToSequence(self, list_split, axis=dim)
if chunks == 1:
return op.Identity(self)
return op.Split(self, axis=dim, num_outputs=chunks)


@torch_op("aten::clamp", trace_only=True)
Expand Down
Loading