Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mock class for specific workflow #416

Merged
merged 6 commits into from
Dec 13, 2023

Conversation

doombeaker
Copy link
Contributor

@doombeaker doombeaker commented Dec 12, 2023

  • 适配动态 batchsize 以及一个特定工作流中的动态 shape 问题
  • 动态 batchsize 不支持 community 版本,说明

@doombeaker doombeaker marked this pull request as ready for review December 13, 2023 11:52
Comment on lines +116 to 118
x = x.flatten(2, 3).permute(0, 2, 1)
# x = x.reshape(b, c, -1).permute(0, 2, 1)
# x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最原始的 ComfyUI 写法,以及上一版的 onediff 写法都保留,是因为这是一个修正动态 shape 问题的例子。
原始需求:

(b, c, h, w) -> (b, h*w, c)

第一版 onediff 的写法:

x = x.reshape(b, c, -1).permute(0, 2, 1)

此时 reshape 中的 b, c 如果不变,只有一个维度变化(-1 那个维度),那么 reshape 可以动态推导出 -1 那个维度的具体值。

但是,如果 -1 维度变化的同时, b 也变化,reshape 就无法正常推导了。
所以要用最新的版本:

x = x.flatten(2, 3).permute(0, 2, 1)

先把 (b, c, h, w) 中的后 2 个维度压到一起 (b, c, (h*w)),然后调整维度顺序,得到 (b, (h,w), c,达到与之前等价的效果。

del hsp
if len(hs) > 0:
# output_shape = hs[-1].shape
output_shape = hs[-1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处修改为 tensor,目的是传给 oneflow 的参数是 tensor,这样每次推导时能拿到 tensor 的 meta 信息。之前的 shape 作为属性,无法每次推导时更新。

if self.dims == 3:
raise ValueError("output_shape shoud not be Tensor for dims == 3")
else:
x = F.interpolate_like(x, like=output_shape, mode="nearest")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是 oneflow 新增的算子

@@ -1,19 +1,18 @@
import oneflow
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件从 onediff_comfy_nodes 搬到 onediff 下了,它确实应该是 onediff 层次的东西(获取 oneflow 版本信息)

@@ -26,12 +25,6 @@ def is_quantization_enabled():
return hasattr(oneflow._C, "dynamic_quantization")


def is_community_version(stop_if_not=False):
def is_community_version():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉了打印功能,开发者自己决定是否打印,或者做什么处理(否则这个 PR 里调用了 2 此,打印两次,比较多余)

@ccssu ccssu merged commit 3daab10 into main Dec 13, 2023
4 checks passed
@ccssu ccssu deleted the add_mock_class_for_specific_workflow branch December 13, 2023 14:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants