Skip to content

Commit c084ef3

Browse files
rusty1sJakubPietrakIntel
authored andcommitted
[Type Hints] utils.is_undirected and utils.to_undirected (pyg-team#5767)
1 parent 43358f8 commit c084ef3

File tree

5 files changed

+105
-27
lines changed

5 files changed

+105
-27
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4646
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
4747
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
4848
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
49-
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757))
49+
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767))
5050
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
5151
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
5252
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))

test/utils/test_undirected.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
2+
from torch import Tensor
23

4+
from torch_geometric.testing import is_full_test
35
from torch_geometric.utils import is_undirected, to_undirected
46

57

@@ -18,10 +20,26 @@ def test_is_undirected():
1820

1921
assert not is_undirected(torch.stack([row, col], dim=0))
2022

23+
if is_full_test():
24+
25+
@torch.jit.script
26+
def jit(edge_index: Tensor) -> bool:
27+
return is_undirected(edge_index)
28+
29+
assert not jit(torch.stack([row, col], dim=0))
30+
2131

2232
def test_to_undirected():
2333
row = torch.tensor([0, 1, 1])
2434
col = torch.tensor([1, 0, 2])
2535

2636
edge_index = to_undirected(torch.stack([row, col], dim=0))
2737
assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
38+
39+
if is_full_test():
40+
41+
@torch.jit.script
42+
def jit(edge_index: Tensor) -> Tensor:
43+
return to_undirected(edge_index)
44+
45+
assert torch.equal(jit(torch.stack([row, col], dim=0)), edge_index)

torch_geometric/nn/dense/diff_pool.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66

77
def dense_diff_pool(
8-
x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None,
9-
normalize: Optional[bool] = True
8+
x: Tensor,
9+
adj: Tensor,
10+
s: Tensor,
11+
mask: Optional[Tensor] = None,
12+
normalize: bool = True,
1013
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1114
r"""The differentiable pooling operator from the `"Hierarchical Graph
1215
Representation Learning with Differentiable Pooling"
@@ -76,8 +79,6 @@ def dense_diff_pool(
7679
if normalize is True:
7780
link_loss = link_loss / adj.numel()
7881

79-
# Moved EPS from global to local variable for TorchScript support
80-
EPS = 1e-15
81-
ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()
82+
ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean()
8283

8384
return out, out_adj, link_loss, ent_loss

torch_geometric/utils/sort_edge_index.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,35 @@
11
from typing import List, Optional, Tuple, Union
22

3+
import torch
34
from torch import Tensor
45

56
from .num_nodes import maybe_num_nodes
67

78

9+
@torch.jit._overload
10+
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
11+
sort_by_row=True):
12+
# type: (Tensor, Optional[bool], Optional[int], bool) -> Tensor # noqa
13+
pass
14+
15+
16+
@torch.jit._overload
17+
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
18+
sort_by_row=True):
19+
# type: (Tensor, Tensor, Optional[int], bool) -> Tuple[Tensor, Tensor] # noqa
20+
pass
21+
22+
23+
@torch.jit._overload
24+
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
25+
sort_by_row=True):
26+
# type: (Tensor, List[Tensor], Optional[int], bool) -> Tuple[Tensor, List[Tensor]] # noqa
27+
pass
28+
29+
830
def sort_edge_index(
931
edge_index: Tensor,
10-
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
32+
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
1133
num_nodes: Optional[int] = None,
1234
sort_by_row: bool = True,
1335
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
@@ -53,9 +75,9 @@ def sort_edge_index(
5375

5476
edge_index = edge_index[:, perm]
5577

56-
if edge_attr is None:
57-
return edge_index
58-
elif isinstance(edge_attr, Tensor):
78+
if isinstance(edge_attr, Tensor):
5979
return edge_index, edge_attr[perm]
60-
else:
80+
elif isinstance(edge_attr, (list, tuple)):
6181
return edge_index, [e[perm] for e in edge_attr]
82+
else:
83+
return edge_index

torch_geometric/utils/undirected.py

+53-16
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,21 @@
88
from .num_nodes import maybe_num_nodes
99

1010

11+
@torch.jit._overload
12+
def is_undirected(edge_index, edge_attr=None, num_nodes=None):
13+
# type: (Tensor, Optional[Tensor], Optional[int]) -> bool # noqa
14+
pass
15+
16+
17+
@torch.jit._overload
18+
def is_undirected(edge_index, edge_attr=None, num_nodes=None):
19+
# type: (Tensor, List[Tensor], Optional[int]) -> bool # noqa
20+
pass
21+
22+
1123
def is_undirected(
1224
edge_index: Tensor,
13-
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
25+
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
1426
num_nodes: Optional[int] = None,
1527
) -> bool:
1628
r"""Returns :obj:`True` if the graph given by :attr:`edge_index` is
@@ -42,31 +54,56 @@ def is_undirected(
4254
"""
4355
num_nodes = maybe_num_nodes(edge_index, num_nodes)
4456

45-
edge_attr = [] if edge_attr is None else edge_attr
46-
edge_attr = [edge_attr] if isinstance(edge_attr, Tensor) else edge_attr
57+
edge_attrs: List[Tensor] = []
58+
if isinstance(edge_attr, Tensor):
59+
edge_attrs.append(edge_attr)
60+
elif isinstance(edge_attr, (list, tuple)):
61+
edge_attrs = edge_attr
4762

48-
edge_index1, edge_attr1 = sort_edge_index(
63+
edge_index1, edge_attrs1 = sort_edge_index(
4964
edge_index,
50-
edge_attr,
65+
edge_attrs,
5166
num_nodes=num_nodes,
5267
sort_by_row=True,
5368
)
54-
edge_index2, edge_attr2 = sort_edge_index(
55-
edge_index1,
56-
edge_attr1,
69+
edge_index2, edge_attrs2 = sort_edge_index(
70+
edge_index,
71+
edge_attrs,
5772
num_nodes=num_nodes,
5873
sort_by_row=False,
5974
)
6075

61-
return (bool(torch.all(edge_index1[0] == edge_index2[1]))
62-
and bool(torch.all(edge_index1[1] == edge_index2[0])) and all([
63-
torch.all(e == e_T) for e, e_T in zip(edge_attr1, edge_attr2)
64-
]))
76+
if not torch.equal(edge_index1[0], edge_index2[1]):
77+
return False
78+
if not torch.equal(edge_index1[1], edge_index2[0]):
79+
return False
80+
for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2):
81+
if not torch.equal(edge_attr1, edge_attr2):
82+
return False
83+
return True
84+
85+
86+
@torch.jit._overload
87+
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
88+
# type: (Tensor, Optional[bool], Optional[int], str) -> Tensor # noqa
89+
pass
90+
91+
92+
@torch.jit._overload
93+
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
94+
# type: (Tensor, Tensor, Optional[int], str) -> Tuple[Tensor, Tensor] # noqa
95+
pass
96+
97+
98+
@torch.jit._overload
99+
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
100+
# type: (Tensor, List[Tensor], Optional[int], str) -> Tuple[Tensor, List[Tensor]] # noqa
101+
pass
65102

66103

67104
def to_undirected(
68105
edge_index: Tensor,
69-
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
106+
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
70107
num_nodes: Optional[int] = None,
71108
reduce: str = "add",
72109
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
@@ -116,13 +153,13 @@ def to_undirected(
116153
edge_attr = None
117154
num_nodes = edge_attr
118155

119-
row, col = edge_index
156+
row, col = edge_index[0], edge_index[1]
120157
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
121158
edge_index = torch.stack([row, col], dim=0)
122159

123-
if edge_attr is not None and isinstance(edge_attr, Tensor):
160+
if isinstance(edge_attr, Tensor):
124161
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
125-
elif edge_attr is not None:
162+
elif isinstance(edge_attr, (list, tuple)):
126163
edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr]
127164

128165
return coalesce(edge_index, edge_attr, num_nodes, reduce)

0 commit comments

Comments
 (0)