1818
1919import paddle
2020from paddle import _C_ops
21- from paddle .tensor import fill_constant
2221
23- from ..base .data_feeder import (
24- check_dtype ,
25- check_type ,
26- check_variable_and_dtype ,
27- )
2822from ..base .framework import Variable
2923from ..framework import (
30- LayerHelper ,
3124 in_dynamic_mode ,
32- in_pir_mode ,
3325)
3426
3527if TYPE_CHECKING :
4537@forbid_keywords (["x" , "num_or_sections" , "axis" , "name" ], "paddle.split" )
4638def split (
4739 tensor : Tensor , split_size_or_sections : int | Sequence [int ], dim : int = 0
48- ) -> tuple [Tensor ]:
40+ ) -> tuple [Tensor , ... ]:
4941 """
5042 (PyTorch Compatible API) Split the input tensor into multiple sub-Tensors.
5143
@@ -72,7 +64,7 @@ def split(
7264
7365 >>> import paddle
7466
75- >>> # x is a Tensor of shape [3, 9 , 5]
67+ >>> # x is a Tensor of shape [3, 8 , 5]
7668 >>> x = paddle.rand([3, 8, 5])
7769
7870 >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
@@ -170,7 +162,7 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
170162 )
171163 else :
172164 return tuple (_C_ops .split (tensor , split_size_or_sections , dim ))
173- elif in_pir_mode () :
165+ else :
174166 if isinstance (dim , paddle .pir .Value ):
175167 dim .stop_gradient = True
176168 if isinstance (dim , int ):
@@ -212,108 +204,3 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
212204 split_size_or_sections
213205 )
214206 return tuple (_C_ops .split (tensor , split_size_or_sections , dim ))
215-
216- else :
217- check_variable_and_dtype (
218- tensor ,
219- 'input' ,
220- [
221- 'bool' ,
222- 'bfloat16' ,
223- 'float16' ,
224- 'uint16' ,
225- 'float32' ,
226- 'float64' ,
227- 'int32' ,
228- 'int64' ,
229- 'uint8' ,
230- 'int8' ,
231- ],
232- 'split' ,
233- )
234- check_type (
235- split_size_or_sections ,
236- 'split_size_or_sections' ,
237- (list , int , tuple ),
238- 'split' ,
239- )
240- check_type (dim , 'dim' , (int , Variable ), 'split' )
241- if isinstance (dim , Variable ):
242- check_dtype (dim .dtype , 'dim' , ['int32' , 'int64' ], 'split' )
243-
244- helper = LayerHelper ('split' , ** locals ())
245-
246- input_shape = tensor .shape
247- inputs = {'X' : tensor }
248- attrs = {'num' : 0 }
249-
250- def _get_SectionsTensorList (one_list ):
251- tensor_list = []
252- unk_dim_idx = - 1
253- for idx , dim_size in enumerate (one_list ):
254- if isinstance (dim_size , Variable ):
255- dim_size .stop_gradient = True
256- tensor_list .append (dim_size )
257- else :
258- assert isinstance (dim_size , int )
259- if dim_size == - 1 :
260- assert unk_dim_idx == - 1 , (
261- "Only one value of 'num_or_section' in split can "
262- f"be -1. But received num_or_section[{ idx } ] is also -1."
263- )
264- unk_dim_idx = idx
265- temp_out = helper .create_variable_for_type_inference (
266- 'int32'
267- )
268- fill_constant (
269- [1 ], 'int32' , dim_size , force_cpu = True , out = temp_out
270- )
271- tensor_list .append (temp_out )
272- return tuple (tensor_list )
273-
274- if isinstance (dim , Variable ):
275- dim .stop_gradient = True
276- inputs ['AxisTensor' ] = dim
277- else :
278- assert len (tensor .shape ) + dim >= 0 , "(rank(x) + dim) must >= 0"
279- dim = (len (input_shape ) + dim ) if dim < 0 else dim
280- attrs ['axis' ] = dim
281-
282- if isinstance (split_size_or_sections , int ):
283- shape_on_dim = SaveGetShapeOnDim (tensor .shape , dim )
284- split_size_or_sections = GetSplitSize (
285- split_size_or_sections , shape_on_dim
286- )
287-
288- if isinstance (split_size_or_sections , int ):
289- # after GetSplitSize, if the result is int, split_size_or_sections is actually equivalent to the original num_or_sections (num)
290- attrs ['num' ] = split_size_or_sections
291- assert (
292- split_size_or_sections > 0
293- ), 'split_size_or_sections must be than 0.'
294- num = split_size_or_sections
295- else :
296- if isinstance (dim , int ) and input_shape [dim ] > 0 :
297- assert (
298- len (split_size_or_sections ) <= input_shape [dim ]
299- ), 'len(split_size_or_sections) must not be more than input.shape[dim].'
300- num = len (split_size_or_sections )
301- attrs ['sections' ] = [
302- - 1 if isinstance (ele , Variable ) else ele
303- for ele in split_size_or_sections
304- ]
305- if paddle .utils ._contain_var (split_size_or_sections ):
306- inputs ['SectionsTensorList' ] = _get_SectionsTensorList (
307- split_size_or_sections
308- )
309-
310- outs = [
311- helper .create_variable_for_type_inference (
312- dtype = helper .input_dtype ()
313- )
314- for i in range (num )
315- ]
316- helper .append_op (
317- type = 'split' , inputs = inputs , outputs = {'Out' : outs }, attrs = attrs
318- )
319- return tuple (outs )
0 commit comments