Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fx problem for Aggregation #5021

Merged
merged 7 commits into from
Jul 23, 2022
Merged

Fix fx problem for Aggregation #5021

merged 7 commits into from
Jul 23, 2022

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Jul 21, 2022

This PR addresses a problem we had when using torch.fx on our new nn.aggr.Aggregation class. This class redefines the __call__ function to contain:

     def __call__(self, x: Tensor, index: Optional[Tensor] = None,
                 ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                 dim: int = -2, **kwargs) -> Tensor:

        if dim >= x.dim() or dim < -x.dim():
            raise ValueError(f"Encountered invalid dimension '{dim}' of "
                             f"source tensor with {x.dim()} dimensions")
        ...

During a torch.fx._symbolic_trace pytorch modules torch.nn.Module calls are excluded from symbolic trancing intentionally using is_leaf_module; this allows for avoiding errors on symbolic tracing of for example, conditions statements involving tensors (which are not handled).

However, this is implemented by patching a specific function within torch.fx._symbolic_trace:

patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False)

from here.

This doesn't patch our __call__ as it no longer lives in torch.nn.Module.

The fix here is to reimplemnt the trace function to add an extra wrapper around our code.

Note that this should be replaced by a better solution once one is available, by for example:

if op == 'call_module':
return isinstance(get_submodule(module, target), GlobalPooling)
return isinstance(get_submodule(module, target),
GlobalPooling) or isinstance(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let‘s just replace it here.


# @abstractmethod
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
dim: int = -2, **kwargs) -> Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add kwargs here in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of the GraphMultisetTransformer - which also takes edge_index optionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can just add this to the default args then the problem is solved, but I don't see it as a common need

def __init__(self):
super().__init__()
self._forward_sub = self.forward
self.forward = self._foward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is overly complicated. How about we just do

def forward:
    self.validate()
    self._forward()

And overwrite _forward in child modules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I was thinking that too - just providing this option incase we want to keep the current abstract interface

@@ -35,35 +75,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
def reset_parameters(self):
pass

def __call__(self, x: Tensor, index: Optional[Tensor] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I understand better: what was the issue with this code to begin with? It looks like the tracer has problems when overwriting call. Is that correct? Is there any way to fix this in the tracer?

Copy link
Contributor Author

@Padarn Padarn Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its because of this: https://github.com/pytorch/pytorch/blob/e68583b4d180066b8e4f108e0d23176a2676421c/torch/fx/_symbolic_trace.py#L702
the tracer only looks for a leaf module in the patched methods - but since we overwrite call its no longer patched (the super().__call__ is still patched)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we patch it ourselves?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, yes... but we'll have to copy quite a bit of code from their implementation. I can implement that and we can see what it looks like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, depends on how easy it would be to try out and how large the code to patch would be. Do not burn too much time into it if it is not easy to fix :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this, but I don't know how the _Patcher context manager works, so I was afraid to suggest it even if it happened to pass tests.

Could certainly invest time in understanding if you think its a big upgrade thouh.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we quickly test if it works? I am always in favor of reducing number of lines :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, no doesn't work 'out of the box'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then leave it as it is. Might be good to add some exhaustive documentation here, and also link to the PyTorch file this code is coming from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Will keep an eye open for better solutions too.

@codecov
Copy link

codecov bot commented Jul 22, 2022

Codecov Report

Merging #5021 (6827cd9) into master (be9e4af) will decrease coverage by 1.94%.
The diff coverage is 93.61%.

@@            Coverage Diff             @@
##           master    #5021      +/-   ##
==========================================
- Coverage   84.77%   82.83%   -1.95%     
==========================================
  Files         331      331              
  Lines       18115    18156      +41     
==========================================
- Hits        15357    15039     -318     
- Misses       2758     3117     +359     
Impacted Files Coverage Δ
torch_geometric/nn/to_hetero_transformer.py 95.26% <ø> (ø)
torch_geometric/nn/fx.py 90.05% <93.61%> (+1.01%) ⬆️
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.51% <0.00%> (-53.00%) ⬇️
torch_geometric/nn/glob/glob.py 60.52% <0.00%> (-26.32%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/profile/profile.py 32.94% <0.00%> (-15.30%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/nn/resolver.py 88.00% <0.00%> (-6.00%) ⬇️
torch_geometric/transforms/add_self_loops.py 94.44% <0.00%> (-5.56%) ⬇️
... and 11 more

Help us with your feedback. Take ten seconds to tell us how you rate us.

@Padarn Padarn changed the title [Discuss] Fix fx problem for Aggregation Fix fx problem for Aggregation Jul 22, 2022
@Padarn
Copy link
Contributor Author

Padarn commented Jul 22, 2022

@rusty1s I've added a comment linking to this issue and added details in the description here. I think it is most useful for people looking at this later to be able to come back to this issue if they have any concerns.

torch_geometric/nn/fx.py Outdated Show resolved Hide resolved
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to remove the warning before merging - otherwise LGTM. Thanks!

torch_geometric/nn/fx.py Outdated Show resolved Hide resolved
torch_geometric/nn/fx.py Show resolved Hide resolved
@Padarn
Copy link
Contributor Author

Padarn commented Jul 23, 2022

Will do. Thanks!

@Padarn Padarn enabled auto-merge (squash) July 23, 2022 06:47
@Padarn Padarn merged commit 06dbf5b into pyg-team:master Jul 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants