Skip to content

Commit 8cbbd72

Browse files
nihal-raorusty1s
andauthored
[Type Hints] utils.grid (#5724)
This is a draft PR, as the test in `test_grid.py` is failing because of issues with the `coalesce` function. Currently, the code uses `torch.sparse.coalesce` , which does not have any type hints(hence cant be used with jit). On @rusty1s suggestion, I tried using `torch_geometric.utils.coalesce` instead. The test (on running `FULL_TEST=1 pytest test/utils/test_grid.py`) is still failing, with the traceback as below : ``` E RuntimeError: E 'Union[Tensor, List[Tensor], NoneType]' object is not subscriptable: E File "/home/nihal/opensource/pytorch_geometric/torch_geometric/utils/coalesce.py", line 78 E edge_index = edge_index[:, perm] E if edge_attr is not None and isinstance(edge_attr, Tensor): E edge_attr = edge_attr[perm] E ~~~~~~~~~~~~~~~ <--- HERE E elif edge_attr is not None: E edge_attr = [e[perm] for e in edge_attr] E 'coalesce' is being compiled since it was called from 'grid_index' E File "/home/nihal/opensource/pytorch_geometric/torch_geometric/utils/grid.py", line 60 E E edge_index = torch.stack([row, col], dim=0) E edge_index = coalesce(edge_index, None, height * width) E ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE E print('------------------------',type(edge_index)) E return edge_index E 'grid_index' is being compiled since it was called from 'grid' E File "/home/nihal/opensource/pytorch_geometric/torch_geometric/utils/grid.py", line 38 E """ E E edge_index = grid_index(height, width, device) E ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE E pos = grid_pos(height, width, dtype, device) E return edge_index, pos ../../miniconda3/envs/pygenv/lib/python3.8/site-packages/torch/jit/_script.py:1343: RuntimeError ``` Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent 7cb8c07 commit 8cbbd72

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4646
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
4747
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
4848
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
49-
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757))
49+
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757))
5050
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
5151
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
5252
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))

test/utils/test_grid.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torch
2+
3+
from torch_geometric.testing import is_full_test
14
from torch_geometric.utils import grid
25

36

@@ -14,3 +17,10 @@ def test_grid():
1417
assert row.tolist() == expected_row
1518
assert col.tolist() == expected_col
1619
assert pos.tolist() == expected_pos
20+
21+
if is_full_test():
22+
jit = torch.jit.script(grid)
23+
(row, col), pos = jit(height=3, width=2)
24+
assert row.tolist() == expected_row
25+
assert col.tolist() == expected_col
26+
assert pos.tolist() == expected_pos

torch_geometric/utils/coalesce.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,30 @@
77
from .num_nodes import maybe_num_nodes
88

99

10+
@torch.jit._overload
11+
def coalesce(edge_index, edge_attr=None, num_nodes=None, reduce="add",
12+
is_sorted=False, sort_by_row=True):
13+
# type: (Tensor, Optional[bool], Optional[int], str, bool, bool) -> Tensor # noqa
14+
pass
15+
16+
17+
@torch.jit._overload
18+
def coalesce(edge_index, edge_attr=None, num_nodes=None, reduce="add",
19+
is_sorted=False, sort_by_row=True):
20+
# type: (Tensor, Tensor, Optional[int], str, bool, bool) -> Tuple[Tensor, Tensor] # noqa
21+
pass
22+
23+
24+
@torch.jit._overload
25+
def coalesce(edge_index, edge_attr=None, num_nodes=None, reduce="add",
26+
is_sorted=False, sort_by_row=True):
27+
# type: (Tensor, List[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, List[Tensor]] # noqa
28+
pass
29+
30+
1031
def coalesce(
1132
edge_index: Tensor,
12-
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
33+
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
1334
num_nodes: Optional[int] = None,
1435
reduce: str = "add",
1536
is_sorted: bool = False,
@@ -74,18 +95,21 @@ def coalesce(
7495
if not is_sorted:
7596
idx[1:], perm = idx[1:].sort()
7697
edge_index = edge_index[:, perm]
77-
if edge_attr is not None and isinstance(edge_attr, Tensor):
98+
if isinstance(edge_attr, Tensor):
7899
edge_attr = edge_attr[perm]
79-
elif edge_attr is not None:
100+
elif isinstance(edge_attr, (list, tuple)):
80101
edge_attr = [e[perm] for e in edge_attr]
81102

82103
mask = idx[1:] > idx[:-1]
83104

84105
# Only perform expensive merging in case there exists duplicates:
85106
if mask.all():
86-
return edge_index if edge_attr is None else (edge_index, edge_attr)
107+
if isinstance(edge_attr, (Tensor, list, tuple)):
108+
return edge_index, edge_attr
109+
return edge_index
87110

88111
edge_index = edge_index[:, mask]
112+
89113
if edge_attr is None:
90114
return edge_index
91115

@@ -95,9 +119,11 @@ def coalesce(
95119

96120
if isinstance(edge_attr, Tensor):
97121
edge_attr = scatter(edge_attr, idx, 0, None, dim_size, reduce)
98-
else:
122+
return edge_index, edge_attr
123+
elif isinstance(edge_attr, (list, tuple)):
99124
edge_attr = [
100125
scatter(e, idx, 0, None, dim_size, reduce) for e in edge_attr
101126
]
127+
return edge_index, edge_attr
102128

103-
return edge_index, edge_attr
129+
return edge_index

torch_geometric/utils/grid.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
from typing import Optional, Tuple
2+
13
import torch
2-
from torch_sparse import coalesce
4+
from torch import Tensor
5+
6+
from torch_geometric.utils.coalesce import coalesce
37

48

5-
def grid(height, width, dtype=None, device=None):
9+
def grid(
10+
height: int,
11+
width: int,
12+
dtype: Optional[torch.dtype] = None,
13+
device: Optional[torch.device] = None,
14+
) -> Tuple[Tensor, Tensor]:
615
r"""Returns the edge indices of a two-dimensional grid graph with height
716
:attr:`height` and width :attr:`width` and its node positions.
817
@@ -29,13 +38,17 @@ def grid(height, width, dtype=None, device=None):
2938
[0., 0.],
3039
[1., 0.]])
3140
"""
32-
3341
edge_index = grid_index(height, width, device)
3442
pos = grid_pos(height, width, dtype, device)
3543
return edge_index, pos
3644

3745

38-
def grid_index(height, width, device=None):
46+
def grid_index(
47+
height: int,
48+
width: int,
49+
device: Optional[torch.device] = None,
50+
) -> Tensor:
51+
3952
w = width
4053
kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
4154
kernel = torch.tensor(kernel, device=device)
@@ -51,12 +64,17 @@ def grid_index(height, width, device=None):
5164
row, col = row[mask], col[mask]
5265

5366
edge_index = torch.stack([row, col], dim=0)
54-
edge_index, _ = coalesce(edge_index, None, height * width, height * width)
55-
67+
edge_index = coalesce(edge_index, None, height * width)
5668
return edge_index
5769

5870

59-
def grid_pos(height, width, dtype=None, device=None):
71+
def grid_pos(
72+
height: int,
73+
width: int,
74+
dtype: Optional[torch.dtype] = None,
75+
device: Optional[torch.device] = None,
76+
) -> Tensor:
77+
6078
dtype = torch.float if dtype is None else dtype
6179
x = torch.arange(width, dtype=dtype, device=device)
6280
y = (height - 1) - torch.arange(height, dtype=dtype, device=device)

0 commit comments

Comments
 (0)