Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add graph cache key #576

Merged
merged 9 commits into from
Jan 29, 2024
Merged

add graph cache key #576

merged 9 commits into from
Jan 29, 2024

Conversation

ccssu
Copy link
Contributor

@ccssu ccssu commented Jan 26, 2024

No description provided.

Comment on lines 34 to 35
cache_key = calculate_model_hash(model) + "_" + flow.__version__
return f"{file_path}_{count}_{cache_key}.graph"
Copy link
Contributor Author

@ccssu ccssu Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加 oneflow 模型结构的 key 避免 自定义注册导致的冲突, 比如添加自定CrossAttention1f 但是加载 以前保存的图文件。

torch2of_class_map = {
    comfy.ldm.modules.attention.CrossAttention: CrossAttention1f,
    comfy.ldm.modules.attention.SpatialTransformer: SpatialTransformer1f,
    comfy_ops_Linear: Linear1f,
    AttnBlock: AttnBlock1f,
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块功能是支持 graph file 的缓存;

然后如果改了模型结构,为了避免复用错误的缓存,所以加了个和模型结构相关的 key 放到文件名以区分?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这构造 key 的方式,感觉开销不小啊,取了整个 module 的 repr 的 str 来生成?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的取了整个 module 的 repr 的 str 来生成。
开销0.1~0.2s 在 SDXL 1.0 感觉可接受,主要是更安全。 如下图画红框处,这里主要耗时卡点的是 torch2flow那个转换。
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph_file_management 每次 graph 调用都会被执行,100ms 的开销是不可接受的,1 ms 都需要想想。

Copy link
Contributor Author

@ccssu ccssu Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph_file_management 每次 graph 调用都会被执行,100ms 的开销是不可接受的,1 ms 都需要想想。

这个不是每次 graph 调用都会被执行,只会在第一次加载图时候执行

options={},
graph_path=None,
graph_device=None,
self, torch_module, oneflow_module, use_graph=True, dynamic=True, options={}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除掉遗留的 graph_path=None,
graph_device=None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这里给这两个参数增加下接口说明吧:

- 'size' which config the cache size when cache is enabled. Note that after onediff v0.12, cache is default disabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

        - 'graph_file' which config the graph file path, default None.
        - 'graph_file_device' which config the device of graph file, default None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以再补充完整一点,如果配置了 graph_file ,会生成编译结果的 cache;如果配置了 graph_file_device ,在加载编译结果时,会把编译结果给转换到 graph_file_device 这个设备上以支持改变设备

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, compilation is saved after the first.
  • 'graph_file_device' (None) sets the device for the graph file, default None. Enables flexible loading and compilation shift to the specified device.

@ccssu ccssu requested a review from strint January 26, 2024 06:52
Comment on lines 31 to 47
with cost_time(
debug=transform_mgr.debug_mode, message="calculate model input count"
):
args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor)
count = len(
[v for v in args_tree.iter_nodes() if isinstance(v, flow.Tensor)]
)

with cost_time(debug=transform_mgr.debug_mode, message="get model"):
model = self._deployable_module_model.oneflow_module

with cost_time(
debug=transform_mgr.debug_mode,
message="calculate model hash for cache key",
):
cache_key = calculate_model_hash(model) + "_" + flow.__version__
return f"{file_path}_{count}_{cache_key}.graph"
Copy link
Contributor Author

@ccssu ccssu Jan 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DEBUG [2024-01-28 02:34:06] - calculate model input count run time 7.200241088867188e-05 seconds
DEBUG [2024-01-28 02:34:06] - Convert <class 'comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel'> ...
DEBUG [2024-01-28 02:34:09] - Convert id(self._torch_module)=140361274282560 done!
DEBUG [2024-01-28 02:34:09] - get model run time 3.119259834289551 seconds
DEBUG [2024-01-28 02:34:09] - calculate model hash for cache key run time 0.014719009399414062 seconds

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get model 的时间为什么达到了 3s,这里不是只是一次简单的 get attr 么

Copy link
Contributor Author

@ccssu ccssu Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get model 的时间为什么达到了 3s,这里不是只是一次简单的 get attr 么
频繁的mock_torch enable 和 disable 耗时占了近一半。 这里提前对模型中所有的 class 做一次 cache , 估计可以优化 1s 多

image

@strint strint merged commit ca046d1 into main Jan 29, 2024
1 check passed
@strint strint deleted the fix_graph_management_utils branch January 29, 2024 13:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants