Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type Hints] utils.is_undirected and utils.to_undirected #5767

Merged
merged 3 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
18 changes: 18 additions & 0 deletions test/utils/test_undirected.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from torch import Tensor

from torch_geometric.testing import is_full_test
from torch_geometric.utils import is_undirected, to_undirected


Expand All @@ -18,10 +20,26 @@ def test_is_undirected():

assert not is_undirected(torch.stack([row, col], dim=0))

if is_full_test():

@torch.jit.script
def jit(edge_index: Tensor) -> bool:
return is_undirected(edge_index)

assert not jit(torch.stack([row, col], dim=0))


def test_to_undirected():
row = torch.tensor([0, 1, 1])
col = torch.tensor([1, 0, 2])

edge_index = to_undirected(torch.stack([row, col], dim=0))
assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]

if is_full_test():

@torch.jit.script
def jit(edge_index: Tensor) -> Tensor:
return to_undirected(edge_index)

assert torch.equal(jit(torch.stack([row, col], dim=0)), edge_index)
11 changes: 6 additions & 5 deletions torch_geometric/nn/dense/diff_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@


def dense_diff_pool(
x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None,
normalize: Optional[bool] = True
x: Tensor,
adj: Tensor,
s: Tensor,
mask: Optional[Tensor] = None,
normalize: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""The differentiable pooling operator from the `"Hierarchical Graph
Representation Learning with Differentiable Pooling"
Expand Down Expand Up @@ -76,8 +79,6 @@ def dense_diff_pool(
if normalize is True:
link_loss = link_loss / adj.numel()

# Moved EPS from global to local variable for TorchScript support
EPS = 1e-15
ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()
ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean()

return out, out_adj, link_loss, ent_loss
32 changes: 27 additions & 5 deletions torch_geometric/utils/sort_edge_index.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor

from .num_nodes import maybe_num_nodes


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
sort_by_row=True):
# type: (Tensor, Optional[bool], Optional[int], bool) -> Tensor # noqa
pass


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
sort_by_row=True):
# type: (Tensor, Tensor, Optional[int], bool) -> Tuple[Tensor, Tensor] # noqa
pass


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
sort_by_row=True):
# type: (Tensor, List[Tensor], Optional[int], bool) -> Tuple[Tensor, List[Tensor]] # noqa
pass


def sort_edge_index(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
num_nodes: Optional[int] = None,
sort_by_row: bool = True,
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
Expand Down Expand Up @@ -53,9 +75,9 @@ def sort_edge_index(

edge_index = edge_index[:, perm]

if edge_attr is None:
return edge_index
elif isinstance(edge_attr, Tensor):
if isinstance(edge_attr, Tensor):
return edge_index, edge_attr[perm]
else:
elif isinstance(edge_attr, (list, tuple)):
return edge_index, [e[perm] for e in edge_attr]
else:
return edge_index
69 changes: 53 additions & 16 deletions torch_geometric/utils/undirected.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,21 @@
from .num_nodes import maybe_num_nodes


@torch.jit._overload
def is_undirected(edge_index, edge_attr=None, num_nodes=None):
# type: (Tensor, Optional[Tensor], Optional[int]) -> bool # noqa
pass


@torch.jit._overload
def is_undirected(edge_index, edge_attr=None, num_nodes=None):
# type: (Tensor, List[Tensor], Optional[int]) -> bool # noqa
pass


def is_undirected(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
num_nodes: Optional[int] = None,
) -> bool:
r"""Returns :obj:`True` if the graph given by :attr:`edge_index` is
Expand Down Expand Up @@ -42,31 +54,56 @@ def is_undirected(
"""
num_nodes = maybe_num_nodes(edge_index, num_nodes)

edge_attr = [] if edge_attr is None else edge_attr
edge_attr = [edge_attr] if isinstance(edge_attr, Tensor) else edge_attr
edge_attrs: List[Tensor] = []
if isinstance(edge_attr, Tensor):
edge_attrs.append(edge_attr)
elif isinstance(edge_attr, (list, tuple)):
edge_attrs = edge_attr

edge_index1, edge_attr1 = sort_edge_index(
edge_index1, edge_attrs1 = sort_edge_index(
edge_index,
edge_attr,
edge_attrs,
num_nodes=num_nodes,
sort_by_row=True,
)
edge_index2, edge_attr2 = sort_edge_index(
edge_index1,
edge_attr1,
edge_index2, edge_attrs2 = sort_edge_index(
edge_index,
edge_attrs,
num_nodes=num_nodes,
sort_by_row=False,
)

return (bool(torch.all(edge_index1[0] == edge_index2[1]))
and bool(torch.all(edge_index1[1] == edge_index2[0])) and all([
torch.all(e == e_T) for e, e_T in zip(edge_attr1, edge_attr2)
]))
if not torch.equal(edge_index1[0], edge_index2[1]):
return False
if not torch.equal(edge_index1[1], edge_index2[0]):
return False
for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2):
if not torch.equal(edge_attr1, edge_attr2):
return False
return True


@torch.jit._overload
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
# type: (Tensor, Optional[bool], Optional[int], str) -> Tensor # noqa
pass


@torch.jit._overload
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
# type: (Tensor, Tensor, Optional[int], str) -> Tuple[Tensor, Tensor] # noqa
pass


@torch.jit._overload
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
# type: (Tensor, List[Tensor], Optional[int], str) -> Tuple[Tensor, List[Tensor]] # noqa
pass


def to_undirected(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
num_nodes: Optional[int] = None,
reduce: str = "add",
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
Expand Down Expand Up @@ -116,13 +153,13 @@ def to_undirected(
edge_attr = None
num_nodes = edge_attr

row, col = edge_index
row, col = edge_index[0], edge_index[1]
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
edge_index = torch.stack([row, col], dim=0)

if edge_attr is not None and isinstance(edge_attr, Tensor):
if isinstance(edge_attr, Tensor):
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
elif edge_attr is not None:
elif isinstance(edge_attr, (list, tuple)):
edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr]

return coalesce(edge_index, edge_attr, num_nodes, reduce)