Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API v2 #790

Merged
merged 7 commits into from
Dec 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 19 additions & 61 deletions rfcs/APIs/20231206_api_design_for_slice_scatter.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
| API名称 | paddle.slice_scatter |
| ------------------------------------------------------------ | ----------------------------------------- |
| 提交作者<input type="checkbox" class="rowselector hidden"> | megemini (柳顺) |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-12-13 |
| 版本号 | V1.0 |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-12-22 |
| 版本号 | V2.0 |
| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop |
| 文件名 | 20231213_api_design_for_slice_scatter.md<br> |

**修订记录**

v2.0 修改函数签名,支持 `list of int` 的参数


# 一、概述
Expand Down Expand Up @@ -266,17 +269,17 @@ paddle 目前的 `set_value` 算子已经支持 `axes`, `starts`, `ends`, `steps

添加 Python API:
```python
paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None)
paddle.slice_scatter(x, value, axes, starts, ends, strides, name=None)
```

参数表:

- x: (Tensor) 输入的 tensor。数据类型支持 `float32`、`float64`。
- value: (Tensor) 用于填充的 tensor。数据类型与input一致,形状与`x[*x.shape[:axis], start:end:step, *x.shape[axis+1:]]`取出的slice一致
- axis: (int) y的数据将被填充至x的axis维度。
- start: (Optional[int]) 待插入slice位置的起始index。
- stop: (Optional[int]) 待插入slice位置的结束index。
- step: (int) 待插入slice的步长。
- x: (Tensor) 输入的 tensor。
- value: (Tensor) 用于填充的 tensor。数据类型与input一致。
- axes: (list|tuple) y的数据将被填充至x的axis维度。
- starts: (list|tuple) 待插入slice位置的起始index。
- ends: (list|tuple) 待插入slice位置的结束index。
- strides: (list|tuple) 待插入slice的步长。
- name: (Optional[str]) op 名称

## 底层OP设计
Expand All @@ -288,53 +291,16 @@ paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None)
此次使用 `set_value` 算子实现接口:

``` python
def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):

if x.ndim != value.ndim:
raise ValueError(
f"The input x and value should have save dimension, but got input of {x.ndim} and value of {value.ndim}."
)

x_shape = x.shape
value_shape = value.shape

index = list(range(start or 0, stop or x_shape[axis], step))
exp_shape = [*x_shape[:axis], len(index), *x_shape[axis+1:]]
if exp_shape != value_shape:
raise ValueError(
"The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]],"
f"but got value.shape of {value.shape} and slice shape {exp_shape}."
)

starts = [start]
ends = [stop]
steps = [step]
axes = [axis]
none_axes = []
decrease_axes = []
inputs = {'Input': x}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

dtype = x.dtype
attrs['dtype'] = dtype

value = value.astype(dtype)
inputs["ValueTensor"] = value
def slice_scatter(x, value, axes, starts, ends, strides, name=None):
... check params

if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
x,
value,
starts,
ends,
steps,
strides,
axes,
decrease_axes,
none_axes,
Expand All @@ -354,22 +320,14 @@ def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
return output
```

有几点说明:

- x 与 src 需要有相同的 ndim
- values_shape 需要与 slice 的 exp_shape 一致
- 参数 axis/start/stop/step 不支持 list。因为,多个 axis 的话可能导致 slice 的 shape 错误。
比如,x 为 [8, 8], src 为 [8, 2],则 axis 只能为 1。


# 六、测试和验收的考量

- 覆盖动态图和静态图的测试场景
- 覆盖 CPU、GPU 两种测试场景
- 支持各种Tensor精度,FP32、FP64(带验证
- 需要检查前向和反向计算的精度正确性
- 处理0维输入数据
- 处理可选参数不存在或不一致的情况
- 支持各种Tensor精度,FP32、FP64 等(待验证
- 需要检查计算正确性
- 需要检查多维的情况
- 需要检查 broadcast 情况

# 七、可行性分析和排期规划

Expand Down