diff --git a/rfcs/APIs/20231206_api_design_for_slice_scatter.md b/rfcs/APIs/20231206_api_design_for_slice_scatter.md
index 4a7db8b6d..775cc7c28 100644
--- a/rfcs/APIs/20231206_api_design_for_slice_scatter.md
+++ b/rfcs/APIs/20231206_api_design_for_slice_scatter.md
@@ -3,11 +3,14 @@
| API名称 | paddle.slice_scatter |
| ------------------------------------------------------------ | ----------------------------------------- |
| 提交作者 | megemini (柳顺) |
-| 提交时间 | 2023-12-13 |
-| 版本号 | V1.0 |
+| 提交时间 | 2023-12-22 |
+| 版本号 | V2.0 |
| 依赖飞桨版本 | develop |
| 文件名 | 20231213_api_design_for_slice_scatter.md
|
+**修订记录**
+
+v2.0 修改函数签名,支持 `list of int` 的参数
# 一、概述
@@ -266,17 +269,17 @@ paddle 目前的 `set_value` 算子已经支持 `axes`, `starts`, `ends`, `steps
添加 Python API:
```python
-paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None)
+paddle.slice_scatter(x, value, axes, starts, ends, strides, name=None)
```
参数表:
-- x: (Tensor) 输入的 tensor。数据类型支持 `float32`、`float64`。
-- value: (Tensor) 用于填充的 tensor。数据类型与input一致,形状与`x[*x.shape[:axis], start:end:step, *x.shape[axis+1:]]`取出的slice一致。
-- axis: (int) y的数据将被填充至x的axis维度。
-- start: (Optional[int]) 待插入slice位置的起始index。
-- stop: (Optional[int]) 待插入slice位置的结束index。
-- step: (int) 待插入slice的步长。
+- x: (Tensor) 输入的 tensor。
+- value: (Tensor) 用于填充的 tensor。数据类型与input一致。
+- axes: (list|tuple) y的数据将被填充至x的axis维度。
+- starts: (list|tuple) 待插入slice位置的起始index。
+- ends: (list|tuple) 待插入slice位置的结束index。
+- strides: (list|tuple) 待插入slice的步长。
- name: (Optional[str]) op 名称
## 底层OP设计
@@ -288,45 +291,8 @@ paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None)
此次使用 `set_value` 算子实现接口:
``` python
-def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
-
- if x.ndim != value.ndim:
- raise ValueError(
- f"The input x and value should have save dimension, but got input of {x.ndim} and value of {value.ndim}."
- )
-
- x_shape = x.shape
- value_shape = value.shape
-
- index = list(range(start or 0, stop or x_shape[axis], step))
- exp_shape = [*x_shape[:axis], len(index), *x_shape[axis+1:]]
- if exp_shape != value_shape:
- raise ValueError(
- "The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]],"
- f"but got value.shape of {value.shape} and slice shape {exp_shape}."
- )
-
- starts = [start]
- ends = [stop]
- steps = [step]
- axes = [axis]
- none_axes = []
- decrease_axes = []
- inputs = {'Input': x}
- attrs = {
- 'axes': axes,
- 'starts': starts,
- 'ends': ends,
- 'steps': steps,
- 'decrease_axes': decrease_axes,
- 'none_axes': none_axes,
- }
-
- dtype = x.dtype
- attrs['dtype'] = dtype
-
- value = value.astype(dtype)
- inputs["ValueTensor"] = value
+def slice_scatter(x, value, axes, starts, ends, strides, name=None):
+ ... check params
if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
@@ -334,7 +300,7 @@ def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
value,
starts,
ends,
- steps,
+ strides,
axes,
decrease_axes,
none_axes,
@@ -354,22 +320,14 @@ def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
return output
```
-有几点说明:
-
-- x 与 src 需要有相同的 ndim
-- values_shape 需要与 slice 的 exp_shape 一致
-- 参数 axis/start/stop/step 不支持 list。因为,多个 axis 的话可能导致 slice 的 shape 错误。
- 比如,x 为 [8, 8], src 为 [8, 2],则 axis 只能为 1。
-
-
# 六、测试和验收的考量
- 覆盖动态图和静态图的测试场景
- 覆盖 CPU、GPU 两种测试场景
-- 支持各种Tensor精度,FP32、FP64(带验证)
-- 需要检查前向和反向计算的精度正确性
-- 处理0维输入数据
-- 处理可选参数不存在或不一致的情况
+- 支持各种Tensor精度,FP32、FP64 等(待验证)
+- 需要检查计算正确性
+- 需要检查多维的情况
+- 需要检查 broadcast 情况
# 七、可行性分析和排期规划