Skip to content

Commit

Permalink
Merge pull request #27 from zheng-da/master
Browse files Browse the repository at this point in the history
add RGCN
  • Loading branch information
acbull committed Jan 17, 2021
2 parents 9a13d9c + 3ae707d commit 8ac50bf
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions example_OAG/GPT_GNN/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import GCNConv, GATConv, RGCNConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, uniform
from torch_geometric.utils import softmax
Expand Down Expand Up @@ -164,12 +164,14 @@ def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads
self.base_conv = GCNConv(in_hid, out_hid)
elif self.conv_name == 'gat':
self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads)
elif self.conv_name == 'rgcn':
self.base_conv = RGCNConv(in_hid, out_hid, num_relations)
def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time):
if self.conv_name == 'hgt':
return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time)
elif self.conv_name == 'gcn':
return self.base_conv(meta_xs, edge_index)
elif self.conv_name == 'gat':
return self.base_conv(meta_xs, edge_index)


elif self.conv_name == 'rgcn':
return self.base_conv(meta_xs, edge_index, edge_type)

0 comments on commit 8ac50bf

Please sign in to comment.