-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
gnn_explainer_ba_shapes.py
96 lines (78 loc) · 3.16 KB
/
gnn_explainer_ba_shapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch_geometric.transforms as T
from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.nn import GCN
from torch_geometric.utils import k_hop_subgraph
dataset = ExplainerDataset(
graph_generator=BAGraph(num_nodes=300, num_edges=5),
motif_generator='house',
num_motifs=80,
transform=T.Constant(),
)
data = dataset[0]
idx = torch.arange(data.num_nodes)
train_idx, test_idx = train_test_split(idx, train_size=0.8, stratify=data.y)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = GCN(data.num_node_features, hidden_channels=20, num_layers=3,
out_channels=dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[train_idx], data.y[train_idx])
torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=-1)
train_correct = int((pred[train_idx] == data.y[train_idx]).sum())
train_acc = train_correct / train_idx.size(0)
test_correct = int((pred[test_idx] == data.y[test_idx]).sum())
test_acc = test_correct / test_idx.size(0)
return train_acc, test_acc
pbar = tqdm(range(1, 2001))
for epoch in pbar:
loss = train()
if epoch == 1 or epoch % 200 == 0:
train_acc, test_acc = test()
pbar.set_description(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Test: {test_acc:.4f}')
pbar.close()
model.eval()
for explanation_type in ['phenomenon', 'model']:
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=300),
explanation_type=explanation_type,
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='raw',
),
)
# Explanation ROC AUC over all test nodes:
targets, preds = [], []
node_indices = range(400, data.num_nodes, 5)
for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'):
target = data.y if explanation_type == 'phenomenon' else None
explanation = explainer(data.x, data.edge_index, index=node_index,
target=target)
_, _, _, hard_edge_mask = k_hop_subgraph(node_index, num_hops=3,
edge_index=data.edge_index)
targets.append(data.edge_mask[hard_edge_mask].cpu())
preds.append(explanation.edge_mask[hard_edge_mask].cpu())
auc = roc_auc_score(torch.cat(targets), torch.cat(preds))
print(f'Mean ROC AUC (explanation type {explanation_type:10}): {auc:.4f}')