-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
dmgi_unsup.py
121 lines (93 loc) · 4.05 KB
/
dmgi_unsup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# An implementation of "Unsupervised Attributed Multiplex Network
# Embedding" (DMGI) for unsupervised learning on heterogeneous graphs:
# * Paper: <https://arxiv.org/abs/1911.06750> (AAAI 2020)
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from torch.optim import Adam
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import IMDB
from torch_geometric.nn import GCNConv
path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/IMDB')
dataset = IMDB(path)
metapaths = [
[('movie', 'actor'), ('actor', 'movie')], # MAM
[('movie', 'director'), ('director', 'movie')], # MDM
]
data = T.AddMetaPaths(metapaths, drop_orig_edge_types=True)(dataset[0])
class DMGI(torch.nn.Module):
def __init__(self, num_nodes, in_channels, out_channels, num_relations):
super().__init__()
self.convs = torch.nn.ModuleList(
[GCNConv(in_channels, out_channels) for _ in range(num_relations)])
self.M = torch.nn.Bilinear(out_channels, out_channels, 1)
self.Z = torch.nn.Parameter(torch.empty(num_nodes, out_channels))
self.reset_parameters()
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
torch.nn.init.xavier_uniform_(self.M.weight)
self.M.bias.data.zero_()
torch.nn.init.xavier_uniform_(self.Z)
def forward(self, x, edge_indices):
pos_hs, neg_hs, summaries = [], [], []
for conv, edge_index in zip(self.convs, edge_indices):
pos_h = F.dropout(x, p=0.5, training=self.training)
pos_h = conv(pos_h, edge_index).relu()
pos_hs.append(pos_h)
neg_h = F.dropout(x, p=0.5, training=self.training)
neg_h = neg_h[torch.randperm(neg_h.size(0), device=neg_h.device)]
neg_h = conv(neg_h, edge_index).relu()
neg_hs.append(neg_h)
summaries.append(pos_h.mean(dim=0, keepdim=True))
return pos_hs, neg_hs, summaries
def loss(self, pos_hs, neg_hs, summaries):
loss = 0.
for pos_h, neg_h, s in zip(pos_hs, neg_hs, summaries):
s = s.expand_as(pos_h)
loss += -torch.log(self.M(pos_h, s).sigmoid() + 1e-15).mean()
loss += -torch.log(1 - self.M(neg_h, s).sigmoid() + 1e-15).mean()
pos_mean = torch.stack(pos_hs, dim=0).mean(dim=0)
neg_mean = torch.stack(neg_hs, dim=0).mean(dim=0)
pos_reg_loss = (self.Z - pos_mean).pow(2).sum()
neg_reg_loss = (self.Z - neg_mean).pow(2).sum()
loss += 0.001 * (pos_reg_loss - neg_reg_loss)
return loss
model = DMGI(data['movie'].num_nodes, data['movie'].x.size(-1),
out_channels=64, num_relations=len(data.edge_types))
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
data, model = data.to(device), model.to(device)
optimizer = Adam(model.parameters(), lr=0.0005, weight_decay=0.0001)
def train():
model.train()
optimizer.zero_grad()
x = data['movie'].x
edge_indices = data.edge_index_dict.values()
pos_hs, neg_hs, summaries = model(x, edge_indices)
loss = model.loss(pos_hs, neg_hs, summaries)
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
train_emb = model.Z[data['movie'].train_mask].cpu()
val_emb = model.Z[data['movie'].val_mask].cpu()
test_emb = model.Z[data['movie'].test_mask].cpu()
train_y = data['movie'].y[data['movie'].train_mask].cpu()
val_y = data['movie'].y[data['movie'].val_mask].cpu()
test_y = data['movie'].y[data['movie'].test_mask].cpu()
clf = LogisticRegression().fit(train_emb, train_y)
return clf.score(val_emb, val_y), clf.score(test_emb, test_y)
for epoch in range(1, 1001):
loss = train()
if epoch % 50 == 0:
val_acc, test_acc = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')