|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +============================== |
| 4 | +Graph classification with Tempate Based Fused Gromov Wasserstein |
| 5 | +============================== |
| 6 | +
|
| 7 | +This example first illustrates how to train a graph classification gnn based on the Template Fused Gromov Wasserstein layer as proposed in [52] . |
| 8 | +
|
| 9 | +[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022).Template based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35. |
| 10 | +
|
| 11 | +""" |
| 12 | + |
| 13 | +# Author: Sonia Mazelet <sonia.mazelet@ens-paris-saclay.fr> |
| 14 | +# Rémi Flamary <remi.flamary@unice.fr> |
| 15 | +# |
| 16 | +# License: MIT License |
| 17 | + |
| 18 | +# sphinx_gallery_thumbnail_number = 1 |
| 19 | + |
| 20 | +#%% |
| 21 | + |
| 22 | +import matplotlib.pyplot as pl |
| 23 | +import torch |
| 24 | +import networkx as nx |
| 25 | +from torch.utils.data import random_split |
| 26 | +from torch_geometric.loader import DataLoader |
| 27 | +from torch_geometric.utils import to_networkx, one_hot |
| 28 | +from torch_geometric.utils import stochastic_blockmodel_graph as sbm |
| 29 | +from torch_geometric.data import Data as GraphData |
| 30 | +import torch.nn as nn |
| 31 | +from torch_geometric.nn import Linear, GCNConv |
| 32 | +from ot.gnn import TFGWPooling |
| 33 | +from sklearn.manifold import TSNE |
| 34 | + |
| 35 | + |
| 36 | +############################################################################## |
| 37 | +# Generate data |
| 38 | +# ------------- |
| 39 | + |
| 40 | +# parameters |
| 41 | + |
| 42 | +# We create 2 classes of stochastic block models (SBM) graphs with 1 block and 2 blocks respectively. |
| 43 | + |
| 44 | +torch.manual_seed(0) |
| 45 | + |
| 46 | +n_graphs = 50 |
| 47 | +n_nodes = 10 |
| 48 | +n_node_classes = 2 |
| 49 | + |
| 50 | +#edge probabilities for the SBMs |
| 51 | +P1 = [[0.8]] |
| 52 | +P2 = [[0.9, 0.1], [0.1, 0.9]] |
| 53 | + |
| 54 | +#block sizes |
| 55 | +block_sizes1 = [n_nodes] |
| 56 | +block_sizes2 = [n_nodes // 2, n_nodes // 2] |
| 57 | + |
| 58 | +#node features |
| 59 | +x1 = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) |
| 60 | +x1 = one_hot(x1, num_classes=n_node_classes) |
| 61 | +x1 = torch.reshape(x1, (n_nodes, n_node_classes)) |
| 62 | + |
| 63 | +x2 = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) |
| 64 | +x2 = one_hot(x2, num_classes=n_node_classes) |
| 65 | +x2 = torch.reshape(x2, (n_nodes, n_node_classes)) |
| 66 | + |
| 67 | +graphs1 = [GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0])) for i in range(n_graphs)] |
| 68 | +graphs2 = [GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1])) for i in range(n_graphs)] |
| 69 | + |
| 70 | +graphs = graphs1 + graphs2 |
| 71 | + |
| 72 | +#split the data into train and test sets |
| 73 | +train_graphs, test_graphs = random_split(graphs, [n_graphs, n_graphs]) |
| 74 | + |
| 75 | +train_loader = DataLoader(train_graphs, batch_size=10, shuffle=True) |
| 76 | +test_loader = DataLoader(test_graphs, batch_size=10, shuffle=False) |
| 77 | + |
| 78 | + |
| 79 | +#%% |
| 80 | + |
| 81 | +############################################################################## |
| 82 | +# Plot data |
| 83 | +# --------- |
| 84 | + |
| 85 | +# plot one graph of each class |
| 86 | + |
| 87 | +fontsize = 10 |
| 88 | + |
| 89 | +pl.figure(0, figsize=(8, 2.5)) |
| 90 | +pl.clf() |
| 91 | +pl.subplot(121) |
| 92 | +pl.axis('off') |
| 93 | +pl.title('Graph of class 1', fontsize=fontsize) |
| 94 | +G = to_networkx(graphs1[0], to_undirected=True) |
| 95 | +pos = nx.spring_layout(G, seed=0) |
| 96 | +nx.draw_networkx(G, pos, with_labels=False, node_color="tab:blue") |
| 97 | + |
| 98 | +pl.subplot(122) |
| 99 | +pl.axis('off') |
| 100 | +pl.title('Graph of class 2', fontsize=fontsize) |
| 101 | +G = to_networkx(graphs2[0], to_undirected=True) |
| 102 | +pos = nx.spring_layout(G, seed=0) |
| 103 | +nx.draw_networkx(G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color="tab:blue") |
| 104 | +nx.draw_networkx(G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color="tab:red") |
| 105 | + |
| 106 | +pl.tight_layout() |
| 107 | +pl.show() |
| 108 | + |
| 109 | +#%% |
| 110 | + |
| 111 | +############################################################################## |
| 112 | +# Pooling architecture using the TFGW layer |
| 113 | +# --------- |
| 114 | + |
| 115 | + |
| 116 | +class pooling_TFGW(nn.Module): |
| 117 | + """ |
| 118 | + Pooling architecture using the TFGW layer. |
| 119 | + """ |
| 120 | + |
| 121 | + def __init__(self, n_features, n_templates, n_template_nodes, n_classes, n_hidden_layers, feature_init_mean=0., feature_init_std=1.): |
| 122 | + """ |
| 123 | + Pooling architecture using the TFGW layer. |
| 124 | + """ |
| 125 | + super().__init__() |
| 126 | + |
| 127 | + self.n_templates = n_templates |
| 128 | + self.n_template_nodes = n_template_nodes |
| 129 | + self.n_hidden_layers = n_hidden_layers |
| 130 | + self.n_features = n_features |
| 131 | + |
| 132 | + self.conv = GCNConv(self.n_features, self.n_hidden_layers) |
| 133 | + |
| 134 | + self.TFGW = TFGWPooling(self.n_hidden_layers, self.n_templates, self.n_template_nodes, feature_init_mean=feature_init_mean, feature_init_std=feature_init_std) |
| 135 | + |
| 136 | + self.linear = Linear(self.n_templates, n_classes) |
| 137 | + |
| 138 | + def forward(self, x, edge_index, batch=None): |
| 139 | + x = self.conv(x, edge_index) |
| 140 | + |
| 141 | + x = self.TFGW(x, edge_index, batch) |
| 142 | + |
| 143 | + x_latent = x # save latent embeddings for visualization |
| 144 | + |
| 145 | + x = self.linear(x) |
| 146 | + |
| 147 | + return x, x_latent |
| 148 | + |
| 149 | + |
| 150 | +############################################################################## |
| 151 | +# Graph classification training |
| 152 | +# --------- |
| 153 | + |
| 154 | + |
| 155 | +n_epochs = 25 |
| 156 | + |
| 157 | +#store latent embeddings and classes for TSNE visualization |
| 158 | +embeddings_for_TSNE = [] |
| 159 | +classes = [] |
| 160 | + |
| 161 | +model = pooling_TFGW(n_features=2, n_templates=2, n_template_nodes=2, n_classes=2, n_hidden_layers=2, feature_init_mean=0.5, feature_init_std=0.5) |
| 162 | + |
| 163 | +optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005) |
| 164 | +criterion = torch.nn.CrossEntropyLoss() |
| 165 | + |
| 166 | +all_accuracy = [] |
| 167 | +all_loss = [] |
| 168 | + |
| 169 | +for epoch in range(n_epochs): |
| 170 | + |
| 171 | + losses = [] |
| 172 | + accs = [] |
| 173 | + |
| 174 | + for data in train_loader: |
| 175 | + out, latent_embedding = model(data.x, data.edge_index, data.batch) |
| 176 | + loss = criterion(out, data.y) |
| 177 | + loss.backward() |
| 178 | + optimizer.step() |
| 179 | + |
| 180 | + pred = out.argmax(dim=1) |
| 181 | + train_correct = pred == data.y |
| 182 | + train_acc = int(train_correct.sum()) / len(data) |
| 183 | + |
| 184 | + accs.append(train_acc) |
| 185 | + losses.append(loss.item()) |
| 186 | + |
| 187 | + #store last classes and embeddings for TSNE visualization |
| 188 | + if epoch == n_epochs - 1: |
| 189 | + embeddings_for_TSNE.append(latent_embedding) |
| 190 | + classes.append(data.y) |
| 191 | + |
| 192 | + print(f'Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}') |
| 193 | + |
| 194 | + all_accuracy.append(torch.mean(torch.tensor(accs))) |
| 195 | + all_loss.append(torch.mean(torch.tensor(losses))) |
| 196 | + |
| 197 | + |
| 198 | +pl.figure(1, figsize=(8, 2.5)) |
| 199 | +pl.clf() |
| 200 | +pl.subplot(121) |
| 201 | +pl.plot(all_loss) |
| 202 | +pl.xlabel('epochs') |
| 203 | +pl.title('Loss') |
| 204 | + |
| 205 | +pl.subplot(122) |
| 206 | +pl.plot(all_accuracy) |
| 207 | +pl.xlabel('epochs') |
| 208 | +pl.title('Accuracy') |
| 209 | + |
| 210 | +pl.tight_layout() |
| 211 | +pl.show() |
| 212 | + |
| 213 | +#Test |
| 214 | + |
| 215 | +test_accs = [] |
| 216 | + |
| 217 | +for data in test_loader: |
| 218 | + out, latent_embedding = model(data.x, data.edge_index, data.batch) |
| 219 | + pred = out.argmax(dim=1) |
| 220 | + test_correct = pred == data.y |
| 221 | + test_acc = int(test_correct.sum()) / len(data) |
| 222 | + test_accs.append(test_acc) |
| 223 | + embeddings_for_TSNE.append(latent_embedding) |
| 224 | + classes.append(data.y) |
| 225 | + |
| 226 | +classes = torch.hstack(classes) |
| 227 | + |
| 228 | +print(f'Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}') |
| 229 | + |
| 230 | +#%% |
| 231 | +############################################################################## |
| 232 | +# TSNE visualization of graph classification |
| 233 | +# --------- |
| 234 | + |
| 235 | +indices = torch.randint(2 * n_graphs, (60,)) # select a subset of embeddings for TSNE visualization |
| 236 | +latent_embeddings = torch.vstack(embeddings_for_TSNE).detach().numpy()[indices, :] |
| 237 | + |
| 238 | +TSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform(latent_embeddings) |
| 239 | + |
| 240 | +class_0 = classes[indices] == 0 |
| 241 | +class_1 = classes[indices] == 1 |
| 242 | + |
| 243 | +TSNE_embeddings_0 = TSNE_embeddings[class_0, :] |
| 244 | +TSNE_embeddings_1 = TSNE_embeddings[class_1, :] |
| 245 | + |
| 246 | +pl.figure(2, figsize=(6, 2.5)) |
| 247 | +pl.scatter(TSNE_embeddings_0[:, 0], TSNE_embeddings_0[:, 1], |
| 248 | + alpha=0.5, marker='o', label='class 1') |
| 249 | +pl.scatter(TSNE_embeddings_1[:, 0], TSNE_embeddings_1[:, 1], |
| 250 | + alpha=0.5, marker='o', label='class 2') |
| 251 | +pl.legend() |
| 252 | +pl.title('TSNE in the latent space after training') |
| 253 | +pl.show() |
| 254 | + |
| 255 | + |
| 256 | +# %% |
0 commit comments