Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detecting dynamic flow? #310

Open
eliegoudout opened this issue May 31, 2024 · 2 comments
Open

Detecting dynamic flow? #310

eliegoudout opened this issue May 31, 2024 · 2 comments

Comments

@eliegoudout
Copy link

Hi,

I came across the notion of static/dynamic flow on PyTorch's doc.
I realize that dynamic flow (that is when module calls may not be consistent between different inputs) poses an obvious problem for torchinfo. Indeed, the summary is computed through a forward pass of an input (random or full zeros or else, I've not looked at your code to figure it out), but another input might yield a different module execution.

As such, I think it would be wise to consider issuing a warning or raising an error when detecting a dynamic flow? Otherwise, the output may be misleading.

I chose the "Feature Request" tag, but it might also e considered a "Bug Report" since it's about an elementary vulnerability.

Cheers!

@TylerYep
Copy link
Owner

That sounds reasonable. PRs addressing this are welcome!

@eliegoudout
Copy link
Author

I would think about adding something like this in summary

import warnings 

warning = """
The control flow of the target module may be dynamic. As such, the
summary may vary for different inputs. For more information, see
https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing
"""
try:
    torch.fx.symbolic_trace(model)  # Fails for dynamic control flow
except torch.fx.proxy.TraceError as e:
    warnings.warn(warning)

but I'm a bit uncomfortable pushing this like this because I'm not entirely sure about my understanding of the page I linked regarding control flow and tracing, my knowloedge is very limited on this. Also, I don't know the cost of using torch.fx.symbolic_trace(model) but I guess it could be roughly equivalent to a normal forward pass, so that could add a bit of an overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants