Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type Hints] Missing type hints in transforms.* #5753

Merged
merged 5 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), , [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753))
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/add_self_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def __init__(self, attr: Optional[str] = 'edge_weight',
self.attr = attr
self.fill_value = fill_value

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.edge_stores:
if store.is_bipartite() or 'edge_index' not in store:
continue
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/center.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
class Center(BaseTransform):
r"""Centers node positions :obj:`pos` around the origin
(functional name: :obj:`center`)."""
def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.node_stores:
if hasattr(store, 'pos'):
store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True)
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ class Compose(BaseTransform):
def __init__(self, transforms: List[Callable]):
self.transforms = transforms

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for transform in self.transforms:
if isinstance(data, (list, tuple)):
data = [transform(d) for d in data]
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/transforms/delaunay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import scipy.spatial
import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -9,7 +10,7 @@
class Delaunay(BaseTransform):
r"""Computes the delaunay triangulation of a set of points
(functional name: :obj:`delaunay`)."""
def __call__(self, data):
def __call__(self, data: Data) -> Data:
if data.pos.size(0) < 2:
data.edge_index = torch.tensor([], dtype=torch.long,
device=data.pos.device).view(2, 0)
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/transforms/gcn_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -18,7 +19,7 @@ class GCNNorm(BaseTransform):
def __init__(self, add_self_loops: bool = True):
self.add_self_loops = add_self_loops

def __call__(self, data):
def __call__(self, data: Data) -> Data:
gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm
assert 'edge_index' in data or 'adj_t' in data

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/transforms/generate_mesh_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F
from torch_scatter import scatter_add

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -10,7 +11,7 @@
class GenerateMeshNormals(BaseTransform):
r"""Generate normal vectors for each mesh node based on neighboring
faces (functional name: :obj:`generate_mesh_normals`)."""
def __call__(self, data):
def __call__(self, data: Data) -> Data:
assert 'face' in data
pos, face = data.pos, data.face

Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/linear_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def __init__(self, matrix: Tensor):
# We do this to enable post-multiplication in `__call__`.
self.matrix = matrix.t()

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.node_stores:
if not hasattr(store, 'pos'):
continue
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/transforms/local_degree_profile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_std

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import degree
Expand All @@ -20,7 +21,7 @@ class LocalDegreeProfile(BaseTransform):
to the node features, where :math:`DN(i) = \{ \deg(j) \mid j \in
\mathcal{N}(i) \}`.
"""
def __call__(self, data):
def __call__(self, data: Data) -> Data:
row, col = data.edge_index
N = data.num_nodes

Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/transforms/normalize_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class NormalizeFeatures(BaseTransform):
def __init__(self, attrs: List[str] = ["x"]):
self.attrs = attrs

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.stores:
for key, value in store.items(*self.attrs):
value = value - value.min()
value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.))
store[key] = value
return data

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion torch_geometric/transforms/normalize_rotation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -23,7 +24,7 @@ def __init__(self, max_points: int = -1, sort: bool = False):
self.max_points = max_points
self.sort = sort

def __call__(self, data):
def __call__(self, data: Data) -> Data:
pos = data.pos

if self.max_points > 0 and pos.size(0) > self.max_points:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/transforms/normalize_scale.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, Center

Expand All @@ -10,7 +11,7 @@ class NormalizeScale(BaseTransform):
def __init__(self):
self.center = Center()

def __call__(self, data):
def __call__(self, data: Data) -> Data:
data = self.center(data)

scale = (1 / data.pos.abs().max()) * 0.999999
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/transforms/radius_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(
self.flow = flow
self.num_workers = num_workers

def __call__(self, data):
def __call__(self, data: Data) -> Data:
data.edge_attr = None
batch = data.batch if 'batch' in data else None

Expand Down
23 changes: 18 additions & 5 deletions torch_geometric/transforms/random_link_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def __init__(
self.edge_types = edge_types
self.rev_edge_types = rev_edge_types

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
edge_types = self.edge_types
rev_edge_types = self.rev_edge_types

Expand Down Expand Up @@ -246,8 +249,13 @@ def __call__(self, data: Union[Data, HeteroData]):

return train_data, val_data, test_data

def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool,
rev_edge_type: EdgeType):
def _split(
self,
store: EdgeStorage,
index: Tensor,
is_undirected: bool,
rev_edge_type: EdgeType,
) -> EdgeStorage:

for key, value in store.items():
if key == 'edge_index':
Expand Down Expand Up @@ -276,8 +284,13 @@ def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool,

return store

def _create_label(self, store: EdgeStorage, index: Tensor,
neg_edge_index: Tensor, out: EdgeStorage):
def _create_label(
self,
store: EdgeStorage,
index: Tensor,
neg_edge_index: Tensor,
out: EdgeStorage,
) -> EdgeStorage:

edge_index = store.edge_index[:, index]

Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/random_node_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def __init__(
self.num_test = num_test
self.key = key

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.node_stores:
if self.key is not None and not hasattr(store, self.key):
continue
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/remove_isolated_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
class RemoveIsolatedNodes(BaseTransform):
r"""Removes isolated nodes from the graph
(functional name: :obj:`remove_isolated_nodes`)."""
def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
# Gather all nodes that occur in at least one edge (across all types):
n_id_dict = defaultdict(list)
for store in data.edge_stores:
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/to_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def __init__(
self.attrs = attrs or []
self.non_blocking = non_blocking

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
return data.to(self.device, *self.attrs,
non_blocking=self.non_blocking)

Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/transforms/to_sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def __init__(self, attr: Optional[str] = 'edge_weight',
self.remove_edge_index = remove_edge_index
self.fill_cache = fill_cache

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.edge_stores:
if 'edge_index' not in store:
continue
Expand Down Expand Up @@ -75,6 +78,3 @@ def __call__(self, data: Union[Data, HeteroData]):
store.adj_t.storage.csr2csc()

return data

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions torch_geometric/transforms/to_undirected.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(self, reduce: str = "add", merge: bool = True):
self.reduce = reduce
self.merge = merge

def __call__(self, data: Union[Data, HeteroData]):
def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.edge_stores:
if 'edge_index' not in store:
continue
Expand Down Expand Up @@ -74,6 +77,3 @@ def __call__(self, data: Union[Data, HeteroData]):
store[key] = value

return data

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion torch_geometric/transforms/two_hop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch_sparse import coalesce, spspmm

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_self_loops
Expand All @@ -10,7 +11,7 @@
class TwoHop(BaseTransform):
r"""Adds the two hop edges to the edge indices
(functional name: :obj:`two_hop`)."""
def __call__(self, data):
def __call__(self, data: Data) -> Data:
edge_index, edge_attr = data.edge_index, data.edge_attr
N = data.num_nodes

Expand Down