Skip to content

Commit

Permalink
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -…
Browse files Browse the repository at this point in the history
…- v2.0 (#776)

* [Change] h v d split with tensor_split

* [Change] revision history

* [Update] chapter 4, 5
  • Loading branch information
megemini authored Dec 8, 2023
1 parent b41100b commit 8dfef44
Showing 1 changed file with 79 additions and 65 deletions.
144 changes: 79 additions & 65 deletions rfcs/APIs/20231003_api_design_for_tensor_split.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
| API 名称 | tensor_split / hsplit / dsplit |
| - | - |
| 提交作者 | megemini(柳顺) |
| 提交时间 | 2023-10-03 |
| 版本号 | V1.2 |
| 提交时间 | 2023-12-07 |
| 版本号 | V2.0 |
| 依赖飞桨版本 | develop |
| 文件名 | 20231003_api_design_for_tensor_split.md |

**修改历史**
v2.0
-`vsplit`, `hsplit`, `dsplit` 通过 `tensor_split` 实现

# 一、概述

## 1、相关背景
Expand Down Expand Up @@ -152,8 +156,7 @@ def vsplit(x, num_or_sections, name=None):
return split(x, num_or_sections, axis=0, name=name)
```

可以看到,`paddle.vsplit` 通过 `split` 实现,因此,本文涉及到的 `paddle.hsplit`, `paddle.dsplit` 同样需要通过 `split` 实现,尤其需要注意分割数需要整除的问题。

可以看到,`paddle.vsplit` 通过 `split` 实现。

# 三、业内方案调研

Expand Down Expand Up @@ -508,76 +511,97 @@ dsplit = tf_export.tf_export('experimental.numpy.dsplit', v1=[])(

`TensorFlow``Numpy``hsplit``dsplit` 均通过 `split` 实现,而 `PyTorch` 的相应接口通过 `tensor_split` 实现。

另一方面,`hsplit``dsplit``vsplit` 是一组功能类似的接口,`Paddle` 通过 `split` 接口实现了 `vsplit` 函数,因此,可以考虑与 `TensorFlow``Numpy` 相同的方式,使用 `split` 接口实现。

# 五、设计思路与实现方案

上一章指出:
总结一下:

- `PyTorch` 的实现方式:`tensor_split`, `vsplit`, `dsplit`, `hsplit` 为一组,都是通过 `tensor_split` (indices_or_sections) 实现,`split` 的签名 (split_size_or_sections) 与其他几个函数也不相同。

- `TensorFlow`, `Numpy` 的实现方式:`split`, `vsplit`, `dsplit`, `hsplit` 为一组,都是通过 `split` (indices_or_sections) 实现,`Numpy` 单独实现了 `array_split` 函数 (indices_or_sections)。

- `Paddle` 的实现方式:`split`, `vsplit` 为一组,都是通过 `split` (num_or_sections) 实现。

因此,本次设计 `split` (num_or_sections) 与 `tensor_split` (indices_or_sections) 的主要不同:
可以看到,主流方案中 `vsplit`, `dsplit`, `hsplit` 为一组接口,虽然底层实现可能不同。

# 五、设计思路与实现方案

上一章指出,`Paddle``split`, `vsplit` 为一组,都是通过 `split` (num_or_sections) 实现,但是,由于 `Paddle``vsplit` 与主流实现方案 (indices_or_sections) 不同,因此:

**重要提示**
此次设计方案,修改 `vsplit` 的实现方式,使 `vsplit`, `hsplit`, `dsplit` 为一组接口,此处 `vsplit` 的修改为不兼容方式,其中:

- `int` 输入,之前为等分,修改后可以为可以不等分
- `list|tuple`int 数组输入,之前是分片长度,修改后为 indices 末位索引

为实现与主流方案对齐,此次将 `vsplit`, `hsplit`, `dsplit``tensor_split` 设计为一组 API,其中 `vsplit`, `hsplit`, `dsplit` 通过 `tensor_split` 实现, `split` 为单独的 API

本次设计 `split` (num_or_sections) 与 `tensor_split` (num_or_indices) 的主要不同:

分割参数为 `int`
- `split` (num_or_sections),包括对应的 `vsplit`, `dsplit`, `hsplit` 为一组 API`等分` 方式分割。
- `tensor_split` (indices_or_sections) 可以 `不等分`
- `split` (num_or_sections) `等分` 方式分割。
- `tensor_split` (num_or_indices) 包括对应的 `vsplit`, `dsplit`, `hsplit` 为一组 API 可以 `不等分`

分割参数为 `list|tuple`
- `split` (num_or_sections),包括对应的 `vsplit`, `dsplit`, `hsplit` 为一组 API表示每个分片长度,输入 `不能越界`,即,listtuple 的长度不能超过输入 Tensor 待分割的维度的大小,且参数中可以有一个 `-1`
- `tensor_split` (indices_or_sections) 表示切分的索引位置,可以 `越界`,由此,分割参数中不能有 `-1`
分割参数为 `list|tuple` int 数组
- `split` (num_or_sections),表示每个分片长度,输入 `不能越界`,即,listtuple 的长度不能超过输入 Tensor 待分割的维度的大小,且参数中可以有一个 `-1`
- `tensor_split` (num_or_indices) 包括对应的 `vsplit`, `dsplit`, `hsplit` 为一组 API 表示切分的索引位置,可以 `越界`

考虑 `hsplit`, `dsplit` 通过 `split` 方式实现,签名的主要参数参考 `split` (num_or_sections) 函数,`tensor_split` 则单独实现,通过签名(indices_or_sections) 体现差异化。
考虑 `vsplit`, `hsplit`, `dsplit` 通过 `tensor_split` 方式实现,签名的主要参数参考 `tensor_split` (num_or_indices) 函数,`split` 则单独实现,通过签名(num_or_sections) 体现差异化。

其中:

- `num_or_sections`, `split_size_or_sections` 数量或分片长度
- `int` 表示拆分数量
- `list` 表示每个分片长度

- `indices_or_sections` 数量或切分索引位置
- `num_or_indices`, `indices_or_sections` 数量或切分索引位置
- `int` 表示拆分数量
- `list` 表示切分的索引位置

特别需要注意的是,这里使用 `num_or_indices` 而不是 `indices_or_sections`
切分参数支持 `int / int数组` 的输入,语义上,`int` 表示切分数量(num),`int数组` 表示末位索引位置(indices),参数名 `indices_or_sections` 并不完全符合这个语义,所以这里改成 `num_or_indices`

## 命名与参数设计

添加 python 上层接口:

- `paddle.tensor_split(x, indices_or_sections, axis=0, name=None)`
- `Tensor.tensor_split(indices_or_sections, axis=0, name=None)`
- `paddle.tensor_split(x, num_or_indices, axis=0, name=None)`
- `Tensor.tensor_split(num_or_indices, axis=0, name=None)`

- 参数列表
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float16, bfloat16, float32, float64, int32, int64, uint8。
> indices_or_sections (int|list|tuple) – Allows indices_or_sections to be an integer that does not equally divide the axis.
> num_or_indices (int|list|tuple) – Allows num_or_indices to be an integer that does not equally divide the axis.
> axis (int, optional) – dimension along which to split the tensor. Default: 0
> name: (str|None): Name for this layer. Please refer to :ref:`api_guide_Name`, Default None.

*注意*
- 经测试,int16, complex64, complex128 数据类型,`split` 函数不支持,另外,uint16 会转换为 bfloat16,因此也不在支持之列。
- 由于 `dsplit`, `hsplit` 对齐 `split`, `vsplit`,因此,`tensor_split` 使用 `indices_or_sections` (与 `Numpy` 一致) 而不是 `num_or_sections`
- 返回值
> output (List of Tensors)

- `paddle.vsplit(x, num_or_indices, name=None)`
- `Tensor.vsplit(num_or_indices, name=None)`

- 参数列表
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float16, bfloat16, float32, float64, int32, int64, uint8。
> num_or_indices (int|list|tuple) – Allows num_or_indices to be an integer that does not equally divide the axis.
> name: (str|None): Name for this layer. Please refer to :ref:`api_guide_Name`, Default None.

- 返回值
> output (List of Tensors)

- `paddle.hsplit(x, num_or_sections, name=None)`
- `Tensor.hsplit(num_or_sections, name=None)`
- `paddle.hsplit(x, num_or_indices, name=None)`
- `Tensor.hsplit(num_or_indices, name=None)`

- 参数列表
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float16, bfloat16, float32, float64, int32, int64, uint8。
> num_or_sections (int|list|tuple) – If num_or_sections is an int, then num_or_sections indicates the number of equal sized sub-Tensors that the x will be divided into.
> num_or_indices (int|list|tuple) – Allows num_or_indices to be an integer that does not equally divide the axis.
> name: (str|None): Name for this layer. Please refer to :ref:`api_guide_Name`, Default None.

- 返回值
> output (List of Tensors)

- `paddle.dsplit(x, num_or_sections, name=None)`
- `Tensor.dsplit(num_or_sections, name=None)`
- `paddle.dsplit(x, num_or_indices, name=None)`
- `Tensor.dsplit(num_or_indices, name=None)`

- 参数列表
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float16, bfloat16, float32, float64, int32, int64, uint8。
> num_or_sections (int|list|tuple) – If num_or_sections is an int, then num_or_sections indicates the number of equal sized sub-Tensors that the x will be divided into.
> num_or_indices (int|list|tuple) – Allows num_or_indices to be an integer that does not equally divide the axis.
> name: (str|None): Name for this layer. Please refer to :ref:`api_guide_Name`, Default None.

- 返回值
Expand All @@ -597,10 +621,10 @@ dsplit = tf_export.tf_export('experimental.numpy.dsplit', v1=[])(

具体接口:

- `paddle.tensor_split(x, indices_or_sections, axis=0, name=None)`
- `paddle.tensor_split(x, num_or_indices, axis=0, name=None)`

``` python
def tensor_split(x, indices_or_sections, axis=0, name=None):
def tensor_split(x, num_or_indices, axis=0, name=None):
if x.ndim < 1:
raise ValueError(
f"The input tensor's dimension must be greater than 0, but got {x.ndim}"
Expand All @@ -609,67 +633,57 @@ dsplit = tf_export.tf_export('experimental.numpy.dsplit', v1=[])(
total_n = x.shape[axis]

def _tensor_split_array(total_n, sections, axis):
splits = []

starts = 0
ends = 0
for idx in sections:
ends = idx
sub_array = paddle.slice(x, axes=[axis], starts=[starts], ends=[ends])
splits.append(sub_array)
starts = ends

starts = ends
ends = total_n
sub_array = paddle.slice(x, axes=[axis], starts=[starts], ends=[ends])
splits.append(sub_array)

return splits
...

def _tensor_split_int(total_n, sections, axis):
if sections <= 0:
raise ValueError('indices_or_sections must be larger than 0.')
raise ValueError('num_or_indices must be larger than 0.')

base, mod = divmod(total_n, sections)
section_array = [base + 1] * mod + [base] * (sections - mod)
section_array = np.cumsum(section_array[:-1], dtype=int)

return _tensor_split_array(total_n, section_array, axis)
...

if isinstance(indices_or_sections, int):
return _tensor_split_int(total_n, indices_or_sections, axis)
if isinstance(num_or_indices, int):
return _tensor_split_int(total_n, num_or_indices, axis)

elif isinstance(indices_or_sections, (list, tuple)):
return _tensor_split_array(total_n, indices_or_sections, axis)
elif isinstance(num_or_indices, (list, tuple)):
return _tensor_split_array(total_n, num_or_indices, axis)

else:
raise ValueError(
f"The indices_or_sections should be int, list or tuple of ints, but got {type(indices_or_sections)}"
f"The num_or_indices should be int, list or tuple of ints, but got {type(num_or_indices)}"
)
```

- `paddle.hsplit(x, num_or_sections, name=None)`
- `paddle.vsplit(x, num_or_indices, name=None)`

``` python
def vsplit(x, num_or_indices, name=None):
if x.ndim < 2:
raise ValueError(
f"The input tensor's dimension must be greater than 1, but got {x.ndim}"
)
return tensor_split(x, num_or_indices, axis=0, name=name)
```

def hsplit(x, num_or_sections, name=None):
- `paddle.hsplit(x, num_or_indices, name=None)`

``` python
def hsplit(x, num_or_indices, name=None):
if x.ndim < 1:
raise ValueError(
f"The input tensor's dimension must be greater than 0, but got {x.ndim}"
)
return split(x, num_or_sections, axis=1, name=name)
return tensor_split(x, num_or_indices, axis=1, name=name)
```

- `paddle.dsplit(x, num_or_sections, name=None)`
- `paddle.dsplit(x, num_or_indices, name=None)`

``` python

def dsplit(x, num_or_sections, name=None):
def dsplit(x, num_or_indices, name=None):
if x.ndim < 3:
raise ValueError(
f"The input tensor's dimension must be greater than 2, but got {x.ndim}"
)
return split(x, num_or_sections, axis=2, name=name)
return tensor_split(x, num_or_indices, axis=2, name=name)
```

# 六、测试和验收的考量
Expand Down

0 comments on commit 8dfef44

Please sign in to comment.