|
36 | 36 | graph,
|
37 | 37 | ir,
|
38 | 38 | )
|
| 39 | +from onnxscript._internal import version_utils |
39 | 40 | from onnxscript.function_libs.torch_lib.ops import common as common_ops
|
40 | 41 | from onnxscript.function_libs.torch_lib.registration import torch_op
|
41 | 42 | from onnxscript.function_libs.torch_lib.tensor_typing import (
|
@@ -1647,29 +1648,40 @@ def aten_choose_qparams_optimized(
|
1647 | 1648 | raise NotImplementedError()
|
1648 | 1649 |
|
1649 | 1650 |
|
1650 |
| -@torch_op("aten::chunk") |
1651 |
| -def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
1652 |
| - """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
1653 |
| - # This will create a Sequence of tensors |
1654 |
| - neg_1 = op.Constant(value_ints=[-1]) |
1655 |
| - # Get size of specified dim |
1656 |
| - self_shape = op.Shape(self) |
1657 |
| - dim_size = op.Gather(self_shape, dim, axis=0) |
1658 |
| - # Compute size/chunk to get the number of data in one chunk |
1659 |
| - num_per_chunk = op.Div(dim_size, chunks) |
1660 |
| - num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] |
1661 |
| - |
1662 |
| - # Compute real chunk number |
1663 |
| - num_chunk = op.Div(dim_size, num_per_chunk) |
1664 |
| - # Get something like [n, n, n, n, ...], total num_chunk |
1665 |
| - list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) |
1666 |
| - |
1667 |
| - remainder = op.Mod(dim_size, num_per_chunk) |
1668 |
| - if remainder > 0: # type: ignore[operator] |
1669 |
| - # Append the remainder to the [n, n, n, n, ..., r] |
1670 |
| - list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) |
1671 |
| - |
1672 |
| - return op.SplitToSequence(self, list_split, axis=dim) |
| 1651 | +if version_utils.torch_older_than("2.7.0"): |
| 1652 | + # PyTorch <2.7 does not support determining the number of outputs for the Split op |
| 1653 | + # https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6 |
| 1654 | + @torch_op("aten::chunk") |
| 1655 | + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
| 1656 | + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
| 1657 | + # This will create a Sequence of tensors |
| 1658 | + neg_1 = op.Constant(value_ints=[-1]) |
| 1659 | + # Get size of specified dim |
| 1660 | + self_shape = op.Shape(self) |
| 1661 | + dim_size = op.Gather(self_shape, dim, axis=0) |
| 1662 | + # Compute size/chunk to get the number of data in one chunk |
| 1663 | + num_per_chunk = op.Div(dim_size, chunks) |
| 1664 | + num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] |
| 1665 | + |
| 1666 | + # Compute real chunk number |
| 1667 | + num_chunk = op.Div(dim_size, num_per_chunk) |
| 1668 | + # Get something like [n, n, n, n, ...], total num_chunk |
| 1669 | + list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) |
| 1670 | + |
| 1671 | + remainder = op.Mod(dim_size, num_per_chunk) |
| 1672 | + if remainder > 0: # type: ignore[operator] |
| 1673 | + # Append the remainder to the [n, n, n, n, ..., r] |
| 1674 | + list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) |
| 1675 | + |
| 1676 | + return op.SplitToSequence(self, list_split, axis=dim) |
| 1677 | +else: |
| 1678 | + |
| 1679 | + @torch_op("aten::chunk", trace_only=True) |
| 1680 | + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
| 1681 | + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
| 1682 | + if chunks == 1: |
| 1683 | + return op.Identity(self) |
| 1684 | + return op.Split(self, axis=dim, num_outputs=chunks) |
1673 | 1685 |
|
1674 | 1686 |
|
1675 | 1687 | @torch_op("aten::clamp", trace_only=True)
|
|
0 commit comments