-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
symbolic shape #9902
base: master
Are you sure you want to change the base?
symbolic shape #9902
Conversation
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
|
||
bool operator==(const Dim& a, const Dim& b) { | ||
if (a.is_known() && b.is_known()) { return a.value_ == b.value_; } | ||
// reflexivity: Dim::Unknown() == Dim::Unknown() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了让各种 stl 容器能正常工作,unknown == unknown 一定要成立,这并不完全自然,接下来可能会通过区分不同的 unknown dim 来解决这个问题,类似 ONNX/TVM/... 里的做法
x = x * w1.to(flow.float32) | ||
x = x.unsqueeze(0) | ||
y = x.sum(dim=1) | ||
if lazy_mode.is_enabled(): | ||
# Shape inference works correctly even with the presence of | ||
# symbolic dimensions: | ||
assert x.shape == (1, flow.Dim.unknown(), 4) | ||
# y has a static shape even though x has a symbolic shape | ||
assert y.shape == (1, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
symbolic shape 在形状推导中自动传染:(unknown, 4) 经过 unsqueeze 变成 (1, unknown, 4),(1, unknown, 4) 经过 sum(dim=1) 变成静态的 (1, 4)
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
… errors Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
CI failed when running job: Build cpu. PR label automerge has been removed |
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
CI failed when running job: Build cpu. PR label automerge has been removed |
Signed-off-by: daquexian <daquexian566@gmail.com>
Speed stats:
|
CI failed when running job: cpu-misc. PR label automerge has been removed |
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
CI failed when running job: cpu-module. PR label automerge has been removed |
CI failed when running job: cuda-misc. PR label automerge has been removed |
CI failed when running job: cuda-speed-test. PR label automerge has been removed |
symbolic shape 用符号来表示未知的长度,提供了表达动态形状的能力,是支持动态形状的必备环节。这篇知乎文章讲述了一些背景和 TVM/BladeDisc 的相关实现:https://zhuanlan.zhihu.com/p/608027985
本 PR 实现了:
Dim
,它可以表示已知的长度(Dim a = 4
)也可以表示未知的长度(Dim::Unknown()
),同时 Shape 类也从vector<int64_t>
改为了vector<Dim>
。一个 Dim 对象内部成员只有一个 int64_t,这样 Shape 的数据才是紧凑的,和 oneflow 中需要int64_t* header_ptr
作为形状指针的地方才能兼容。没有用 -1 这样的 int64_t 特殊值来表示动态长度的原因是这样无法自动实现动态形状的传染(具体见下一条),也不类型安全。合并后的注意事项:
反例: