Skip to content

Commit 88a1fb1

Browse files
SoniaMaz8rflamary
andauthored
[FEAT] Template Fused Gromov Wasserstein layer (#488)
* TFGW_layer and test * TGFW-layer and test * readme * releases * modif * modif * modif * modif * autopep * importations * releases modification * debug * debug * blank spaces * add skip * import torch_geometric * Update RELEASES.md Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> * Update ot/gnn/_layers.py Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> * doc * change name doc * example * autopep * example time * autopep * remove reference in TGWPooling * typo * change batch==None * add wasserstein layer + TSNE * more comments * add citation OTGNN * Update README.md Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> * modif code review * autopep * debug test * readme * change math description * debug html * debug html * typo * typos * debug * add gnn to doc * change ref numbers * warning, module title --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent b11f5a0 commit 88a1fb1

File tree

11 files changed

+942
-2
lines changed

11 files changed

+942
-2
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ The contributors to this library are:
4242
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
4343
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
4444
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
45+
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
4546

4647
## Acknowledgments
4748

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ POT provides the following Machine Learning related solvers:
5454
* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8].
5555
* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt).
5656
* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27].
57+
* Graph Neural Network OT layers TFGW [52] and TW (OT-GNN) [53] (https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html)
5758

5859
Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html).
5960

@@ -314,3 +315,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
314315
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.
315316

316317
[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv.
318+
319+
[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35.
320+
321+
[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804).

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 0.9.1dev
44

55
#### New features
6+
- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488)
67
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
78
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
89
- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459)

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ API and modules
2222
dr
2323
factored
2424
gaussian
25+
gnn
2526
gromov
2627
lp
2728
optim

examples/gromov/plot_gnn_TFGW.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==============================
4+
Graph classification with Tempate Based Fused Gromov Wasserstein
5+
==============================
6+
7+
This example first illustrates how to train a graph classification gnn based on the Template Fused Gromov Wasserstein layer as proposed in [52] .
8+
9+
[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022).Template based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35.
10+
11+
"""
12+
13+
# Author: Sonia Mazelet <sonia.mazelet@ens-paris-saclay.fr>
14+
# Rémi Flamary <remi.flamary@unice.fr>
15+
#
16+
# License: MIT License
17+
18+
# sphinx_gallery_thumbnail_number = 1
19+
20+
#%%
21+
22+
import matplotlib.pyplot as pl
23+
import torch
24+
import networkx as nx
25+
from torch.utils.data import random_split
26+
from torch_geometric.loader import DataLoader
27+
from torch_geometric.utils import to_networkx, one_hot
28+
from torch_geometric.utils import stochastic_blockmodel_graph as sbm
29+
from torch_geometric.data import Data as GraphData
30+
import torch.nn as nn
31+
from torch_geometric.nn import Linear, GCNConv
32+
from ot.gnn import TFGWPooling
33+
from sklearn.manifold import TSNE
34+
35+
36+
##############################################################################
37+
# Generate data
38+
# -------------
39+
40+
# parameters
41+
42+
# We create 2 classes of stochastic block models (SBM) graphs with 1 block and 2 blocks respectively.
43+
44+
torch.manual_seed(0)
45+
46+
n_graphs = 50
47+
n_nodes = 10
48+
n_node_classes = 2
49+
50+
#edge probabilities for the SBMs
51+
P1 = [[0.8]]
52+
P2 = [[0.9, 0.1], [0.1, 0.9]]
53+
54+
#block sizes
55+
block_sizes1 = [n_nodes]
56+
block_sizes2 = [n_nodes // 2, n_nodes // 2]
57+
58+
#node features
59+
x1 = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
60+
x1 = one_hot(x1, num_classes=n_node_classes)
61+
x1 = torch.reshape(x1, (n_nodes, n_node_classes))
62+
63+
x2 = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
64+
x2 = one_hot(x2, num_classes=n_node_classes)
65+
x2 = torch.reshape(x2, (n_nodes, n_node_classes))
66+
67+
graphs1 = [GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0])) for i in range(n_graphs)]
68+
graphs2 = [GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1])) for i in range(n_graphs)]
69+
70+
graphs = graphs1 + graphs2
71+
72+
#split the data into train and test sets
73+
train_graphs, test_graphs = random_split(graphs, [n_graphs, n_graphs])
74+
75+
train_loader = DataLoader(train_graphs, batch_size=10, shuffle=True)
76+
test_loader = DataLoader(test_graphs, batch_size=10, shuffle=False)
77+
78+
79+
#%%
80+
81+
##############################################################################
82+
# Plot data
83+
# ---------
84+
85+
# plot one graph of each class
86+
87+
fontsize = 10
88+
89+
pl.figure(0, figsize=(8, 2.5))
90+
pl.clf()
91+
pl.subplot(121)
92+
pl.axis('off')
93+
pl.title('Graph of class 1', fontsize=fontsize)
94+
G = to_networkx(graphs1[0], to_undirected=True)
95+
pos = nx.spring_layout(G, seed=0)
96+
nx.draw_networkx(G, pos, with_labels=False, node_color="tab:blue")
97+
98+
pl.subplot(122)
99+
pl.axis('off')
100+
pl.title('Graph of class 2', fontsize=fontsize)
101+
G = to_networkx(graphs2[0], to_undirected=True)
102+
pos = nx.spring_layout(G, seed=0)
103+
nx.draw_networkx(G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color="tab:blue")
104+
nx.draw_networkx(G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color="tab:red")
105+
106+
pl.tight_layout()
107+
pl.show()
108+
109+
#%%
110+
111+
##############################################################################
112+
# Pooling architecture using the TFGW layer
113+
# ---------
114+
115+
116+
class pooling_TFGW(nn.Module):
117+
"""
118+
Pooling architecture using the TFGW layer.
119+
"""
120+
121+
def __init__(self, n_features, n_templates, n_template_nodes, n_classes, n_hidden_layers, feature_init_mean=0., feature_init_std=1.):
122+
"""
123+
Pooling architecture using the TFGW layer.
124+
"""
125+
super().__init__()
126+
127+
self.n_templates = n_templates
128+
self.n_template_nodes = n_template_nodes
129+
self.n_hidden_layers = n_hidden_layers
130+
self.n_features = n_features
131+
132+
self.conv = GCNConv(self.n_features, self.n_hidden_layers)
133+
134+
self.TFGW = TFGWPooling(self.n_hidden_layers, self.n_templates, self.n_template_nodes, feature_init_mean=feature_init_mean, feature_init_std=feature_init_std)
135+
136+
self.linear = Linear(self.n_templates, n_classes)
137+
138+
def forward(self, x, edge_index, batch=None):
139+
x = self.conv(x, edge_index)
140+
141+
x = self.TFGW(x, edge_index, batch)
142+
143+
x_latent = x # save latent embeddings for visualization
144+
145+
x = self.linear(x)
146+
147+
return x, x_latent
148+
149+
150+
##############################################################################
151+
# Graph classification training
152+
# ---------
153+
154+
155+
n_epochs = 25
156+
157+
#store latent embeddings and classes for TSNE visualization
158+
embeddings_for_TSNE = []
159+
classes = []
160+
161+
model = pooling_TFGW(n_features=2, n_templates=2, n_template_nodes=2, n_classes=2, n_hidden_layers=2, feature_init_mean=0.5, feature_init_std=0.5)
162+
163+
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005)
164+
criterion = torch.nn.CrossEntropyLoss()
165+
166+
all_accuracy = []
167+
all_loss = []
168+
169+
for epoch in range(n_epochs):
170+
171+
losses = []
172+
accs = []
173+
174+
for data in train_loader:
175+
out, latent_embedding = model(data.x, data.edge_index, data.batch)
176+
loss = criterion(out, data.y)
177+
loss.backward()
178+
optimizer.step()
179+
180+
pred = out.argmax(dim=1)
181+
train_correct = pred == data.y
182+
train_acc = int(train_correct.sum()) / len(data)
183+
184+
accs.append(train_acc)
185+
losses.append(loss.item())
186+
187+
#store last classes and embeddings for TSNE visualization
188+
if epoch == n_epochs - 1:
189+
embeddings_for_TSNE.append(latent_embedding)
190+
classes.append(data.y)
191+
192+
print(f'Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}')
193+
194+
all_accuracy.append(torch.mean(torch.tensor(accs)))
195+
all_loss.append(torch.mean(torch.tensor(losses)))
196+
197+
198+
pl.figure(1, figsize=(8, 2.5))
199+
pl.clf()
200+
pl.subplot(121)
201+
pl.plot(all_loss)
202+
pl.xlabel('epochs')
203+
pl.title('Loss')
204+
205+
pl.subplot(122)
206+
pl.plot(all_accuracy)
207+
pl.xlabel('epochs')
208+
pl.title('Accuracy')
209+
210+
pl.tight_layout()
211+
pl.show()
212+
213+
#Test
214+
215+
test_accs = []
216+
217+
for data in test_loader:
218+
out, latent_embedding = model(data.x, data.edge_index, data.batch)
219+
pred = out.argmax(dim=1)
220+
test_correct = pred == data.y
221+
test_acc = int(test_correct.sum()) / len(data)
222+
test_accs.append(test_acc)
223+
embeddings_for_TSNE.append(latent_embedding)
224+
classes.append(data.y)
225+
226+
classes = torch.hstack(classes)
227+
228+
print(f'Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}')
229+
230+
#%%
231+
##############################################################################
232+
# TSNE visualization of graph classification
233+
# ---------
234+
235+
indices = torch.randint(2 * n_graphs, (60,)) # select a subset of embeddings for TSNE visualization
236+
latent_embeddings = torch.vstack(embeddings_for_TSNE).detach().numpy()[indices, :]
237+
238+
TSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform(latent_embeddings)
239+
240+
class_0 = classes[indices] == 0
241+
class_1 = classes[indices] == 1
242+
243+
TSNE_embeddings_0 = TSNE_embeddings[class_0, :]
244+
TSNE_embeddings_1 = TSNE_embeddings[class_1, :]
245+
246+
pl.figure(2, figsize=(6, 2.5))
247+
pl.scatter(TSNE_embeddings_0[:, 0], TSNE_embeddings_0[:, 1],
248+
alpha=0.5, marker='o', label='class 1')
249+
pl.scatter(TSNE_embeddings_1[:, 0], TSNE_embeddings_1[:, 1],
250+
alpha=0.5, marker='o', label='class 2')
251+
pl.legend()
252+
pl.title('TSNE in the latent space after training')
253+
pl.show()
254+
255+
256+
# %%

ot/gnn/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Layers and functions for optimal transport in Graph Neural Networks.
4+
5+
.. warning::
6+
Note that by default the module is not imported in :mod:`ot`. In order to
7+
use it you need to explicitly import :mod:`ot.gnn`. This module is PyTorch Geometric dependent.
8+
The layers are compatible with their API.
9+
10+
"""
11+
12+
# Author: Sonia Mazelet <sonia.mazelet@ens-paris-saclay.fr>
13+
# Rémi Flamary <remi.flamary@unice.fr>
14+
#
15+
# License: MIT License
16+
17+
# All submodules and packages
18+
19+
20+
from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates)
21+
22+
from ._layers import (TFGWPooling,TWPooling)
23+
24+
__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling']

0 commit comments

Comments
 (0)