Skip to content

Conversation

SoniaMaz8
Copy link
Contributor

@SoniaMaz8 SoniaMaz8 commented Jul 7, 2023

Types of changes

I introduced a new module ot.gnn in which the Template Fused Gromov Wasserstein layer (Vincent-Cuaz, C., Flamary, R., Corneli, M., Vayer, T., & Courty, N. (2022). Template-based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35, 11800-11814.) is coded. This layer is a pooling layer for GNNs that constructs graph embeddings by constructing a vector of FGW distances between the graph and learned graph templates.

Motivation and context / Related issue

How has this been tested (if it applies)

I created a new test which constructs two small graphs, passes them through the layer and performs back propagation.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title Template Fused Gromov Wasserstein layer [WIP] Template Fused Gromov Wasserstein layer Jul 7, 2023
@rflamary rflamary marked this pull request as ready for review July 7, 2023 13:05
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SoniaMaz8 , here are a few comments .

Please also add a quick example of the use of TFGW in the doc.

n_templates_nodes : int
Number of nodes in each template.
alpha0 : float, optional
Trade-off parameter (0 < alpha < 1). If None alpha is trained, else it is fixed at the given value.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FGW tradoff parameter, weights features (apha=) and structures(alpha=).

"""
Template Fused Gromov-Wasserstein (TFGW) layer. This layer is a pooling layer for graph neural networks.
It computes the fused Gromov-Wasserstein distances between the graph and a set of templates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add warning that alpha=logitg() tthat is optimized (to avoid constrained optimization)

self.q0 = nn.Parameter(self.q0)

if alpha0 is None:
alpha0 = torch.Tensor([0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle alpha as array (add option for that)

ot/gnn/_utils.py Outdated
from ..gromov import fused_gromov_wasserstein2


def template_initialisation(n_templates, n_template_nodes, n_features, feature_init_mean=0., feature_init_std=1.):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFGW_template_initialiation

test/test_gnn.py Outdated
Pooling architecture using the LTFGW layer.
Parameters
----------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove doc here

@rflamary rflamary changed the title [WIP] Template Fused Gromov Wasserstein layer [MRG] Template Fused Gromov Wasserstein layer Jul 18, 2023
@rflamary rflamary changed the title [MRG] Template Fused Gromov Wasserstein layer [FEAT] Template Fused Gromov Wasserstein layer Jul 18, 2023
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @SoniaMaz8,

The PR is looking very good. I have a few comments below and more high level stuff that needs be done before merging.

  • Add gnn in the docs/source/all.rst file so that the sub-module is properly documented (it might show some worng formating that needs to be corected)
  • Express what is computed in math form when describing the Pooling layer, it will help users understand what is done without the need for reading the papers (seen ot.emd2 as an example of how to add math in doc format).
  • Add an item in the itemize in the top of the readme such as : Graph Neural Network OT layers TFGW [52] and TW (OT-GNN) [53]

References
----------
.. [53] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks] (https://arxiv.org/pdf/2006.04804)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad formatting when compiled in doc, you can remove the link.



@pytest.mark.skipif(not torch_geometric, reason="pytorch_geometric not installed")
def test_TFGW():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split this in two tests please

@rflamary rflamary merged commit 88a1fb1 into PythonOT:master Jul 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants