@@ -1647,29 +1647,12 @@ def aten_choose_qparams_optimized(
1647
1647
raise NotImplementedError ()
1648
1648
1649
1649
1650
- @torch_op ("aten::chunk" )
1650
+ @torch_op ("aten::chunk" , trace_only = True )
1651
1651
def aten_chunk (self : TTensor , chunks : int , dim : int = 0 ) -> Sequence [TTensor ]:
1652
1652
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1653
- # This will create a Sequence of tensors
1654
- neg_1 = op .Constant (value_ints = [- 1 ])
1655
- # Get size of specified dim
1656
- self_shape = op .Shape (self )
1657
- dim_size = op .Gather (self_shape , dim , axis = 0 )
1658
- # Compute size/chunk to get the number of data in one chunk
1659
- num_per_chunk = op .Div (dim_size , chunks )
1660
- num_per_chunk = op .Cast (op .Mod (dim_size , chunks ) > 0 , to = INT64 .dtype ) + num_per_chunk # type: ignore[operator]
1661
-
1662
- # Compute real chunk number
1663
- num_chunk = op .Div (dim_size , num_per_chunk )
1664
- # Get something like [n, n, n, n, ...], total num_chunk
1665
- list_split = op .Expand (num_per_chunk , op .Reshape (num_chunk , neg_1 ))
1666
-
1667
- remainder = op .Mod (dim_size , num_per_chunk )
1668
- if remainder > 0 : # type: ignore[operator]
1669
- # Append the remainder to the [n, n, n, n, ..., r]
1670
- list_split = op .Concat (list_split , op .Reshape (remainder , neg_1 ), axis = 0 )
1671
-
1672
- return op .SplitToSequence (self , list_split , axis = dim )
1653
+ if chunks == 1 :
1654
+ return op .Identity (self )
1655
+ return op .Split (self , axis = dim , num_outputs = chunks )
1673
1656
1674
1657
1675
1658
@torch_op ("aten::clamp" , trace_only = True )
0 commit comments