Skip to content

Commit

Permalink
Add a normalize parameter to dense_diff_pool (#4847)
Browse files Browse the repository at this point in the history
* modified link_loss to make it viable to not normalizing

* modified link_loss to make it viable to not normalizing

* changelog modified

* Update CHANGELOG.md

Co-authored-by: Jinu Sunil <jinu.sunil@gmail.com>

* update

Co-authored-by: Chuxuan Hu <chuxuan3@illinois.edu>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Jinu Sunil <jinu.sunil@gmail.com>
  • Loading branch information
4 people authored Jun 23, 2022
1 parent 97c50a0 commit 85cddb3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
Expand Down
8 changes: 6 additions & 2 deletions torch_geometric/nn/dense/diff_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
EPS = 1e-15


def dense_diff_pool(x, adj, s, mask=None):
def dense_diff_pool(x, adj, s, mask=None, normalize=True):
r"""The differentiable pooling operator from the `"Hierarchical Graph
Representation Learning with Differentiable Pooling"
<https://arxiv.org/abs/1806.08804>`_ paper
Expand Down Expand Up @@ -44,6 +44,9 @@ def dense_diff_pool(x, adj, s, mask=None):
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
normalize (bool, optional): If set to :obj:`False`, the link
prediction loss is not divided by :obj:`adj.numel()`.
(default: :obj:`True`)
:rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
:class:`Tensor`)
Expand All @@ -66,7 +69,8 @@ def dense_diff_pool(x, adj, s, mask=None):

link_loss = adj - torch.matmul(s, s.transpose(1, 2))
link_loss = torch.norm(link_loss, p=2)
link_loss = link_loss / adj.numel()
if normalize is True:
link_loss = link_loss / adj.numel()

ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()

Expand Down

0 comments on commit 85cddb3

Please sign in to comment.