-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to compile model with GATv2Conv layers #9603
Comments
#9007 added support for |
After installing the master branch with Error output (Click to expand)
|
Thank you for providing the repro and complete error message. A quick workaround is to initialize the parameters in advance to passing the model to model = GNN(num_channels, num_classes, 4, 4)
+ dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node")
+ model(dataset[0].x, dataset[0].edge_index)
model = torch.compile(model, dynamic=True, fullgraph=True)
- dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node") We should decorate |
Thanks for this suggestion. Unfortunately this still doesn't run for me, but the error states New MWE (Click to expand)import torch
torch._dynamo.config.capture_dynamic_output_shape_ops = True
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv as conv
from torch_geometric.datasets import FakeDataset
class GNN(torch.nn.Module):
def __init__(self, features, classes, hidden_width, layers):
super().__init__()
self.layers = torch.nn.ModuleList([conv(features, hidden_width, heads=1)])
for i in range(layers-2):
self.layers.append(conv(-1, hidden_width, heads=1))
self.layers.append(conv(-1, classes))
self.act = F.gelu
def forward(self, x, edge_index):
for l in self.layers[:-1]:
x = l(x, edge_index)
x = self.act(x)
x = self.layers[-1](x, edge_index)
return x
if __name__ == "__main__":
num_channels = 2
num_classes = 2
model = GNN(num_channels, num_classes, 4, 4)
dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node")
model(dataset[0].x, dataset[0].edge_index)
model = torch.compile(model, dynamic=True, fullgraph=True)
for data in dataset:
out = model(data.x, data.edge_index)
print(out) Error output (Click to expand)
|
@jusevitch I ran the script again on PyG master with PyTorch nightly, and it worked without any issue. Can you retry with newer version of PyG and PyTorch?
This wouldn't work because disabling the region will produce a graph break anyway. I don't see any other solutions where users can call |
🐛 Describe the bug
I'm unable to use
torch.compile
to compile a simple model using GATv2Conv layers. A MWE is below.(Edit: I get a similar error when using
torch_geometric.compile()
.)MWE
Error Output (click to expand)
Versions
Version Information:
The text was updated successfully, but these errors were encountered: