-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mock class for specific workflow #416
Conversation
doombeaker
commented
Dec 12, 2023
•
edited
Loading
edited
- 适配动态 batchsize 以及一个特定工作流中的动态 shape 问题
- 动态 batchsize 不支持 community 版本,说明
…lass_for_specific_workflow
…lass_for_specific_workflow
x = x.flatten(2, 3).permute(0, 2, 1) | ||
# x = x.reshape(b, c, -1).permute(0, 2, 1) | ||
# x = rearrange(x, 'b c h w -> b (h w) c').contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最原始的 ComfyUI 写法,以及上一版的 onediff 写法都保留,是因为这是一个修正动态 shape 问题的例子。
原始需求:
(b, c, h, w) -> (b, h*w, c)
第一版 onediff 的写法:
x = x.reshape(b, c, -1).permute(0, 2, 1)
此时 reshape 中的 b
, c
如果不变,只有一个维度变化(-1
那个维度),那么 reshape
可以动态推导出 -1
那个维度的具体值。
但是,如果 -1
维度变化的同时, b
也变化,reshape 就无法正常推导了。
所以要用最新的版本:
x = x.flatten(2, 3).permute(0, 2, 1)
先把 (b, c, h, w)
中的后 2 个维度压到一起 (b, c, (h*w))
,然后调整维度顺序,得到 (b, (h,w), c
,达到与之前等价的效果。
del hsp | ||
if len(hs) > 0: | ||
# output_shape = hs[-1].shape | ||
output_shape = hs[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处修改为 tensor,目的是传给 oneflow 的参数是 tensor,这样每次推导时能拿到 tensor 的 meta 信息。之前的 shape 作为属性,无法每次推导时更新。
if self.dims == 3: | ||
raise ValueError("output_shape shoud not be Tensor for dims == 3") | ||
else: | ||
x = F.interpolate_like(x, like=output_shape, mode="nearest") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是 oneflow 新增的算子
@@ -1,19 +1,18 @@ | |||
import oneflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件从 onediff_comfy_nodes 搬到 onediff 下了,它确实应该是 onediff 层次的东西(获取 oneflow 版本信息)
@@ -26,12 +25,6 @@ def is_quantization_enabled(): | |||
return hasattr(oneflow._C, "dynamic_quantization") | |||
|
|||
|
|||
def is_community_version(stop_if_not=False): | |||
def is_community_version(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉了打印功能,开发者自己决定是否打印,或者做什么处理(否则这个 PR 里调用了 2 此,打印两次,比较多余)