diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 2d25f3efa..a6f6f94ba 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,6 +42,7 @@ The contributors to this library are: * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) * [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) +* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) ## Acknowledgments diff --git a/README.md b/README.md index 00e2bb989..c3a802431 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ POT provides the following Machine Learning related solvers: * [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8]. * [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt). * [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27]. +* Graph Neural Network OT layers TFGW [52] and TW (OT-GNN) [53] (https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html). @@ -314,3 +315,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. [52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv. + +[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35. + +[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804). diff --git a/RELEASES.md b/RELEASES.md index bd74da059..bd5a618e8 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,6 +3,7 @@ ## 0.9.1dev #### New features +- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488) - Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483) - Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) - Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459) diff --git a/docs/source/all.rst b/docs/source/all.rst index a9d7fe2bb..8750074c3 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -22,6 +22,7 @@ API and modules dr factored gaussian + gnn gromov lp optim diff --git a/examples/gromov/plot_gnn_TFGW.py b/examples/gromov/plot_gnn_TFGW.py new file mode 100644 index 000000000..de745031d --- /dev/null +++ b/examples/gromov/plot_gnn_TFGW.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +""" +============================== +Graph classification with Tempate Based Fused Gromov Wasserstein +============================== + +This example first illustrates how to train a graph classification gnn based on the Template Fused Gromov Wasserstein layer as proposed in [52] . + +[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022).Template based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35. + +""" + +# Author: Sonia Mazelet +# Rémi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +#%% + +import matplotlib.pyplot as pl +import torch +import networkx as nx +from torch.utils.data import random_split +from torch_geometric.loader import DataLoader +from torch_geometric.utils import to_networkx, one_hot +from torch_geometric.utils import stochastic_blockmodel_graph as sbm +from torch_geometric.data import Data as GraphData +import torch.nn as nn +from torch_geometric.nn import Linear, GCNConv +from ot.gnn import TFGWPooling +from sklearn.manifold import TSNE + + +############################################################################## +# Generate data +# ------------- + +# parameters + +# We create 2 classes of stochastic block models (SBM) graphs with 1 block and 2 blocks respectively. + +torch.manual_seed(0) + +n_graphs = 50 +n_nodes = 10 +n_node_classes = 2 + +#edge probabilities for the SBMs +P1 = [[0.8]] +P2 = [[0.9, 0.1], [0.1, 0.9]] + +#block sizes +block_sizes1 = [n_nodes] +block_sizes2 = [n_nodes // 2, n_nodes // 2] + +#node features +x1 = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) +x1 = one_hot(x1, num_classes=n_node_classes) +x1 = torch.reshape(x1, (n_nodes, n_node_classes)) + +x2 = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) +x2 = one_hot(x2, num_classes=n_node_classes) +x2 = torch.reshape(x2, (n_nodes, n_node_classes)) + +graphs1 = [GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0])) for i in range(n_graphs)] +graphs2 = [GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1])) for i in range(n_graphs)] + +graphs = graphs1 + graphs2 + +#split the data into train and test sets +train_graphs, test_graphs = random_split(graphs, [n_graphs, n_graphs]) + +train_loader = DataLoader(train_graphs, batch_size=10, shuffle=True) +test_loader = DataLoader(test_graphs, batch_size=10, shuffle=False) + + +#%% + +############################################################################## +# Plot data +# --------- + +# plot one graph of each class + +fontsize = 10 + +pl.figure(0, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.title('Graph of class 1', fontsize=fontsize) +G = to_networkx(graphs1[0], to_undirected=True) +pos = nx.spring_layout(G, seed=0) +nx.draw_networkx(G, pos, with_labels=False, node_color="tab:blue") + +pl.subplot(122) +pl.axis('off') +pl.title('Graph of class 2', fontsize=fontsize) +G = to_networkx(graphs2[0], to_undirected=True) +pos = nx.spring_layout(G, seed=0) +nx.draw_networkx(G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color="tab:blue") +nx.draw_networkx(G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color="tab:red") + +pl.tight_layout() +pl.show() + +#%% + +############################################################################## +# Pooling architecture using the TFGW layer +# --------- + + +class pooling_TFGW(nn.Module): + """ + Pooling architecture using the TFGW layer. + """ + + def __init__(self, n_features, n_templates, n_template_nodes, n_classes, n_hidden_layers, feature_init_mean=0., feature_init_std=1.): + """ + Pooling architecture using the TFGW layer. + """ + super().__init__() + + self.n_templates = n_templates + self.n_template_nodes = n_template_nodes + self.n_hidden_layers = n_hidden_layers + self.n_features = n_features + + self.conv = GCNConv(self.n_features, self.n_hidden_layers) + + self.TFGW = TFGWPooling(self.n_hidden_layers, self.n_templates, self.n_template_nodes, feature_init_mean=feature_init_mean, feature_init_std=feature_init_std) + + self.linear = Linear(self.n_templates, n_classes) + + def forward(self, x, edge_index, batch=None): + x = self.conv(x, edge_index) + + x = self.TFGW(x, edge_index, batch) + + x_latent = x # save latent embeddings for visualization + + x = self.linear(x) + + return x, x_latent + + +############################################################################## +# Graph classification training +# --------- + + +n_epochs = 25 + +#store latent embeddings and classes for TSNE visualization +embeddings_for_TSNE = [] +classes = [] + +model = pooling_TFGW(n_features=2, n_templates=2, n_template_nodes=2, n_classes=2, n_hidden_layers=2, feature_init_mean=0.5, feature_init_std=0.5) + +optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005) +criterion = torch.nn.CrossEntropyLoss() + +all_accuracy = [] +all_loss = [] + +for epoch in range(n_epochs): + + losses = [] + accs = [] + + for data in train_loader: + out, latent_embedding = model(data.x, data.edge_index, data.batch) + loss = criterion(out, data.y) + loss.backward() + optimizer.step() + + pred = out.argmax(dim=1) + train_correct = pred == data.y + train_acc = int(train_correct.sum()) / len(data) + + accs.append(train_acc) + losses.append(loss.item()) + + #store last classes and embeddings for TSNE visualization + if epoch == n_epochs - 1: + embeddings_for_TSNE.append(latent_embedding) + classes.append(data.y) + + print(f'Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}') + + all_accuracy.append(torch.mean(torch.tensor(accs))) + all_loss.append(torch.mean(torch.tensor(losses))) + + +pl.figure(1, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.plot(all_loss) +pl.xlabel('epochs') +pl.title('Loss') + +pl.subplot(122) +pl.plot(all_accuracy) +pl.xlabel('epochs') +pl.title('Accuracy') + +pl.tight_layout() +pl.show() + +#Test + +test_accs = [] + +for data in test_loader: + out, latent_embedding = model(data.x, data.edge_index, data.batch) + pred = out.argmax(dim=1) + test_correct = pred == data.y + test_acc = int(test_correct.sum()) / len(data) + test_accs.append(test_acc) + embeddings_for_TSNE.append(latent_embedding) + classes.append(data.y) + +classes = torch.hstack(classes) + +print(f'Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}') + +#%% +############################################################################## +# TSNE visualization of graph classification +# --------- + +indices = torch.randint(2 * n_graphs, (60,)) # select a subset of embeddings for TSNE visualization +latent_embeddings = torch.vstack(embeddings_for_TSNE).detach().numpy()[indices, :] + +TSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform(latent_embeddings) + +class_0 = classes[indices] == 0 +class_1 = classes[indices] == 1 + +TSNE_embeddings_0 = TSNE_embeddings[class_0, :] +TSNE_embeddings_1 = TSNE_embeddings[class_1, :] + +pl.figure(2, figsize=(6, 2.5)) +pl.scatter(TSNE_embeddings_0[:, 0], TSNE_embeddings_0[:, 1], + alpha=0.5, marker='o', label='class 1') +pl.scatter(TSNE_embeddings_1[:, 0], TSNE_embeddings_1[:, 1], + alpha=0.5, marker='o', label='class 2') +pl.legend() +pl.title('TSNE in the latent space after training') +pl.show() + + +# %% diff --git a/ot/gnn/__init__.py b/ot/gnn/__init__.py new file mode 100644 index 000000000..6a84100a1 --- /dev/null +++ b/ot/gnn/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +""" +Layers and functions for optimal transport in Graph Neural Networks. + +.. warning:: + Note that by default the module is not imported in :mod:`ot`. In order to + use it you need to explicitly import :mod:`ot.gnn`. This module is PyTorch Geometric dependent. + The layers are compatible with their API. + +""" + +# Author: Sonia Mazelet +# Rémi Flamary +# +# License: MIT License + +# All submodules and packages + + +from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates) + +from ._layers import (TFGWPooling,TWPooling) + +__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling'] \ No newline at end of file diff --git a/ot/gnn/_layers.py b/ot/gnn/_layers.py new file mode 100644 index 000000000..9e32dfdfc --- /dev/null +++ b/ot/gnn/_layers.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +""" +Template Fused Gromov Wasserstein +""" + +# Author: Sonia Mazelet +# Rémi Flamary +# +# License: MIT License + +import torch +import torch.nn as nn +from ._utils import TFGW_template_initialization, FGW_distance_to_templates, wasserstein_distance_to_templates + + +class TFGWPooling(nn.Module): + r""" + Template Fused Gromov-Wasserstein (TFGW) layer. This layer is a pooling layer for graph neural networks. + Computes the fused Gromov-Wasserstein distances between the graph and a set of templates. + + + .. math:: + TFGW_{ \overline{ \mathcal{G} }, \alpha }(C,F,h)=[ FGW_{\alpha}(C,F,h,\overline{C}_k,\overline{F}_k,\overline{h}_k)]_{k=1}^{K} + + where : + + - :math:`\mathcal{G}=\{(\overline{C}_k,\overline{F}_k,\overline{h}_k) \}_{k \in \{1,...,K \}} \}` is the set of :math:`K` templates characterized by their adjacency matrices :math:`\overline{C}_k`, their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. + - :math:`C`, :math:`F` and :math:`h` are respectively the adjacency matrix, the feature matrix and the node weights of the graph. + - :math:`\alpha` is the trade-off parameter between features and structure for the Fused Gromov-Wasserstein distance. + + + Parameters + ---------- + n_features : int + Feature dimension of the nodes. + n_tplt : int + Number of graph templates. + n_tplt_nodes : int + Number of nodes in each template. + alpha : float, optional + FGW trade-off parameter (0 < alpha < 1). If None alpha is trained, else it is fixed at the given value. + Weights features (alpha=0) and structure (alpha=1). + train_node_weights : bool, optional + If True, the templates node weights are learned. + Else, they are uniform. + multi_alpha: bool, optional + If True, the alpha parameter is a vector of size n_tplt. + feature_init_mean: float, optional + Mean of the random normal law to initialize the template features. + feature_init_std: float, optional + Standard deviation of the random normal law to initialize the template features. + + + + References + ---------- + .. [53] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Template based graph neural network with optimal transport distances" + """ + + def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, alpha=None, train_node_weights=True, multi_alpha=False, feature_init_mean=0., feature_init_std=1.): + """ + Template Fused Gromov-Wasserstein (TFGW) layer. This layer is a pooling layer for graph neural networks. + Computes the fused Gromov-Wasserstein distances between the graph and a set of templates. + + + .. math:: + TFGW_{\overline{\mathcal{G}},\alpha}(C,F,h)=[FGW_{\alpha}(C,F,h,\overline{C}_k,\overline{F}_k,\overline{h}_k)]_{k=1}^{K} + + where : + + - :math:`\mathcal{G}=\{(\overline{C}_k,\overline{F}_k,\overline{h}_k) \}_{k \in \{1,...,K \}} }` is the set of :math:`K` templates charactersised by their adjacency matrices :math:`\overline{C}_k`, their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. + - :math:`C`, :math:`F` and :math:`h` are respectively the adjacency matrix, the feature matrix and the node weights of the graph. + - :math:`\alpha` is the trade-off parameter between features and structure for the Fused Gromov-Wasserstein distance. + + + Parameters + ---------- + n_features : int + Feature dimension of the nodes. + n_tplt : int + Number of graph templates. + n_tplt_nodes : int + Number of nodes in each template. + alpha : float, optional + FGW trade-off parameter (0 < alpha < 1). If None alpha is trained, else it is fixed at the given value. + Weights features (alpha=0) and structure (alpha=1). + train_node_weights : bool, optional + If True, the templates node weights are learned. + Else, they are uniform. + multi_alpha: bool, optional + If True, the alpha parameter is a vector of size n_tplt. + feature_init_mean: float, optional + Mean of the random normal law to initialize the template features. + feature_init_std: float, optional + Standard deviation of the random normal law to initialize the template features. + + + References + ---------- + .. [53] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Template based graph neural network with optimal transport distances" + + """ + super().__init__() + + self.n_tplt = n_tplt + self.n_tplt_nodes = n_tplt_nodes + self.n_features = n_features + self.multi_alpha = multi_alpha + self.feature_init_mean = feature_init_mean + self.feature_init_std = feature_init_std + + tplt_adjacencies, tplt_features, self.q0 = TFGW_template_initialization(self.n_tplt, self.n_tplt_nodes, self.n_features, self.feature_init_mean, self.feature_init_std) + self.tplt_adjacencies = nn.Parameter(tplt_adjacencies) + self.tplt_features = nn.Parameter(tplt_features) + + self.softmax = nn.Softmax(dim=1) + + if train_node_weights: + self.q0 = nn.Parameter(self.q0) + + if alpha is None: + if multi_alpha: + self.alpha0 = torch.Tensor([0] * self.n_tplt) + else: + alpha0 = torch.Tensor([0]) + self.alpha0 = nn.Parameter(alpha0) + else: + if multi_alpha: + self.alpha0 = torch.Tensor([alpha] * self.n_tplt) + else: + self.alpha0 = torch.Tensor([alpha]) + self.alpha0 = torch.logit(alpha0) + + def forward(self, x, edge_index, batch=None): + """ + Parameters + ---------- + x : torch.Tensor + Node features. + edge_index : torch.Tensor + Edge indices. + batch : torch.Tensor, optional + Batch vector which assigns each node to its graph. + """ + alpha = torch.sigmoid(self.alpha0) + q = self.softmax(self.q0) + x = FGW_distance_to_templates(edge_index, self.tplt_adjacencies, x, self.tplt_features, q, alpha, self.multi_alpha, batch) + return x + + +class TWPooling(nn.Module): + r""" + Template Wasserstein (TW) layer, also kown as OT-GNN layer. This layer is a pooling layer for graph neural networks. + Computes the Wasserstein distances between the features of the graph features and a set of templates. + + .. math:: + TW_{\overline{\mathcal{G}}}(C,F,h)=[W(F,h,\overline{F}_k,\overline{h}_k)]_{k=1}^{K} + + where : + + - :math:`\mathcal{G}=\{(\overline{F}_k,\overline{h}_k) \}_{k \in \{1,...,K \}} \}` is the set of :math:`K` templates charactersised by their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. + - :math:`F` and :math:`h` are respectively the feature matrix and the node weights of the graph. + + Parameters + ---------- + n_features : int + Feature dimension of the nodes. + n_tplt : int + Number of graph templates. + n_tplt_nodes : int + Number of nodes in each template. + train_node_weights : bool, optional + If True, the templates node weights are learned. + Else, they are uniform. + feature_init_mean: float, optional + Mean of the random normal law to initialize the template features. + feature_init_std: float, optional + Standard deviation of the random normal law to initialize the template features. + + References + ---------- + .. [54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks] + + """ + + def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, train_node_weights=True, feature_init_mean=0., feature_init_std=1.): + """ + Template Wasserstein (TW) layer, also kown as OT-GNN layer. This layer is a pooling layer for graph neural networks. + Computes the Wasserstein distances between the features of the graph features and a set of templates. + + .. math:: + TW_{\overline{\mathcal{G}}}(C,F,h)=[W(F,h,\overline{F}_k,\overline{h}_k)]_{k=1}^{K} + + where : + + - :math:`\mathcal{G}=\{(\overline{F}_k,\overline{h}_k) \}_{k \in \llbracket 1;K \rrbracket }` is the set of :math:`K` templates charactersised by their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. + - :math:`F` and :math:`h` are respectively the feature matrix and the node weights of the graph. + + Parameters + ---------- + n_features : int + Feature dimension of the nodes. + n_tplt : int + Number of graph templates. + n_tplt_nodes : int + Number of nodes in each template. + train_node_weights : bool, optional + If True, the templates node weights are learned. + Else, they are uniform. + feature_init_mean: float, optional + Mean of the random normal law to initialize the template features. + feature_init_std: float, optional + Standard deviation of the random normal law to initialize the template features. + + References + ---------- + .. [54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks] + """ + super().__init__() + + self.n_tplt = n_tplt + self.n_tplt_nodes = n_tplt_nodes + self.n_features = n_features + self.feature_init_mean = feature_init_mean + self.feature_init_std = feature_init_std + + _, tplt_features, self.q0 = TFGW_template_initialization(self.n_tplt, self.n_tplt_nodes, self.n_features, self.feature_init_mean, self.feature_init_std) + self.tplt_features = nn.Parameter(tplt_features) + + self.softmax = nn.Softmax(dim=1) + + if train_node_weights: + self.q0 = nn.Parameter(self.q0) + + def forward(self, x, edge_index=None, batch=None): + """ + Parameters + ---------- + x : torch.Tensor + Node features. + edge_index : torch.Tensor + Edge indices. + batch : torch.Tensor, optional + Batch vector which assigns each node to its graph. + """ + q = self.softmax(self.q0) + x = wasserstein_distance_to_templates(x, self.tplt_features, q, batch) + return x diff --git a/ot/gnn/_utils.py b/ot/gnn/_utils.py new file mode 100644 index 000000000..18e32f627 --- /dev/null +++ b/ot/gnn/_utils.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +""" +GNN layers utils +""" + +# Author: Sonia Mazelet +# Rémi Flamary +# +# License: MIT License + +import torch +from ..utils import dist +from ..gromov import fused_gromov_wasserstein2 +from ..lp import emd2 +from torch_geometric.utils import subgraph + + +def TFGW_template_initialization(n_tplt, n_tplt_nodes, n_features, feature_init_mean=0., feature_init_std=1.): + """ + Initializes templates for the Template Fused Gromov Wasserstein layer. + Returns the adjacency matrices and the features of the nodes of the templates. + Adjacency matrices are intialised uniformly with values in :math:`[0,1]`. + Node features are intialized following a normal distribution. + + Parameters + ---------- + + n_tplt: int + Number of templates. + n_tplt_nodes: int + Number of nodes per template. + n_features: int + Number of features for the nodes. + feature_init_mean: float, optional + Mean of the random normal law to initialize the template features. + feature_init_std: float, optional + Standard deviation of the random normal law to initialize the template features. + + Returns + ---------- + tplt_adjacencies: torch.Tensor, shape (n_templates, n_template_nodes, n_template_nodes) + Adjancency matrices for the templates. + tplt_features: torch.Tensor, shape (n_templates, n_template_nodes, n_features) + Node features for each template. + q: torch.Tensor, shape (n_templates, n_template_nodes) + weight on the template nodes. + """ + + tplt_adjacencies = torch.rand((n_tplt, n_tplt_nodes, n_tplt_nodes)) + tplt_features = torch.Tensor(n_tplt, n_tplt_nodes, n_features) + + torch.nn.init.normal_(tplt_features, mean=feature_init_mean, std=feature_init_std) + + q = torch.zeros(n_tplt, n_tplt_nodes) + + tplt_adjacencies = 0.5 * (tplt_adjacencies + torch.transpose(tplt_adjacencies, 1, 2)) + + return tplt_adjacencies, tplt_features, q + + +def FGW_distance_to_templates(G_edges, tplt_adjacencies, G_features, tplt_features, tplt_weights, alpha=0.5, multi_alpha=False, batch=None): + """ + Computes the FGW distances between a graph and templates. + + Parameters + ---------- + G_edges : torch.Tensor, shape (n_edges, 2) + Edge indices of the graph in the Pytorch Geometric format. + tplt_adjacencies : list of torch.Tensor, shape (n_templates, n_template_nodes, n_templates_nodes) + List of the adjacency matrices of the templates. + G_features : torch.Tensor, shape (n_nodes, n_features) + Graph node features. + tplt_features : list of torch.Tensor, shape (n_templates, n_template_nodes, n_features) + List of the node features of the templates. + weights : torch.Tensor, shape (n_templates, n_template_nodes) + Weights on the nodes of the templates. + alpha : float, optional + Trade-off parameter (0 < alpha < 1). + Weights features (alpha=0) and structure (alpha=1). + multi_alpha: bool, optional + If True, the alpha parameter is a vector of size n_templates. + batch: torch.Tensor, optional + Batch vector which assigns each node to its graph. + + Returns + ------- + distances : torch.Tensor, shape (n_templates) if batch=None, else shape (n_graphs, n_templates). + Vector of fused Gromov-Wasserstein distances between the graph and the templates. + """ + + if batch is None: + + n, n_feat = G_features.shape + n_T, _, n_feat_T = tplt_features.shape + + weights_G = torch.ones(n) / n + + C = torch.sparse_coo_tensor(G_edges, torch.ones(len(G_edges[0])), size=(n, n)).type(torch.float) + C = C.to_dense() + + if not n_feat == n_feat_T: + raise ValueError('The templates and the graphs must have the same feature dimension.') + + distances = torch.zeros(n_T) + + for j in range(n_T): + + template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + M = dist(G_features, template_features).type(torch.float) + + #if alpha is zero the emd distance is used + if multi_alpha and torch.any(alpha > 0): + embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha[j], symmetric=True, max_iter=50) + elif not multi_alpha and torch.all(alpha == 0): + embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) + elif not multi_alpha and alpha > 0: + embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha, symmetric=True, max_iter=50) + else: + embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) + + distances[j] = embedding + + else: + + n_T, _, n_feat_T = tplt_features.shape + + num_graphs = torch.max(batch) + 1 + distances = torch.zeros(num_graphs, n_T) + + #iterate over the graphs in the batch + for i in range(num_graphs): + + nodes = torch.where(batch == i)[0] + + G_edges_i, _ = subgraph(nodes, edge_index=G_edges, relabel_nodes=True) + G_features_i = G_features[nodes] + + n, n_feat = G_features_i.shape + + weights_G = torch.ones(n) / n + + n_edges = len(G_edges_i[0]) + + C = torch.sparse_coo_tensor(G_edges_i, torch.ones(n_edges), size=(n, n)).type(torch.float) + C = C.to_dense() + + if not n_feat == n_feat_T: + raise ValueError('The templates and the graphs must have the same feature dimension.') + + for j in range(n_T): + + template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + M = dist(G_features_i, template_features).type(torch.float) + + #if alpha is zero the emd distance is used + if multi_alpha and torch.any(alpha > 0): + embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha[j], symmetric=True, max_iter=50) + elif not multi_alpha and torch.all(alpha == 0): + embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) + elif not multi_alpha and alpha > 0: + embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha, symmetric=True, max_iter=50) + else: + embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) + + distances[i, j] = embedding + + return distances + + +def wasserstein_distance_to_templates(G_features, tplt_features, tplt_weights, batch=None): + """ + Computes the Wasserstein distances between a graph and graph templates. + + Parameters + ---------- + G_features : torch.Tensor, shape (n_nodes, n_features) + Node features of the graph. + tplt_features : list of torch.Tensor, shape (n_templates, n_template_nodes, n_features) + List of the node features of the templates. + weights : torch.Tensor, shape (n_templates, n_template_nodes) + Weights on the nodes of the templates. + batch: torch.Tensor, optional + Batch vector which assigns each node to its graph. + + Returns + ------- + distances : torch.Tensor, shape (n_templates) if batch=None, else shape (n_graphs, n_templates) + Vector of Wasserstein distances between the graph and the templates. + """ + + if batch is None: + + n, n_feat = G_features.shape + n_T, _, n_feat_T = tplt_features.shape + + weights_G = torch.ones(n) / n + + if not n_feat == n_feat_T: + raise ValueError('The templates and the graphs must have the same feature dimension.') + + distances = torch.zeros(n_T) + + for j in range(n_T): + + template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + M = dist(G_features, template_features).type(torch.float) + + distances[j] = emd2(weights_G, tplt_weights[j], M, numItermax=50) + + else: + + n_T, _, n_feat_T = tplt_features.shape + + num_graphs = torch.max(batch) + 1 + distances = torch.zeros(num_graphs, n_T) + + #iterate over the graphs in the batch + for i in range(num_graphs): + + nodes = torch.where(batch == i)[0] + + G_features_i = G_features[nodes] + + n, n_feat = G_features_i.shape + + weights_G = torch.ones(n) / n + + if not n_feat == n_feat_T: + raise ValueError('The templates and the graphs must have the same feature dimension.') + + for j in range(n_T): + + template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + M = dist(G_features_i, template_features).type(torch.float) + + distances[i, j] = emd2(weights_G, tplt_weights[j], M, numItermax=50) + + return distances diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4952a2183..3641d7817 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -327,7 +327,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): # ensure that same mass np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum') + b.sum(0), err_msg='a and b vector must have the same sum', + decimal=6) b = b * a.sum() / b.sum() asel = a != 0 diff --git a/requirements.txt b/requirements.txt index 9be4deb42..f96e89285 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ torch jax jaxlib tensorflow -pytest \ No newline at end of file +pytest +torch_geometric \ No newline at end of file diff --git a/test/test_gnn.py b/test/test_gnn.py new file mode 100644 index 000000000..c69170123 --- /dev/null +++ b/test/test_gnn.py @@ -0,0 +1,162 @@ +"""Tests for module gnn""" + +# Author: Sonia Mazelet +# Rémi Flamary +# +# License: MIT License + + +import pytest + +try: # test if pytorch_geometric is installed + import torch_geometric +except ImportError: + torch_geometric = False + + +@pytest.mark.skipif(not torch_geometric, reason="pytorch_geometric not installed") +def test_TFGW(): + # Test the TFGW layer by passing two graphs through the layer and doing backpropagation. + + import torch + from torch_geometric.nn import Linear + from torch_geometric.data import Data as GraphData + from torch_geometric.loader import DataLoader + import torch.nn as nn + from ot.gnn import TFGWPooling + + class pooling_TFGW(nn.Module): + """ + Pooling architecture using the TFGW layer. + """ + + def __init__(self, n_features, n_templates, n_template_nodes): + """ + Pooling architecture using the TFGW layer. + """ + super().__init__() + + self.n_features = n_features + self.n_templates = n_templates + self.n_template_nodes = n_template_nodes + + self.TFGW = TFGWPooling(self.n_templates, self.n_template_nodes, self.n_features) + + self.linear = Linear(self.n_templates, 1) + + def forward(self, x, edge_index): + + x = self.TFGW(x, edge_index) + + x = self.linear(x) + + return x + + n_templates = 3 + n_template_nodes = 3 + n_nodes = 10 + n_features = 3 + n_epochs = 3 + + C1 = torch.randint(0, 2, size=(n_nodes, n_nodes)) + C2 = torch.randint(0, 2, size=(n_nodes, n_nodes)) + + edge_index1 = torch.stack(torch.where(C1 == 1)) + edge_index2 = torch.stack(torch.where(C2 == 1)) + + x1 = torch.rand(n_nodes, n_features) + x2 = torch.rand(n_nodes, n_features) + + graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.])) + graph2 = GraphData(x=x2, edge_index=edge_index2, y=torch.tensor([1.])) + + dataset = DataLoader([graph1, graph2], batch_size=1) + + model_FGW = pooling_TFGW(n_features, n_templates, n_template_nodes) + + optimizer = torch.optim.Adam(model_FGW.parameters(), lr=0.01) + criterion = torch.nn.CrossEntropyLoss() + + model_FGW.train() + + for i in range(n_epochs): + for data in dataset: + + out = model_FGW(data.x, data.edge_index) + loss = criterion(out, data.y) + loss.backward() + optimizer.step() + + +@pytest.mark.skipif(not torch_geometric, reason="pytorch_geometric not installed") +def test_TW(): + # Test the TW layer by passing two graphs through the layer and doing backpropagation. + + import torch + from torch_geometric.nn import Linear + from torch_geometric.data import Data as GraphData + from torch_geometric.loader import DataLoader + import torch.nn as nn + from ot.gnn import TWPooling + + class pooling_TW(nn.Module): + """ + Pooling architecture using the TW layer. + """ + + def __init__(self, n_features, n_templates, n_template_nodes): + """ + Pooling architecture using the TW layer. + """ + super().__init__() + + self.n_features = n_features + self.n_templates = n_templates + self.n_template_nodes = n_template_nodes + + self.TFGW = TWPooling(self.n_templates, self.n_template_nodes, self.n_features) + + self.linear = Linear(self.n_templates, 1) + + def forward(self, x, edge_index): + + x = self.TFGW(x, edge_index) + + x = self.linear(x) + + return x + + n_templates = 3 + n_template_nodes = 3 + n_nodes = 10 + n_features = 3 + n_epochs = 3 + + C1 = torch.randint(0, 2, size=(n_nodes, n_nodes)) + C2 = torch.randint(0, 2, size=(n_nodes, n_nodes)) + + edge_index1 = torch.stack(torch.where(C1 == 1)) + edge_index2 = torch.stack(torch.where(C2 == 1)) + + x1 = torch.rand(n_nodes, n_features) + x2 = torch.rand(n_nodes, n_features) + + graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.])) + graph2 = GraphData(x=x2, edge_index=edge_index2, y=torch.tensor([1.])) + + dataset = DataLoader([graph1, graph2], batch_size=1) + + model_W = pooling_TW(n_features, n_templates, n_template_nodes) + + optimizer = torch.optim.Adam(model_W.parameters(), lr=0.01) + criterion = torch.nn.CrossEntropyLoss() + + model_W.train() + + for i in range(n_epochs): + for data in dataset: + + out = model_W(data.x, data.edge_index) + loss = criterion(out, data.y) + loss.backward() + optimizer.step()