-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add graph cache key #576
add graph cache key #576
Conversation
cache_key = calculate_model_hash(model) + "_" + flow.__version__ | ||
return f"{file_path}_{count}_{cache_key}.graph" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加 oneflow 模型结构的 key 避免 自定义注册导致的冲突, 比如添加自定CrossAttention1f 但是加载 以前保存的图文件。
torch2of_class_map = {
comfy.ldm.modules.attention.CrossAttention: CrossAttention1f,
comfy.ldm.modules.attention.SpatialTransformer: SpatialTransformer1f,
comfy_ops_Linear: Linear1f,
AttnBlock: AttnBlock1f,
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块功能是支持 graph file 的缓存;
然后如果改了模型结构,为了避免复用错误的缓存,所以加了个和模型结构相关的 key 放到文件名以区分?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这构造 key 的方式,感觉开销不小啊,取了整个 module 的 repr 的 str 来生成?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
graph_file_management 每次 graph 调用都会被执行,100ms 的开销是不可接受的,1 ms 都需要想想。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
graph_file_management 每次 graph 调用都会被执行,100ms 的开销是不可接受的,1 ms 都需要想想。
这个不是每次 graph 调用都会被执行,只会在第一次加载图时候执行
options={}, | ||
graph_path=None, | ||
graph_device=None, | ||
self, torch_module, oneflow_module, use_graph=True, dynamic=True, options={} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除掉遗留的 graph_path=None,
graph_device=None,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里给这两个参数增加下接口说明吧:
- 'size' which config the cache size when cache is enabled. Note that after onediff v0.12, cache is default disabled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
- 'graph_file' which config the graph file path, default None.
- 'graph_file_device' which config the device of graph file, default None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以再补充完整一点,如果配置了 graph_file ,会生成编译结果的 cache;如果配置了 graph_file_device ,在加载编译结果时,会把编译结果给转换到 graph_file_device 这个设备上以支持改变设备
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, compilation is saved after the first.
- 'graph_file_device' (None) sets the device for the graph file, default None. Enables flexible loading and compilation shift to the specified device.
with cost_time( | ||
debug=transform_mgr.debug_mode, message="calculate model input count" | ||
): | ||
args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) | ||
count = len( | ||
[v for v in args_tree.iter_nodes() if isinstance(v, flow.Tensor)] | ||
) | ||
|
||
with cost_time(debug=transform_mgr.debug_mode, message="get model"): | ||
model = self._deployable_module_model.oneflow_module | ||
|
||
with cost_time( | ||
debug=transform_mgr.debug_mode, | ||
message="calculate model hash for cache key", | ||
): | ||
cache_key = calculate_model_hash(model) + "_" + flow.__version__ | ||
return f"{file_path}_{count}_{cache_key}.graph" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DEBUG [2024-01-28 02:34:06] - calculate model input count run time 7.200241088867188e-05 seconds
DEBUG [2024-01-28 02:34:06] - Convert <class 'comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel'> ...
DEBUG [2024-01-28 02:34:09] - Convert id(self._torch_module)=140361274282560 done!
DEBUG [2024-01-28 02:34:09] - get model run time 3.119259834289551 seconds
DEBUG [2024-01-28 02:34:09] - calculate model hash for cache key run time 0.014719009399414062 seconds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get model 的时间为什么达到了 3s,这里不是只是一次简单的 get attr 么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.