-
Notifications
You must be signed in to change notification settings - Fork 530
[FEAT] Template Fused Gromov Wasserstein layer #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @SoniaMaz8 , here are a few comments .
Please also add a quick example of the use of TFGW in the doc.
ot/gnn/_layers.py
Outdated
n_templates_nodes : int | ||
Number of nodes in each template. | ||
alpha0 : float, optional | ||
Trade-off parameter (0 < alpha < 1). If None alpha is trained, else it is fixed at the given value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FGW tradoff parameter, weights features (apha=) and structures(alpha=).
""" | ||
Template Fused Gromov-Wasserstein (TFGW) layer. This layer is a pooling layer for graph neural networks. | ||
It computes the fused Gromov-Wasserstein distances between the graph and a set of templates. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add warning that alpha=logitg() tthat is optimized (to avoid constrained optimization)
ot/gnn/_layers.py
Outdated
self.q0 = nn.Parameter(self.q0) | ||
|
||
if alpha0 is None: | ||
alpha0 = torch.Tensor([0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handle alpha as array (add option for that)
ot/gnn/_utils.py
Outdated
from ..gromov import fused_gromov_wasserstein2 | ||
|
||
|
||
def template_initialisation(n_templates, n_template_nodes, n_features, feature_init_mean=0., feature_init_std=1.): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFGW_template_initialiation
test/test_gnn.py
Outdated
Pooling architecture using the LTFGW layer. | ||
Parameters | ||
---------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove doc here
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @SoniaMaz8,
The PR is looking very good. I have a few comments below and more high level stuff that needs be done before merging.
- Add
gnn
in thedocs/source/all.rst
file so that the sub-module is properly documented (it might show some worng formating that needs to be corected) - Express what is computed in math form when describing the Pooling layer, it will help users understand what is done without the need for reading the papers (seen ot.emd2 as an example of how to add math in doc format).
- Add an item in the itemize in the top of the readme such as : Graph Neural Network OT layers TFGW [52] and TW (OT-GNN) [53]
ot/gnn/_layers.py
Outdated
References | ||
---------- | ||
.. [53] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks] (https://arxiv.org/pdf/2006.04804) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad formatting when compiled in doc, you can remove the link.
|
||
|
||
@pytest.mark.skipif(not torch_geometric, reason="pytorch_geometric not installed") | ||
def test_TFGW(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split this in two tests please
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Types of changes
I introduced a new module ot.gnn in which the Template Fused Gromov Wasserstein layer (Vincent-Cuaz, C., Flamary, R., Corneli, M., Vayer, T., & Courty, N. (2022). Template-based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35, 11800-11814.) is coded. This layer is a pooling layer for GNNs that constructs graph embeddings by constructing a vector of FGW distances between the graph and learned graph templates.
Motivation and context / Related issue
How has this been tested (if it applies)
I created a new test which constructs two small graphs, passes them through the layer and performs back propagation.
PR checklist