Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/strided_copy_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ bool CheckStride(
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
int rank,
int64_t output_numel) {
if (output_numel == 0) return true;

int64_t stride = output_numel;
int64_t last_stride = 1;
for (size_t i = 0; i < rank; i++) {
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@
masked_scatter,
masked_scatter_,
moveaxis,
narrow,
put_along_axis,
ravel,
repeat_interleave,
Expand Down Expand Up @@ -937,6 +938,7 @@ def __dir__(self):
'mv',
'in_dynamic_mode',
'min',
'narrow',
'amin',
'any',
'slice',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
masked_scatter,
masked_scatter_,
moveaxis,
narrow,
put_along_axis,
put_along_axis_,
ravel,
Expand Down Expand Up @@ -687,6 +688,7 @@
'logical_or_',
'logical_xor',
'logical_xor_',
'narrow',
'not_equal',
'not_equal_',
'allclose',
Expand Down
91 changes: 91 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,97 @@ def slice(
return out


def narrow(
input: Tensor,
dim: int,
start: int | Tensor,
length: int,
) -> Tensor:
"""
Returns a narrowed slice of input along a single axis.

This operator selects the index range [start, start + length) on dimension dim and keeps all
the dimensions unchanged.

Args:
input (Tensor): Input tensor.
dim (int): Dimension to narrow. Supports negative indexing.
start (int|Tensor): Start index on ``dim``. Can be a Python int or a 0-D
int Tensor (int32 or int64). Negative values are supported.
length (int): Number of elements to select from ``start``. Must be
non-negative.

Returns:
Tensor: A tensor that is a narrowed view of ``input``.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([[1, 2, 3, 4],
... [5, 6, 7, 8]], dtype='int64')

>>> y1 = paddle.narrow(x, dim=1, start=1, length=2)
>>> print(y1)
Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
[[2, 3],
[6, 7]])

>>> y2 = paddle.narrow(x, dim=-1, start=-3, length=3)
>>> print(y2)
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[2, 3, 4],
[6, 7, 8]])

>>> s = paddle.to_tensor(0, dtype='int64')
>>> y3 = paddle.narrow(x, dim=1, start=s, length=2)
>>> print(y3)
Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
[[1, 2],
[5, 6]])
"""

if isinstance(start, paddle.Tensor):
assert start.ndim == 0 and start.dtype in [
paddle.int32,
paddle.int64,
], "start must be an 0-dim integral Tensor."
start = start.item()
assert input.ndim > 0, "narrow() cannot be applied to a 0-dim tensor."
assert length >= 0, "narrow(): length must be non-negative."

rank = input.ndim
if input.ndim == 0:
rank = 1

if not (0 <= dim < rank):
_dim = dim + rank if dim < 0 else dim
if _dim < 0 or _dim >= rank:
raise IndexError(
f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {dim})"
)
dim = _dim

dim_length = input.shape[dim]
assert -dim_length <= start <= dim_length, (
f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})"
)
if start < 0:
start = start + dim_length
assert start <= dim_length - length, (
f"start ({start}) + length ({length}) exceeds dimension size ({dim_length})."
)
new_shape = list(input.shape)
new_shape[dim] = length
stride = input.strides
offset = start * stride[dim]
offset *= paddle.core.size_of_dtype(input.dtype)
return paddle.as_strided(
input, shape=new_shape, stride=stride, offset=offset
)


def transpose(
x: Tensor, perm: Sequence[int], name: str | None = None
) -> Tensor:
Expand Down
Loading
Loading