-
Notifications
You must be signed in to change notification settings - Fork 275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API #682
Conversation
) | ||
``` | ||
|
||
**疑问**: 静态图目前调不通,`while starts < total_n` 似乎永远跳不出来,还请指教! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图下使用tensor的值进行控制流判断时,不能直接使用python的if/while/for
等,需要使用专门的控制流API,如cond / while_loop
等; 这里如非必要可以在非tensor类型下操作。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里尝试过用 while_loop
,但是报错了,提示返回与输入长度对不上~ 不确定是我这里调用有问题还是API的bug,我再试一下 ~
> ``` | ||
> | ||
|
||
^*^ 注 : `Paddle` 的 `split` 函数签名为 `split(x, num_or_sections, axis=0, name=None)`,与上文中介绍的不一样,但并不影响后续的分析。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里因为是
飞桨现状
, 最好展开介绍下paddle.split
和上面的差异,如果行为和上面介绍的相似或一致,也可以直接指出来 - 辛苦也介绍下已有的API
paddle.vsplit
,其参数是对齐split
还是tensor_split
;因为vsplit / dsplit / hsplit
从观感上是一组API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK,我尽快更新一下 ~
|
||
- 参数列表 | ||
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float32、float64、int32、int64。 | ||
> num_or_sections (Tensor|int|list|tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 关于
x
, 理论上tensor操作类应该支持所有数据类型,这里因为受制于依赖的API可能有部分会缺失,目前官网文档中相比实际可能偏少,辛苦验证下实际支持的数据类型 num_or_sections
:需要说明下语义,这里应该是和paddle.split
的核心差异所在,参数命名也可以考虑和语义对齐下
|
||
- 参数列表 | ||
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float32、float64、int32、int64。 | ||
> num_or_sections (Tensor|int|list|tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个API 同tensor_split
- 返回值 | ||
> output (List of Tensors) | ||
|
||
另外,这几个接口均无需 `name` 参数,因为输出可能为多个 Tensor。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是可以加上name的,参考已有的vsplit / split
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vsplit/split 确实有 name 参数,但是代码里面实际没有用到的:
可以看到,vsplit 调用的 split,而 split 里面实际没有用到 name ~
所以这里才有的这个疑问。而且,最关键的问题是,如果有 name,那么这个 name 怎么赋值?因为返回的是一个 list,那么这个 name 是不是也要是 list?🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@megemini 这里还是默认就可以。参考下文档了解下name参数的作用哈,主要是可以指定名称,以替代自动生成的OP名前缀。可以在静态图模式下使用split设置name参数看看输出的差异
另一方面,`hsplit`、`dsplit`、`vsplit` 是一组功能类似的接口,`Paddle` 通过 `split` 接口实现了 `vsplit` 函数,因此,可以考虑与 `TensorFlow`、`Numpy` 相同的方式,使用 `split` 接口实现。 | ||
|
||
# 五、设计思路与实现方案 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里,建议明确强调下以下两点:
- 设计的tensor_split和已有的split的差异
- dsplit / hsplit 的参数功能是对齐split 还是tensor_split
非常抱歉拖了这么久才更新,此次更新主要包括:
请审核~ 非常感谢! |
|
||
另外,对于 `split` 函数与 `tensor_split` 的区别,这里引用 [Pytorch文档学习 TORCH.TENSOR_SPLIT](https://blog.csdn.net/Jamesgender/article/details/130559738) : | ||
|
||
> 这个方法和 split 方法长得很像。他们的作用都是根据 indices_or_sections,把输入拆分成几个视图。区别在于: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里split
/ tensor_split
的一个重要差异,前者参数num_or_sections
表示数量或分片长度 , 但后者indices_or_sections
则表示数量或切分索引位置,这个语义上的差异会对API的使用有较大影响,是需要明确如何设计的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯 这里之前没说清楚,已更新 ~ 非常感谢!
分割参数为 `list|tuple`: | ||
- `split`,包括对应的 `vsplit`, `dsplit`, `hsplit` 为一组 API,输入 `不能越界`,即,list 或 tuple 的长度不能超过输入 Tensor 待分割的维度的大小,且参数中可以有一个 `-1`。 | ||
- `tensor_split` 可以 `越界`,由此,分割参数中不能有 `-1`。 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里补充下参数num_or_sections
以及 indice_or_sections`语义的差异吧
- 返回值 | ||
> output (List of Tensors) | ||
|
||
另外,这几个接口均无需 `name` 参数,因为输出可能为多个 Tensor。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前已经加上name了,这一句需要移除下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前清理的时候遗漏了 ... ...
重新检查了一遍并作更新,请评审 ~ 😂 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Docs
Description
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API
请评审!