@@ -531,6 +531,97 @@ def slice(
531531 return out
532532
533533
534+ def narrow (
535+ input : Tensor ,
536+ dim : int ,
537+ start : int | Tensor ,
538+ length : int ,
539+ ) -> Tensor :
540+ """
541+ Returns a narrowed slice of input along a single axis.
542+
543+ This operator selects the index range [start, start + length) on dimension dim and keeps all
544+ the dimensions unchanged.
545+
546+ Args:
547+ input (Tensor): Input tensor.
548+ dim (int): Dimension to narrow. Supports negative indexing.
549+ start (int|Tensor): Start index on ``dim``. Can be a Python int or a 0-D
550+ int Tensor (int32 or int64). Negative values are supported.
551+ length (int): Number of elements to select from ``start``. Must be
552+ non-negative.
553+
554+ Returns:
555+ Tensor: A tensor that is a narrowed view of ``input``.
556+
557+ Examples:
558+ .. code-block:: python
559+
560+ >>> import paddle
561+
562+ >>> x = paddle.to_tensor([[1, 2, 3, 4],
563+ ... [5, 6, 7, 8]], dtype='int64')
564+
565+ >>> y1 = paddle.narrow(x, dim=1, start=1, length=2)
566+ >>> print(y1)
567+ Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
568+ [[2, 3],
569+ [6, 7]])
570+
571+ >>> y2 = paddle.narrow(x, dim=-1, start=-3, length=3)
572+ >>> print(y2)
573+ Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
574+ [[2, 3, 4],
575+ [6, 7, 8]])
576+
577+ >>> s = paddle.to_tensor(0, dtype='int64')
578+ >>> y3 = paddle.narrow(x, dim=1, start=s, length=2)
579+ >>> print(y3)
580+ Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
581+ [[1, 2],
582+ [5, 6]])
583+ """
584+
585+ if isinstance (start , paddle .Tensor ):
586+ assert start .ndim == 0 and start .dtype in [
587+ paddle .int32 ,
588+ paddle .int64 ,
589+ ], "start must be an 0-dim integral Tensor."
590+ start = start .item ()
591+ assert input .ndim > 0 , "narrow() cannot be applied to a 0-dim tensor."
592+ assert length >= 0 , "narrow(): length must be non-negative."
593+
594+ rank = input .ndim
595+ if input .ndim == 0 :
596+ rank = 1
597+
598+ if not (0 <= dim < rank ):
599+ _dim = dim + rank if dim < 0 else dim
600+ if _dim < 0 or _dim >= rank :
601+ raise IndexError (
602+ f"Dimension out of range (expected to be in range of [{ - rank } , { rank - 1 } ], but got { dim } )"
603+ )
604+ dim = _dim
605+
606+ dim_length = input .shape [dim ]
607+ assert - dim_length <= start <= dim_length , (
608+ f"start out of range (expected to be in range of [{ - dim_length } , { dim_length } ], but got { start } )"
609+ )
610+ if start < 0 :
611+ start = start + dim_length
612+ assert start <= dim_length - length , (
613+ f"start ({ start } ) + length ({ length } ) exceeds dimension size ({ dim_length } )."
614+ )
615+ new_shape = list (input .shape )
616+ new_shape [dim ] = length
617+ stride = input .strides
618+ offset = start * stride [dim ]
619+ offset *= paddle .core .size_of_dtype (input .dtype )
620+ return paddle .as_strided (
621+ input , shape = new_shape , stride = stride , offset = offset
622+ )
623+
624+
534625def transpose (
535626 x : Tensor , perm : Sequence [int ], name : str | None = None
536627) -> Tensor :
0 commit comments