-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix fx
problem for Aggregation
#5021
Conversation
torch_geometric/nn/fx.py
Outdated
if op == 'call_module': | ||
return isinstance(get_submodule(module, target), GlobalPooling) | ||
return isinstance(get_submodule(module, target), | ||
GlobalPooling) or isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let‘s just replace it here.
torch_geometric/nn/aggr/base.py
Outdated
|
||
# @abstractmethod | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
dim: int = -2, **kwargs) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to add kwargs here in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because of the GraphMultisetTransformer
- which also takes edge_index
optionally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can just add this to the default args then the problem is solved, but I don't see it as a common need
torch_geometric/nn/aggr/base.py
Outdated
def __init__(self): | ||
super().__init__() | ||
self._forward_sub = self.forward | ||
self.forward = self._foward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is overly complicated. How about we just do
def forward:
self.validate()
self._forward()
And overwrite _forward
in child modules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I was thinking that too - just providing this option incase we want to keep the current abstract interface
torch_geometric/nn/aggr/base.py
Outdated
@@ -35,35 +75,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, | |||
def reset_parameters(self): | |||
pass | |||
|
|||
def __call__(self, x: Tensor, index: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so I understand better: what was the issue with this code to begin with? It looks like the tracer has problems when overwriting call. Is that correct? Is there any way to fix this in the tracer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, its because of this: https://github.com/pytorch/pytorch/blob/e68583b4d180066b8e4f108e0d23176a2676421c/torch/fx/_symbolic_trace.py#L702
the tracer only looks for a leaf module in the patched methods - but since we overwrite call its no longer patched (the super().__call__
is still patched)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we patch it ourselves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, yes... but we'll have to copy quite a bit of code from their implementation. I can implement that and we can see what it looks like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, depends on how easy it would be to try out and how large the code to patch would be. Do not burn too much time into it if it is not easy to fix :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this, but I don't know how the _Patcher
context manager works, so I was afraid to suggest it even if it happened to pass tests.
Could certainly invest time in understanding if you think its a big upgrade thouh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we quickly test if it works? I am always in favor of reducing number of lines :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, no doesn't work 'out of the box'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, then leave it as it is. Might be good to add some exhaustive documentation here, and also link to the PyTorch file this code is coming from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Will keep an eye open for better solutions too.
Codecov Report
@@ Coverage Diff @@
## master #5021 +/- ##
==========================================
- Coverage 84.77% 82.83% -1.95%
==========================================
Files 331 331
Lines 18115 18156 +41
==========================================
- Hits 15357 15039 -318
- Misses 2758 3117 +359
Help us with your feedback. Take ten seconds to tell us how you rate us. |
fx
problem for Aggregation
fx
problem for Aggregation
@rusty1s I've added a comment linking to this issue and added details in the description here. I think it is most useful for people looking at this later to be able to come back to this issue if they have any concerns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like to remove the warning before merging - otherwise LGTM. Thanks!
Will do. Thanks! |
This PR addresses a problem we had when using
torch.fx
on our newnn.aggr.Aggregation
class. This class redefines the__call__
function to contain:During a
torch.fx._symbolic_trace
pytorch modulestorch.nn.Module
calls are excluded from symbolic trancing intentionally usingis_leaf_module
; this allows for avoiding errors on symbolic tracing of for example, conditions statements involving tensors (which are not handled).However, this is implemented by patching a specific function within
torch.fx._symbolic_trace
:from here.
This doesn't patch our
__call__
as it no longer lives intorch.nn.Module
.The fix here is to reimplemnt the
trace
function to add an extra wrapper around our code.Note that this should be replaced by a better solution once one is available, by for example:
kwargs
inpre_forward_hooks
: nn.Module hooks ignore kwargs pytorch/pytorch#35643 (comment)torch.fx
team.