Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.dense_mincut_pool Add temperature parameter to mincut pool operator #5908

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Support temperature with `float` or `int` in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908))
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
- Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819))
- Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815))
- Fixed `path` in `hetero_conv_dblp.py` example ([#5686](https://github.com/pyg-team/pytorch_geometric/pull/5686))
Expand Down
11 changes: 5 additions & 6 deletions torch_geometric/nn/dense/mincut_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@


def dense_mincut_pool(
x: Tensor,
adj: Tensor,
s: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None,
temp: float = 1.0) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
r"""The MinCut pooling operator from the `"Spectral Clustering in Graph
Neural Networks for Graph Pooling" <https://arxiv.org/abs/1907.00481>`_
paper
Expand Down Expand Up @@ -54,6 +51,8 @@ def dense_mincut_pool(
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
temp (float): Temperature parameter for softmax
function. (default: :obj:`1.0`)

:rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
:class:`Tensor`)
Expand All @@ -65,7 +64,7 @@ def dense_mincut_pool(

(batch_size, num_nodes, _), k = x.size(), s.size(-1)

s = torch.softmax(s, dim=-1)
s = torch.softmax(s / temp, dim=-1)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

if mask is not None:
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
Expand Down