Skip to content

Commit

Permalink
[Type Hints] utils.to_dense_adj and utils.get_laplacian (pyg-team…
Browse files Browse the repository at this point in the history
…#5682)

Co-authored-by: juyi.lin <juyi.lin@kaust.edu.sa>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored and JakubPietrakIntel committed Nov 25, 2022
1 parent ac1d36d commit 651a619
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 8 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Fixed `path` in `hetero_conv_dblp.py` example ([#5686](https://github.com/pyg-team/pytorch_geometric/pull/5686))
- Fix `auto_select_device` routine in GraphGym for PyTorch Lightning>=1.7 ([#5677](https://github.com/pyg-team/pytorch_geometric/pull/5677))
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down Expand Up @@ -158,7 +159,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added temporal sampling support to `NeighborLoader` ([#4025](https://github.com/pyg-team/pytorch_geometric/pull/4025))
- Added an example for unsupervised heterogeneous graph learning based on "Deep Multiplex Graph Infomax" ([#3189](https://github.com/pyg-team/pytorch_geometric/pull/3189))
### Changed
- Fixed `path` in `heter_conv_dblp` example ([#5686](https://github.com/pyg-team/pytorch_geometric/pull/5686))
- Changed docstring for `RandomLinkSplit` ([#5190](https://github.com/pyg-team/pytorch_geometric/issues/5190))
- Switched to PyTorch `scatter_reduce` implementation - experimental feature ([#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120))
- Fixed `RGATConv` device mismatches for `f-scaled` mode ([#5187](https://github.com/pyg-team/pytorch_geometric/pull/5187)]
Expand Down
8 changes: 8 additions & 0 deletions test/utils/test_get_laplacian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from torch_geometric.testing import is_full_test
from torch_geometric.utils import get_laplacian


Expand All @@ -11,6 +12,13 @@ def test_get_laplacian():
assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]
assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4]

if is_full_test():
jit = torch.jit.script(get_laplacian)
lap = jit(edge_index, edge_weight)
assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2],
[1, 0, 2, 1, 0, 1, 2]]
assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4]

lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym')
assert lap_sym[0].tolist() == lap[0].tolist()
assert lap_sym[1].tolist() == [-0.5, -1, -0.5, -1, 1, 1, 1]
Expand Down
8 changes: 8 additions & 0 deletions test/utils/test_to_dense_adj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from torch_geometric.testing import is_full_test
from torch_geometric.utils import to_dense_adj


Expand All @@ -15,6 +16,13 @@ def test_to_dense_adj():
assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]
assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]

if is_full_test():
jit = torch.jit.script(to_dense_adj)
adj = jit(edge_index, batch)
assert adj.size() == (2, 3, 3)
assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]
assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]

adj = to_dense_adj(edge_index, batch, max_num_nodes=5)
assert adj.size() == (2, 5, 5)
assert adj[0][:3, :3].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]
Expand Down
15 changes: 10 additions & 5 deletions torch_geometric/utils/get_laplacian.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.typing import OptTensor
from torch_geometric.utils import add_self_loops, remove_self_loops

from .num_nodes import maybe_num_nodes


def get_laplacian(edge_index, edge_weight: Optional[torch.Tensor] = None,
normalization: Optional[str] = None,
dtype: Optional[int] = None,
num_nodes: Optional[int] = None):
def get_laplacian(
edge_index: Tensor,
edge_weight: OptTensor = None,
normalization: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
num_nodes: Optional[int] = None,
) -> Tuple[Tensor, OptTensor]:
r""" Computes the graph Laplacian of the graph given by :obj:`edge_index`
and optional :obj:`edge_weight`.
Expand Down
12 changes: 11 additions & 1 deletion torch_geometric/utils/to_dense_adj.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter

from torch_geometric.typing import OptTensor


def to_dense_adj(edge_index, batch=None, edge_attr=None, max_num_nodes=None):
def to_dense_adj(
edge_index: Tensor,
batch: OptTensor = None,
edge_attr: OptTensor = None,
max_num_nodes: Optional[int] = None,
) -> Tensor:
r"""Converts batched sparse adjacency matrices given by edge indices and
edge attributes to a single dense batched adjacency matrix.
Expand Down

0 comments on commit 651a619

Please sign in to comment.