diff --git a/rfcs/APIs/20231206_api_design_for_slice_scatter.md b/rfcs/APIs/20231206_api_design_for_slice_scatter.md index 4a7db8b6d..775cc7c28 100644 --- a/rfcs/APIs/20231206_api_design_for_slice_scatter.md +++ b/rfcs/APIs/20231206_api_design_for_slice_scatter.md @@ -3,11 +3,14 @@ | API名称 | paddle.slice_scatter | | ------------------------------------------------------------ | ----------------------------------------- | | 提交作者 | megemini (柳顺) | -| 提交时间 | 2023-12-13 | -| 版本号 | V1.0 | +| 提交时间 | 2023-12-22 | +| 版本号 | V2.0 | | 依赖飞桨版本 | develop | | 文件名 | 20231213_api_design_for_slice_scatter.md
| +**修订记录** + +v2.0 修改函数签名,支持 `list of int` 的参数 # 一、概述 @@ -266,17 +269,17 @@ paddle 目前的 `set_value` 算子已经支持 `axes`, `starts`, `ends`, `steps 添加 Python API: ```python -paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None) +paddle.slice_scatter(x, value, axes, starts, ends, strides, name=None) ``` 参数表: -- x: (Tensor) 输入的 tensor。数据类型支持 `float32`、`float64`。 -- value: (Tensor) 用于填充的 tensor。数据类型与input一致,形状与`x[*x.shape[:axis], start:end:step, *x.shape[axis+1:]]`取出的slice一致。 -- axis: (int) y的数据将被填充至x的axis维度。 -- start: (Optional[int]) 待插入slice位置的起始index。 -- stop: (Optional[int]) 待插入slice位置的结束index。 -- step: (int) 待插入slice的步长。 +- x: (Tensor) 输入的 tensor。 +- value: (Tensor) 用于填充的 tensor。数据类型与input一致。 +- axes: (list|tuple) y的数据将被填充至x的axis维度。 +- starts: (list|tuple) 待插入slice位置的起始index。 +- ends: (list|tuple) 待插入slice位置的结束index。 +- strides: (list|tuple) 待插入slice的步长。 - name: (Optional[str]) op 名称 ## 底层OP设计 @@ -288,45 +291,8 @@ paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None) 此次使用 `set_value` 算子实现接口: ``` python -def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None): - - if x.ndim != value.ndim: - raise ValueError( - f"The input x and value should have save dimension, but got input of {x.ndim} and value of {value.ndim}." - ) - - x_shape = x.shape - value_shape = value.shape - - index = list(range(start or 0, stop or x_shape[axis], step)) - exp_shape = [*x_shape[:axis], len(index), *x_shape[axis+1:]] - if exp_shape != value_shape: - raise ValueError( - "The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]]," - f"but got value.shape of {value.shape} and slice shape {exp_shape}." - ) - - starts = [start] - ends = [stop] - steps = [step] - axes = [axis] - none_axes = [] - decrease_axes = [] - inputs = {'Input': x} - attrs = { - 'axes': axes, - 'starts': starts, - 'ends': ends, - 'steps': steps, - 'decrease_axes': decrease_axes, - 'none_axes': none_axes, - } - - dtype = x.dtype - attrs['dtype'] = dtype - - value = value.astype(dtype) - inputs["ValueTensor"] = value +def slice_scatter(x, value, axes, starts, ends, strides, name=None): + ... check params if in_dynamic_or_pir_mode(): return _C_ops.set_value_with_tensor( @@ -334,7 +300,7 @@ def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None): value, starts, ends, - steps, + strides, axes, decrease_axes, none_axes, @@ -354,22 +320,14 @@ def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None): return output ``` -有几点说明: - -- x 与 src 需要有相同的 ndim -- values_shape 需要与 slice 的 exp_shape 一致 -- 参数 axis/start/stop/step 不支持 list。因为,多个 axis 的话可能导致 slice 的 shape 错误。 - 比如,x 为 [8, 8], src 为 [8, 2],则 axis 只能为 1。 - - # 六、测试和验收的考量 - 覆盖动态图和静态图的测试场景 - 覆盖 CPU、GPU 两种测试场景 -- 支持各种Tensor精度,FP32、FP64(带验证) -- 需要检查前向和反向计算的精度正确性 -- 处理0维输入数据 -- 处理可选参数不存在或不一致的情况 +- 支持各种Tensor精度,FP32、FP64 等(待验证) +- 需要检查计算正确性 +- 需要检查多维的情况 +- 需要检查 broadcast 情况 # 七、可行性分析和排期规划