Skip to content

Commit 9022c90

Browse files
committed
add narrow
1 parent 7731d1c commit 9022c90

File tree

5 files changed

+457
-0
lines changed

5 files changed

+457
-0
lines changed

paddle/phi/kernels/funcs/strided_copy_kernel.cu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ bool CheckStride(
218218
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
219219
int rank,
220220
int64_t output_numel) {
221+
if (output_numel == 0) return true;
222+
221223
int64_t stride = output_numel;
222224
int64_t last_stride = 1;
223225
for (size_t i = 0; i < rank; i++) {

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@
358358
masked_scatter,
359359
masked_scatter_,
360360
moveaxis,
361+
narrow,
361362
put_along_axis,
362363
ravel,
363364
repeat_interleave,
@@ -937,6 +938,7 @@ def __dir__(self):
937938
'mv',
938939
'in_dynamic_mode',
939940
'min',
941+
'narrow',
940942
'amin',
941943
'any',
942944
'slice',

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193
masked_scatter,
194194
masked_scatter_,
195195
moveaxis,
196+
narrow,
196197
put_along_axis,
197198
put_along_axis_,
198199
ravel,
@@ -687,6 +688,7 @@
687688
'logical_or_',
688689
'logical_xor',
689690
'logical_xor_',
691+
'narrow',
690692
'not_equal',
691693
'not_equal_',
692694
'allclose',

python/paddle/tensor/manipulation.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,97 @@ def slice(
531531
return out
532532

533533

534+
def narrow(
535+
input: Tensor,
536+
dim: int,
537+
start: int | Tensor,
538+
length: int,
539+
) -> Tensor:
540+
"""
541+
Returns a narrowed slice of input along a single axis.
542+
543+
This operator selects the index range [start, start + length) on dimension dim and keeps all
544+
the dimensions unchanged.
545+
546+
Args:
547+
input (Tensor): Input tensor.
548+
dim (int): Dimension to narrow. Supports negative indexing.
549+
start (int|Tensor): Start index on ``dim``. Can be a Python int or a 0-D
550+
int Tensor (int32 or int64). Negative values are supported.
551+
length (int): Number of elements to select from ``start``. Must be
552+
non-negative.
553+
554+
Returns:
555+
Tensor: A tensor that is a narrowed view of ``input``.
556+
557+
Examples:
558+
.. code-block:: python
559+
560+
>>> import paddle
561+
562+
>>> x = paddle.to_tensor([[1, 2, 3, 4],
563+
... [5, 6, 7, 8]], dtype='int64')
564+
565+
>>> y1 = paddle.narrow(x, dim=1, start=1, length=2)
566+
>>> print(y1)
567+
Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
568+
[[2, 3],
569+
[6, 7]])
570+
571+
>>> y2 = paddle.narrow(x, dim=-1, start=-3, length=3)
572+
>>> print(y2)
573+
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
574+
[[2, 3, 4],
575+
[6, 7, 8]])
576+
577+
>>> s = paddle.to_tensor(0, dtype='int64')
578+
>>> y3 = paddle.narrow(x, dim=1, start=s, length=2)
579+
>>> print(y3)
580+
Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
581+
[[1, 2],
582+
[5, 6]])
583+
"""
584+
585+
if isinstance(start, paddle.Tensor):
586+
assert start.ndim == 0 and start.dtype in [
587+
paddle.int32,
588+
paddle.int64,
589+
], "start must be an 0-dim integral Tensor."
590+
start = start.item()
591+
assert input.ndim > 0, "narrow() cannot be applied to a 0-dim tensor."
592+
assert length >= 0, "narrow(): length must be non-negative."
593+
594+
rank = input.ndim
595+
if input.ndim == 0:
596+
rank = 1
597+
598+
if not (0 <= dim < rank):
599+
_dim = dim + rank if dim < 0 else dim
600+
if _dim < 0 or _dim >= rank:
601+
raise IndexError(
602+
f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {dim})"
603+
)
604+
dim = _dim
605+
606+
dim_length = input.shape[dim]
607+
assert -dim_length <= start <= dim_length, (
608+
f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})"
609+
)
610+
if start < 0:
611+
start = start + dim_length
612+
assert start <= dim_length - length, (
613+
f"start ({start}) + length ({length}) exceeds dimension size ({dim_length})."
614+
)
615+
new_shape = list(input.shape)
616+
new_shape[dim] = length
617+
stride = input.strides
618+
offset = start * stride[dim]
619+
offset *= paddle.core.size_of_dtype(input.dtype)
620+
return paddle.as_strided(
621+
input, shape=new_shape, stride=stride, offset=offset
622+
)
623+
624+
534625
def transpose(
535626
x: Tensor, perm: Sequence[int], name: str | None = None
536627
) -> Tensor:

0 commit comments

Comments
 (0)