Skip to content

Commit

Permalink
Set edge_dim=2 and replace -1 with exact values in order to fix weigh…
Browse files Browse the repository at this point in the history
…ts loading.
  • Loading branch information
Anya497 committed Feb 2, 2024
1 parent 942c783 commit 37e8235
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ def __init__(self, hidden_channels, out_channels):
self.conv1 = RGCNConv(hidden_channels, hidden_channels, 3)
self.conv10 = TAGConv(7, hidden_channels, 3)
self.conv2 = TAGConv(hidden_channels, hidden_channels, 3)
self.conv3 = ResGatedGraphConv((-1, -1), hidden_channels, edge_dim=1)
self.conv32 = SAGEConv((-1, -1), hidden_channels)
self.conv4 = SAGEConv((-1, -1), hidden_channels)
self.conv42 = SAGEConv((-1, -1), hidden_channels)
self.conv5 = SAGEConv(-1, hidden_channels)
self.conv3 = ResGatedGraphConv(
(hidden_channels, 7), hidden_channels, edge_dim=2
)
self.conv32 = SAGEConv((hidden_channels, hidden_channels), hidden_channels)
self.conv4 = SAGEConv((hidden_channels, hidden_channels), hidden_channels)
self.conv42 = SAGEConv((hidden_channels, hidden_channels), hidden_channels)
self.conv5 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(
Expand Down

0 comments on commit 37e8235

Please sign in to comment.