Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.26】为 Paddle 新增 diagonal_scatter API #6289

Merged
merged 9 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/paddle/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ tensor 元素操作相关(如:转置,reshape 等)
" :ref:`paddle.view_as <cn_api_paddle_view_as>` ", "使用 other 的 shape,返回 x 的一个 view Tensor"
" :ref:`paddle.unfold <cn_api_paddle_unfold>` ", "返回 x 的一个 view Tensor。以滑动窗口式提取 x 的值"
" :ref:`paddle.masked_fill <cn_api_paddle_masked_fill>` ", "根据 mask 信息,将 value 中的值填充到 x 中 mask 对应为 True 的位置。"
" :ref:`paddle.diagonal_scatter <cn_api_paddle_diagonal_scatter>` ", "根据给定的轴 axis 和偏移量 offset,将张量 y 的值填充到张量 x 中"
" :ref:`paddle.index_fill <cn_api_paddle_index_fill>` ", "沿着指定轴 axis 将 index 中指定位置的 x 的值填充为 value"

.. _tensor_manipulation_inplace:
Expand Down
10 changes: 10 additions & 0 deletions docs/api/paddle/Tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3060,3 +3060,13 @@ masked_fill_(x, mask, value, name=None)
:::::::::

Inplace 版本的 :ref:`cn_api_paddle_masked_fill` API,对输入 `x` 采用 Inplace 策略。

diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
:::::::::
根据给定的轴 axis 和偏移量 offset,将张量 y 的值填充到张量 x 中。

返回:张量 y 填充到张量 x 中的结果。

返回类型:Tensor
DanGuge marked this conversation as resolved.
Show resolved Hide resolved

请参考 :ref:`cn_api_paddle_diagonal_scatter`
37 changes: 37 additions & 0 deletions docs/api/paddle/diagonal_scatter_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
.. _cn_api_paddle_diagonal_scatter:

diagonal_scatter
-------------------------------

.. py:function:: paddle.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)


根据参数 ``offset``、``axis1``、``axis2``,将张量 ``y`` 填充到张量 ``x`` 的对应位置。

这个函数将会返回一个新的 ``Tensor``。

参数 ``offset`` 确定从指定的二维平面中获取对角线的位置:

- 如果 offset = 0,则嵌入主对角线。
- 如果 offset > 0,则嵌入主对角线右上的对角线。
- 如果 offset < 0,则嵌入主对角线左下的对角线。

参数
::::::::::::

- **x** (Tensor) - 输入张量,张量的维度至少为 2 维,支持 float16、float32、float64、bfloat16、uint8、int8、int16、int32、int64、bool、complex64、complex128 数据类型。
- **y** (Tensor) - 嵌入张量,将会被嵌入到输入张量中,支持 float16、float32、float64、bfloat16、uint8、int8、int16、int32、int64、bool、complex64、complex128 数据类型。
- **offset** (int, 可选) - 从指定的二维平面嵌入对角线的位置,默认值为 0,即主对角线。
- **axis1** (int, 可选) - 对角线的第一个维度,默认值为 0。
- **axis2** (int, 可选) - 对角线的第二个维度,默认值为 1。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
::::::::::::
``Tensor``,返回一个根据给定的轴 ``axis`` 和偏移量 ``offset``,将张量 ``y`` 填充到张量 ``x`` 对应位置的新 ``Tensor``。


代码示例
::::::::::::

COPY-FROM: paddle.diagonal_scatter
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## [ 参数完全一致 ] torch.Tensor.diagonal_scatter

### [torch.Tensor.diagonal_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.diagonal_scatter.html?highlight=diagonal_scatter#torch.Tensor.diagonal_scatter)

```python
torch.Tensor.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1)
```

### [paddle.Tensor.diagonal_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#diagonal-scatter-x-y-offset-0-axis1-0-axis2-1-name-none)

```python
paddle.Tensor.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1)
```

两者功能一致且参数用法一致,仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| input | x | 输入张量,被嵌入的张量,仅参数名不一致。 |
| src | y | 用于嵌入的张量,仅参数名不一致。 |
| offset | offset | 从指定的二维平面嵌入对角线的位置,默认值为 0,即主对角线。 |
| dim1 | axis1 | 对角线的第一个维度,默认值为 0,仅参数名不一致。 |
| dim2 | axis2 | 对角线的第二个维度,默认值为 1,仅参数名不一致。 |