Skip to content

Commit f16d8d8

Browse files
committed
add narrow
1 parent ce17cf1 commit f16d8d8

File tree

5 files changed

+445
-0
lines changed

5 files changed

+445
-0
lines changed

paddle/phi/kernels/funcs/strided_copy_kernel.cu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ bool CheckStride(
218218
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
219219
int rank,
220220
int64_t output_numel) {
221+
if (output_numel == 0) return true;
222+
221223
int64_t stride = output_numel;
222224
int64_t last_stride = 1;
223225
for (size_t i = 0; i < rank; i++) {

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@
333333
masked_scatter,
334334
masked_scatter_,
335335
moveaxis,
336+
narrow,
336337
put_along_axis,
337338
ravel,
338339
repeat_interleave,
@@ -873,6 +874,7 @@
873874
'mv',
874875
'in_dynamic_mode',
875876
'min',
877+
'narrow',
876878
'amin',
877879
'any',
878880
'slice',

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
masked_scatter,
192192
masked_scatter_,
193193
moveaxis,
194+
narrow,
194195
put_along_axis,
195196
put_along_axis_,
196197
ravel,
@@ -671,6 +672,7 @@
671672
'logical_or_',
672673
'logical_xor',
673674
'logical_xor_',
675+
'narrow',
674676
'not_equal',
675677
'not_equal_',
676678
'allclose',

python/paddle/tensor/manipulation.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,97 @@ def slice(
528528
return out
529529

530530

531+
def narrow(
532+
input: Tensor,
533+
dim: int,
534+
start: Sequence[int | Tensor],
535+
length: int,
536+
) -> Tensor:
537+
"""
538+
Returns a narrowed slice of input along a single axis.
539+
540+
This operator selects the index range [start, start + length) on dimension dim and keeps all
541+
the dimensions unchanged.
542+
543+
Args:
544+
input (Tensor): Input tensor.
545+
dim (int): Dimension to narrow. Supports negative indexing.
546+
start (int|Tensor): Start index on ``dim``. Can be a Python int or a 0-D
547+
int Tensor (int32 or int64). Negative values are supported.
548+
length (int): Number of elements to select from ``start``. Must be
549+
non-negative.
550+
551+
Returns:
552+
Tensor: A tensor that is a narrowed view of ``input``.
553+
554+
Examples:
555+
.. code-block:: python
556+
557+
>>> import paddle
558+
559+
>>> x = paddle.to_tensor([[1, 2, 3, 4],
560+
... [5, 6, 7, 8]], dtype='int64')
561+
562+
>>> y1 = paddle.narrow(x, dim=1, start=1, length=2)
563+
>>> print(y1)
564+
Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
565+
[[2, 3],
566+
[6, 7]])
567+
568+
>>> y2 = paddle.narrow(x, dim=-1, start=-3, length=3)
569+
>>> print(y2)
570+
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
571+
[[2, 3, 4],
572+
[6, 7, 8]])
573+
574+
>>> s = paddle.to_tensor(0, dtype='int64')
575+
>>> y3 = paddle.narrow(x, dim=1, start=s, length=2)
576+
>>> print(y3)
577+
Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
578+
[[1, 2],
579+
[5, 6]])
580+
"""
581+
582+
if isinstance(start, paddle.Tensor):
583+
assert start.ndim == 0 and start.dtype in [
584+
paddle.int32,
585+
paddle.int64,
586+
], "start must be an 0-dim integral Tensor."
587+
start = start.item()
588+
assert input.ndim > 0, "narrow() cannot be applied to a 0-dim tensor."
589+
assert length >= 0, "narrow(): length must be non-negative."
590+
591+
rank = input.ndim
592+
if input.ndim == 0:
593+
rank = 1
594+
595+
if not (0 <= dim < rank):
596+
_dim = dim + rank if dim < 0 else dim
597+
if _dim < 0 or _dim >= rank:
598+
raise IndexError(
599+
f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {dim})"
600+
)
601+
dim = _dim
602+
603+
dim_length = input.shape[dim]
604+
assert (
605+
-dim_length <= start <= dim_length
606+
), f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})"
607+
if start < 0:
608+
start = start + dim_length
609+
assert (
610+
start <= dim_length - length
611+
), f"start ({start}) + length ({length}) exceeds dimension size ({dim_length})."
612+
new_shape = list(input.shape)
613+
new_shape[dim] = length
614+
stride = input.strides
615+
offset = start * stride[dim]
616+
617+
return paddle.as_strided(
618+
input, shape=new_shape, stride=stride, offset=offset
619+
)
620+
621+
531622
def transpose(
532623
x: Tensor, perm: Sequence[int], name: str | None = None
533624
) -> Tensor:
@@ -7415,6 +7506,7 @@ def as_strided(
74157506
[8, 6]
74167507
>>> # the stride is [6, 1].
74177508
"""
7509+
offset *= paddle.core.size_of_dtype(x.dtype)
74187510
return _C_ops.as_strided(x, shape, stride, offset)
74197511

74207512

0 commit comments

Comments
 (0)