Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
22142b2
TFGW_layer and test
SoniaMaz8 Jul 7, 2023
1da3edc
TGFW-layer and test
SoniaMaz8 Jul 7, 2023
9d6fa41
readme
SoniaMaz8 Jul 7, 2023
804ca7f
releases
SoniaMaz8 Jul 7, 2023
011552d
modif
SoniaMaz8 Jul 7, 2023
b483c1b
modif
SoniaMaz8 Jul 7, 2023
0b515f3
modif
SoniaMaz8 Jul 7, 2023
e04900d
modif
SoniaMaz8 Jul 7, 2023
81662c4
autopep
SoniaMaz8 Jul 7, 2023
a6c3edd
importations
SoniaMaz8 Jul 7, 2023
272bdea
releases modification
SoniaMaz8 Jul 7, 2023
330d8ed
debug
SoniaMaz8 Jul 7, 2023
55c3234
debug
SoniaMaz8 Jul 7, 2023
79db4f0
blank spaces
SoniaMaz8 Jul 7, 2023
c5eb1d5
add skip
SoniaMaz8 Jul 7, 2023
258ce22
import torch_geometric
SoniaMaz8 Jul 7, 2023
2840de6
Merge branch 'master' into TFGW-layer
rflamary Jul 7, 2023
90e57a2
Update RELEASES.md
SoniaMaz8 Jul 10, 2023
ec8aaa9
Update ot/gnn/_layers.py
SoniaMaz8 Jul 10, 2023
a37f77f
doc
SoniaMaz8 Jul 17, 2023
65ab563
change name doc
SoniaMaz8 Jul 17, 2023
fbcb052
example
SoniaMaz8 Jul 18, 2023
e9db027
autopep
SoniaMaz8 Jul 18, 2023
0667e6f
example time
SoniaMaz8 Jul 18, 2023
e342bc7
autopep
SoniaMaz8 Jul 18, 2023
643c500
remove reference in TGWPooling
SoniaMaz8 Jul 18, 2023
27cfce4
typo
SoniaMaz8 Jul 18, 2023
3dccf1c
change batch==None
SoniaMaz8 Jul 18, 2023
d1ab94d
add wasserstein layer + TSNE
SoniaMaz8 Jul 18, 2023
87a9e9f
more comments
SoniaMaz8 Jul 18, 2023
5d254ff
add citation OTGNN
SoniaMaz8 Jul 18, 2023
1e582fb
Update README.md
SoniaMaz8 Jul 19, 2023
68807ec
modif code review
SoniaMaz8 Jul 19, 2023
c1a43ca
Merge branch 'TFGW-layer' of https://github.com/SoniaMaz8/POT into TF…
SoniaMaz8 Jul 19, 2023
17445fb
autopep
SoniaMaz8 Jul 19, 2023
adc0807
debug test
SoniaMaz8 Jul 19, 2023
aabe980
readme
SoniaMaz8 Jul 19, 2023
f0334e5
change math description
SoniaMaz8 Jul 19, 2023
0eb720d
debug html
SoniaMaz8 Jul 19, 2023
5ed9b94
debug html
SoniaMaz8 Jul 19, 2023
b4e09aa
typo
SoniaMaz8 Jul 19, 2023
ea33f6a
typos
SoniaMaz8 Jul 19, 2023
cdf1851
debug
SoniaMaz8 Jul 19, 2023
859e566
Merge branch 'master' into TFGW-layer
rflamary Jul 19, 2023
46e69bc
add gnn to doc
SoniaMaz8 Jul 19, 2023
b8a28f8
change ref numbers
SoniaMaz8 Jul 19, 2023
05d6607
warning, module title
SoniaMaz8 Jul 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The contributors to this library are:
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)

## Acknowledgments

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ POT provides the following Machine Learning related solvers:
* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8].
* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt).
* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27].
* Graph Neural Network OT layers TFGW [52] and TW (OT-GNN) [53] (https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html)

Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html).

Expand Down Expand Up @@ -314,3 +315,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.

[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv.

[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35.

[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804).
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 0.9.1dev

#### New features
- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488)
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459)
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ API and modules
dr
factored
gaussian
gnn
gromov
lp
optim
Expand Down
256 changes: 256 additions & 0 deletions examples/gromov/plot_gnn_TFGW.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# -*- coding: utf-8 -*-
"""
==============================
Graph classification with Tempate Based Fused Gromov Wasserstein
==============================

This example first illustrates how to train a graph classification gnn based on the Template Fused Gromov Wasserstein layer as proposed in [52] .

[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022).Template based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35.

"""

# Author: Sonia Mazelet <sonia.mazelet@ens-paris-saclay.fr>
# Rémi Flamary <remi.flamary@unice.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

#%%

import matplotlib.pyplot as pl
import torch
import networkx as nx
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx, one_hot
from torch_geometric.utils import stochastic_blockmodel_graph as sbm
from torch_geometric.data import Data as GraphData
import torch.nn as nn
from torch_geometric.nn import Linear, GCNConv
from ot.gnn import TFGWPooling
from sklearn.manifold import TSNE


##############################################################################
# Generate data
# -------------

# parameters

# We create 2 classes of stochastic block models (SBM) graphs with 1 block and 2 blocks respectively.

torch.manual_seed(0)

n_graphs = 50
n_nodes = 10
n_node_classes = 2

#edge probabilities for the SBMs
P1 = [[0.8]]
P2 = [[0.9, 0.1], [0.1, 0.9]]

#block sizes
block_sizes1 = [n_nodes]
block_sizes2 = [n_nodes // 2, n_nodes // 2]

#node features
x1 = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
x1 = one_hot(x1, num_classes=n_node_classes)
x1 = torch.reshape(x1, (n_nodes, n_node_classes))

x2 = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
x2 = one_hot(x2, num_classes=n_node_classes)
x2 = torch.reshape(x2, (n_nodes, n_node_classes))

graphs1 = [GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0])) for i in range(n_graphs)]
graphs2 = [GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1])) for i in range(n_graphs)]

graphs = graphs1 + graphs2

#split the data into train and test sets
train_graphs, test_graphs = random_split(graphs, [n_graphs, n_graphs])

train_loader = DataLoader(train_graphs, batch_size=10, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=10, shuffle=False)


#%%

##############################################################################
# Plot data
# ---------

# plot one graph of each class

fontsize = 10

pl.figure(0, figsize=(8, 2.5))
pl.clf()
pl.subplot(121)
pl.axis('off')
pl.title('Graph of class 1', fontsize=fontsize)
G = to_networkx(graphs1[0], to_undirected=True)
pos = nx.spring_layout(G, seed=0)
nx.draw_networkx(G, pos, with_labels=False, node_color="tab:blue")

pl.subplot(122)
pl.axis('off')
pl.title('Graph of class 2', fontsize=fontsize)
G = to_networkx(graphs2[0], to_undirected=True)
pos = nx.spring_layout(G, seed=0)
nx.draw_networkx(G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color="tab:blue")
nx.draw_networkx(G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color="tab:red")

pl.tight_layout()
pl.show()

#%%

##############################################################################
# Pooling architecture using the TFGW layer
# ---------


class pooling_TFGW(nn.Module):
"""
Pooling architecture using the TFGW layer.
"""

def __init__(self, n_features, n_templates, n_template_nodes, n_classes, n_hidden_layers, feature_init_mean=0., feature_init_std=1.):
"""
Pooling architecture using the TFGW layer.
"""
super().__init__()

self.n_templates = n_templates
self.n_template_nodes = n_template_nodes
self.n_hidden_layers = n_hidden_layers
self.n_features = n_features

self.conv = GCNConv(self.n_features, self.n_hidden_layers)

self.TFGW = TFGWPooling(self.n_hidden_layers, self.n_templates, self.n_template_nodes, feature_init_mean=feature_init_mean, feature_init_std=feature_init_std)

self.linear = Linear(self.n_templates, n_classes)

def forward(self, x, edge_index, batch=None):
x = self.conv(x, edge_index)

x = self.TFGW(x, edge_index, batch)

x_latent = x # save latent embeddings for visualization

x = self.linear(x)

return x, x_latent


##############################################################################
# Graph classification training
# ---------


n_epochs = 25

#store latent embeddings and classes for TSNE visualization
embeddings_for_TSNE = []
classes = []

model = pooling_TFGW(n_features=2, n_templates=2, n_template_nodes=2, n_classes=2, n_hidden_layers=2, feature_init_mean=0.5, feature_init_std=0.5)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005)
criterion = torch.nn.CrossEntropyLoss()

all_accuracy = []
all_loss = []

for epoch in range(n_epochs):

losses = []
accs = []

for data in train_loader:
out, latent_embedding = model(data.x, data.edge_index, data.batch)
loss = criterion(out, data.y)
loss.backward()
optimizer.step()

pred = out.argmax(dim=1)
train_correct = pred == data.y
train_acc = int(train_correct.sum()) / len(data)

accs.append(train_acc)
losses.append(loss.item())

#store last classes and embeddings for TSNE visualization
if epoch == n_epochs - 1:
embeddings_for_TSNE.append(latent_embedding)
classes.append(data.y)

print(f'Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}')

all_accuracy.append(torch.mean(torch.tensor(accs)))
all_loss.append(torch.mean(torch.tensor(losses)))


pl.figure(1, figsize=(8, 2.5))
pl.clf()
pl.subplot(121)
pl.plot(all_loss)
pl.xlabel('epochs')
pl.title('Loss')

pl.subplot(122)
pl.plot(all_accuracy)
pl.xlabel('epochs')
pl.title('Accuracy')

pl.tight_layout()
pl.show()

#Test

test_accs = []

for data in test_loader:
out, latent_embedding = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
test_correct = pred == data.y
test_acc = int(test_correct.sum()) / len(data)
test_accs.append(test_acc)
embeddings_for_TSNE.append(latent_embedding)
classes.append(data.y)

classes = torch.hstack(classes)

print(f'Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}')

#%%
##############################################################################
# TSNE visualization of graph classification
# ---------

indices = torch.randint(2 * n_graphs, (60,)) # select a subset of embeddings for TSNE visualization
latent_embeddings = torch.vstack(embeddings_for_TSNE).detach().numpy()[indices, :]

TSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform(latent_embeddings)

class_0 = classes[indices] == 0
class_1 = classes[indices] == 1

TSNE_embeddings_0 = TSNE_embeddings[class_0, :]
TSNE_embeddings_1 = TSNE_embeddings[class_1, :]

pl.figure(2, figsize=(6, 2.5))
pl.scatter(TSNE_embeddings_0[:, 0], TSNE_embeddings_0[:, 1],
alpha=0.5, marker='o', label='class 1')
pl.scatter(TSNE_embeddings_1[:, 0], TSNE_embeddings_1[:, 1],
alpha=0.5, marker='o', label='class 2')
pl.legend()
pl.title('TSNE in the latent space after training')
pl.show()


# %%
24 changes: 24 additions & 0 deletions ot/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
"""
Layers and functions for optimal transport in Graph Neural Networks.

.. warning::
Note that by default the module is not imported in :mod:`ot`. In order to
use it you need to explicitly import :mod:`ot.gnn`. This module is PyTorch Geometric dependent.
The layers are compatible with their API.

"""

# Author: Sonia Mazelet <sonia.mazelet@ens-paris-saclay.fr>
# Rémi Flamary <remi.flamary@unice.fr>
#
# License: MIT License

# All submodules and packages


from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates)

from ._layers import (TFGWPooling,TWPooling)

__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling']
Loading