diff --git a/AIAgent/ml/models/RGCNEdgeTypeTAG3VerticesDoubleHistory2/model.py b/AIAgent/ml/models/RGCNEdgeTypeTAG3VerticesDoubleHistory2/model.py index 16f9bafa..0fc5d3ab 100644 --- a/AIAgent/ml/models/RGCNEdgeTypeTAG3VerticesDoubleHistory2/model.py +++ b/AIAgent/ml/models/RGCNEdgeTypeTAG3VerticesDoubleHistory2/model.py @@ -9,11 +9,13 @@ def __init__(self, hidden_channels, out_channels): self.conv1 = RGCNConv(hidden_channels, hidden_channels, 3) self.conv10 = TAGConv(7, hidden_channels, 3) self.conv2 = TAGConv(hidden_channels, hidden_channels, 3) - self.conv3 = ResGatedGraphConv((-1, -1), hidden_channels, edge_dim=1) - self.conv32 = SAGEConv((-1, -1), hidden_channels) - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.conv42 = SAGEConv((-1, -1), hidden_channels) - self.conv5 = SAGEConv(-1, hidden_channels) + self.conv3 = ResGatedGraphConv( + (hidden_channels, 7), hidden_channels, edge_dim=2 + ) + self.conv32 = SAGEConv((hidden_channels, hidden_channels), hidden_channels) + self.conv4 = SAGEConv((hidden_channels, hidden_channels), hidden_channels) + self.conv42 = SAGEConv((hidden_channels, hidden_channels), hidden_channels) + self.conv5 = SAGEConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(