From b49c7b01a91998474f6203f9012568b0fa82b187 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 17 Oct 2022 22:53:55 +0800 Subject: [PATCH 1/5] typehints --- torch_geometric/transforms/add_self_loops.py | 3 ++- torch_geometric/transforms/center.py | 3 ++- torch_geometric/transforms/compose.py | 3 ++- torch_geometric/transforms/delaunay.py | 3 ++- torch_geometric/transforms/gcn_norm.py | 3 ++- .../transforms/generate_mesh_normals.py | 3 ++- .../transforms/linear_transformation.py | 3 ++- .../transforms/local_degree_profile.py | 3 ++- .../transforms/normalize_features.py | 3 ++- .../transforms/normalize_rotation.py | 3 ++- torch_geometric/transforms/normalize_scale.py | 3 ++- torch_geometric/transforms/radius_graph.py | 3 ++- .../transforms/random_link_split.py | 21 ++++++++++++++----- .../transforms/random_node_split.py | 3 ++- .../transforms/remove_isolated_nodes.py | 3 ++- torch_geometric/transforms/to_device.py | 3 ++- .../transforms/to_sparse_tensor.py | 6 ++---- torch_geometric/transforms/to_undirected.py | 6 ++---- torch_geometric/transforms/two_hop.py | 3 ++- 19 files changed, 52 insertions(+), 29 deletions(-) diff --git a/torch_geometric/transforms/add_self_loops.py b/torch_geometric/transforms/add_self_loops.py index c540f5887890..58bfd145a2dd 100644 --- a/torch_geometric/transforms/add_self_loops.py +++ b/torch_geometric/transforms/add_self_loops.py @@ -32,7 +32,8 @@ def __init__(self, attr: Optional[str] = 'edge_weight', self.attr = attr self.fill_value = fill_value - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.edge_stores: if store.is_bipartite() or 'edge_index' not in store: continue diff --git a/torch_geometric/transforms/center.py b/torch_geometric/transforms/center.py index 5c05b569e891..9055ea6311d6 100644 --- a/torch_geometric/transforms/center.py +++ b/torch_geometric/transforms/center.py @@ -9,7 +9,8 @@ class Center(BaseTransform): r"""Centers node positions :obj:`pos` around the origin (functional name: :obj:`center`).""" - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.node_stores: if hasattr(store, 'pos'): store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True) diff --git a/torch_geometric/transforms/compose.py b/torch_geometric/transforms/compose.py index d7a0b7d228df..665baa35c15b 100644 --- a/torch_geometric/transforms/compose.py +++ b/torch_geometric/transforms/compose.py @@ -13,7 +13,8 @@ class Compose(BaseTransform): def __init__(self, transforms: List[Callable]): self.transforms = transforms - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for transform in self.transforms: if isinstance(data, (list, tuple)): data = [transform(d) for d in data] diff --git a/torch_geometric/transforms/delaunay.py b/torch_geometric/transforms/delaunay.py index 563069a29a26..8a25fe151642 100644 --- a/torch_geometric/transforms/delaunay.py +++ b/torch_geometric/transforms/delaunay.py @@ -1,6 +1,7 @@ import scipy.spatial import torch +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -9,7 +10,7 @@ class Delaunay(BaseTransform): r"""Computes the delaunay triangulation of a set of points (functional name: :obj:`delaunay`).""" - def __call__(self, data): + def __call__(self, data: Data) -> Data: if data.pos.size(0) < 2: data.edge_index = torch.tensor([], dtype=torch.long, device=data.pos.device).view(2, 0) diff --git a/torch_geometric/transforms/gcn_norm.py b/torch_geometric/transforms/gcn_norm.py index fc6d7a5429e3..aa52a4f020aa 100644 --- a/torch_geometric/transforms/gcn_norm.py +++ b/torch_geometric/transforms/gcn_norm.py @@ -1,4 +1,5 @@ import torch_geometric +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -18,7 +19,7 @@ class GCNNorm(BaseTransform): def __init__(self, add_self_loops: bool = True): self.add_self_loops = add_self_loops - def __call__(self, data): + def __call__(self, data: Data) -> Data: gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm assert 'edge_index' in data or 'adj_t' in data diff --git a/torch_geometric/transforms/generate_mesh_normals.py b/torch_geometric/transforms/generate_mesh_normals.py index 84f14aef73ec..0718edc14a1b 100644 --- a/torch_geometric/transforms/generate_mesh_normals.py +++ b/torch_geometric/transforms/generate_mesh_normals.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from torch_scatter import scatter_add +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -10,7 +11,7 @@ class GenerateMeshNormals(BaseTransform): r"""Generate normal vectors for each mesh node based on neighboring faces (functional name: :obj:`generate_mesh_normals`).""" - def __call__(self, data): + def __call__(self, data: Data) -> Data: assert 'face' in data pos, face = data.pos, data.face diff --git a/torch_geometric/transforms/linear_transformation.py b/torch_geometric/transforms/linear_transformation.py index 0c83cc42fae8..796746dd5061 100644 --- a/torch_geometric/transforms/linear_transformation.py +++ b/torch_geometric/transforms/linear_transformation.py @@ -29,7 +29,8 @@ def __init__(self, matrix: Tensor): # We do this to enable post-multiplication in `__call__`. self.matrix = matrix.t() - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.node_stores: if not hasattr(store, 'pos'): continue diff --git a/torch_geometric/transforms/local_degree_profile.py b/torch_geometric/transforms/local_degree_profile.py index 4f28c2895af8..6fa0a47f4c41 100644 --- a/torch_geometric/transforms/local_degree_profile.py +++ b/torch_geometric/transforms/local_degree_profile.py @@ -1,6 +1,7 @@ import torch from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_std +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import degree @@ -20,7 +21,7 @@ class LocalDegreeProfile(BaseTransform): to the node features, where :math:`DN(i) = \{ \deg(j) \mid j \in \mathcal{N}(i) \}`. """ - def __call__(self, data): + def __call__(self, data: Data) -> Data: row, col = data.edge_index N = data.num_nodes diff --git a/torch_geometric/transforms/normalize_features.py b/torch_geometric/transforms/normalize_features.py index 3372442dd8ae..640b8db8ac2e 100644 --- a/torch_geometric/transforms/normalize_features.py +++ b/torch_geometric/transforms/normalize_features.py @@ -17,7 +17,8 @@ class NormalizeFeatures(BaseTransform): def __init__(self, attrs: List[str] = ["x"]): self.attrs = attrs - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.stores: for key, value in store.items(*self.attrs): value = value - value.min() diff --git a/torch_geometric/transforms/normalize_rotation.py b/torch_geometric/transforms/normalize_rotation.py index b0f3e8f687c9..bb4b36f6c95f 100644 --- a/torch_geometric/transforms/normalize_rotation.py +++ b/torch_geometric/transforms/normalize_rotation.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -23,7 +24,7 @@ def __init__(self, max_points: int = -1, sort: bool = False): self.max_points = max_points self.sort = sort - def __call__(self, data): + def __call__(self, data: Data) -> Data: pos = data.pos if self.max_points > 0 and pos.size(0) > self.max_points: diff --git a/torch_geometric/transforms/normalize_scale.py b/torch_geometric/transforms/normalize_scale.py index d13e74662c71..165f8198ba03 100644 --- a/torch_geometric/transforms/normalize_scale.py +++ b/torch_geometric/transforms/normalize_scale.py @@ -1,3 +1,4 @@ +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform, Center @@ -10,7 +11,7 @@ class NormalizeScale(BaseTransform): def __init__(self): self.center = Center() - def __call__(self, data): + def __call__(self, data: Data) -> Data: data = self.center(data) scale = (1 / data.pos.abs().max()) * 0.999999 diff --git a/torch_geometric/transforms/radius_graph.py b/torch_geometric/transforms/radius_graph.py index 43d9bf617e84..f2418567dc3a 100644 --- a/torch_geometric/transforms/radius_graph.py +++ b/torch_geometric/transforms/radius_graph.py @@ -1,4 +1,5 @@ import torch_geometric +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -36,7 +37,7 @@ def __init__( self.flow = flow self.num_workers = num_workers - def __call__(self, data): + def __call__(self, data: Data) -> Data: data.edge_attr = None batch = data.batch if 'batch' in data else None diff --git a/torch_geometric/transforms/random_link_split.py b/torch_geometric/transforms/random_link_split.py index ad23d160ee2b..3bc3f236bdcc 100644 --- a/torch_geometric/transforms/random_link_split.py +++ b/torch_geometric/transforms/random_link_split.py @@ -116,7 +116,8 @@ def __init__( self.edge_types = edge_types self.rev_edge_types = rev_edge_types - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: edge_types = self.edge_types rev_edge_types = self.rev_edge_types @@ -246,8 +247,13 @@ def __call__(self, data: Union[Data, HeteroData]): return train_data, val_data, test_data - def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool, - rev_edge_type: EdgeType): + def _split( + self, + store: EdgeStorage, + index: Tensor, + is_undirected: bool, + rev_edge_type: EdgeType, + ) -> EdgeStorage: for key, value in store.items(): if key == 'edge_index': @@ -276,8 +282,13 @@ def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool, return store - def _create_label(self, store: EdgeStorage, index: Tensor, - neg_edge_index: Tensor, out: EdgeStorage): + def _create_label( + self, + store: EdgeStorage, + index: Tensor, + neg_edge_index: Tensor, + out: EdgeStorage, + ) -> EdgeStorage: edge_index = store.edge_index[:, index] diff --git a/torch_geometric/transforms/random_node_split.py b/torch_geometric/transforms/random_node_split.py index 7d1ffa281c0a..93bcb377b44e 100644 --- a/torch_geometric/transforms/random_node_split.py +++ b/torch_geometric/transforms/random_node_split.py @@ -69,7 +69,8 @@ def __init__( self.num_test = num_test self.key = key - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.node_stores: if self.key is not None and not hasattr(store, self.key): continue diff --git a/torch_geometric/transforms/remove_isolated_nodes.py b/torch_geometric/transforms/remove_isolated_nodes.py index bda19f68b8e7..444363aba763 100644 --- a/torch_geometric/transforms/remove_isolated_nodes.py +++ b/torch_geometric/transforms/remove_isolated_nodes.py @@ -12,7 +12,8 @@ class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph (functional name: :obj:`remove_isolated_nodes`).""" - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: # Gather all nodes that occur in at least one edge (across all types): n_id_dict = defaultdict(list) for store in data.edge_stores: diff --git a/torch_geometric/transforms/to_device.py b/torch_geometric/transforms/to_device.py index 9492341df4b2..d05b86d0540c 100644 --- a/torch_geometric/transforms/to_device.py +++ b/torch_geometric/transforms/to_device.py @@ -29,7 +29,8 @@ def __init__( self.attrs = attrs or [] self.non_blocking = non_blocking - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: return data.to(self.device, *self.attrs, non_blocking=self.non_blocking) diff --git a/torch_geometric/transforms/to_sparse_tensor.py b/torch_geometric/transforms/to_sparse_tensor.py index 7abe38cd7877..b629fd728c46 100644 --- a/torch_geometric/transforms/to_sparse_tensor.py +++ b/torch_geometric/transforms/to_sparse_tensor.py @@ -38,7 +38,8 @@ def __init__(self, attr: Optional[str] = 'edge_weight', self.remove_edge_index = remove_edge_index self.fill_cache = fill_cache - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.edge_stores: if 'edge_index' not in store: continue @@ -75,6 +76,3 @@ def __call__(self, data: Union[Data, HeteroData]): store.adj_t.storage.csr2csc() return data - - def __repr__(self) -> str: - return f'{self.__class__.__name__}()' diff --git a/torch_geometric/transforms/to_undirected.py b/torch_geometric/transforms/to_undirected.py index c8fea3b2198d..61a6b5749e76 100644 --- a/torch_geometric/transforms/to_undirected.py +++ b/torch_geometric/transforms/to_undirected.py @@ -34,7 +34,8 @@ def __init__(self, reduce: str = "add", merge: bool = True): self.reduce = reduce self.merge = merge - def __call__(self, data: Union[Data, HeteroData]): + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: for store in data.edge_stores: if 'edge_index' not in store: continue @@ -74,6 +75,3 @@ def __call__(self, data: Union[Data, HeteroData]): store[key] = value return data - - def __repr__(self) -> str: - return f'{self.__class__.__name__}()' diff --git a/torch_geometric/transforms/two_hop.py b/torch_geometric/transforms/two_hop.py index b7a32c5e7348..86fdf07f2f9c 100644 --- a/torch_geometric/transforms/two_hop.py +++ b/torch_geometric/transforms/two_hop.py @@ -1,6 +1,7 @@ import torch from torch_sparse import coalesce, spspmm +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import remove_self_loops @@ -10,7 +11,7 @@ class TwoHop(BaseTransform): r"""Adds the two hop edges to the edge indices (functional name: :obj:`two_hop`).""" - def __call__(self, data): + def __call__(self, data: Data) -> Data: edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes From b58aade8be4b6ae0b18fb75be279d508de1894e0 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 17 Oct 2022 23:03:13 +0800 Subject: [PATCH 2/5] typehints --- torch_geometric/transforms/add_self_loops.py | 6 ++++-- torch_geometric/transforms/center.py | 6 ++++-- torch_geometric/transforms/linear_transformation.py | 6 ++++-- torch_geometric/transforms/normalize_features.py | 9 ++++----- torch_geometric/transforms/random_link_split.py | 6 ++++-- torch_geometric/transforms/random_node_split.py | 6 ++++-- torch_geometric/transforms/remove_isolated_nodes.py | 6 ++++-- torch_geometric/transforms/to_device.py | 6 ++++-- torch_geometric/transforms/to_sparse_tensor.py | 6 ++++-- torch_geometric/transforms/to_undirected.py | 6 ++++-- 10 files changed, 40 insertions(+), 23 deletions(-) diff --git a/torch_geometric/transforms/add_self_loops.py b/torch_geometric/transforms/add_self_loops.py index 58bfd145a2dd..79dcbb17b52c 100644 --- a/torch_geometric/transforms/add_self_loops.py +++ b/torch_geometric/transforms/add_self_loops.py @@ -32,8 +32,10 @@ def __init__(self, attr: Optional[str] = 'edge_weight', self.attr = attr self.fill_value = fill_value - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.edge_stores: if store.is_bipartite() or 'edge_index' not in store: continue diff --git a/torch_geometric/transforms/center.py b/torch_geometric/transforms/center.py index 9055ea6311d6..9f73034b6365 100644 --- a/torch_geometric/transforms/center.py +++ b/torch_geometric/transforms/center.py @@ -9,8 +9,10 @@ class Center(BaseTransform): r"""Centers node positions :obj:`pos` around the origin (functional name: :obj:`center`).""" - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.node_stores: if hasattr(store, 'pos'): store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True) diff --git a/torch_geometric/transforms/linear_transformation.py b/torch_geometric/transforms/linear_transformation.py index 796746dd5061..79335514c4b2 100644 --- a/torch_geometric/transforms/linear_transformation.py +++ b/torch_geometric/transforms/linear_transformation.py @@ -29,8 +29,10 @@ def __init__(self, matrix: Tensor): # We do this to enable post-multiplication in `__call__`. self.matrix = matrix.t() - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.node_stores: if not hasattr(store, 'pos'): continue diff --git a/torch_geometric/transforms/normalize_features.py b/torch_geometric/transforms/normalize_features.py index 640b8db8ac2e..5b7a0fa8bb68 100644 --- a/torch_geometric/transforms/normalize_features.py +++ b/torch_geometric/transforms/normalize_features.py @@ -17,14 +17,13 @@ class NormalizeFeatures(BaseTransform): def __init__(self, attrs: List[str] = ["x"]): self.attrs = attrs - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.stores: for key, value in store.items(*self.attrs): value = value - value.min() value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.)) store[key] = value return data - - def __repr__(self) -> str: - return f'{self.__class__.__name__}()' diff --git a/torch_geometric/transforms/random_link_split.py b/torch_geometric/transforms/random_link_split.py index 3bc3f236bdcc..85c059d13710 100644 --- a/torch_geometric/transforms/random_link_split.py +++ b/torch_geometric/transforms/random_link_split.py @@ -116,8 +116,10 @@ def __init__( self.edge_types = edge_types self.rev_edge_types = rev_edge_types - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: edge_types = self.edge_types rev_edge_types = self.rev_edge_types diff --git a/torch_geometric/transforms/random_node_split.py b/torch_geometric/transforms/random_node_split.py index 93bcb377b44e..2c9e05fc64dc 100644 --- a/torch_geometric/transforms/random_node_split.py +++ b/torch_geometric/transforms/random_node_split.py @@ -69,8 +69,10 @@ def __init__( self.num_test = num_test self.key = key - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.node_stores: if self.key is not None and not hasattr(store, self.key): continue diff --git a/torch_geometric/transforms/remove_isolated_nodes.py b/torch_geometric/transforms/remove_isolated_nodes.py index 444363aba763..2bf51b64254e 100644 --- a/torch_geometric/transforms/remove_isolated_nodes.py +++ b/torch_geometric/transforms/remove_isolated_nodes.py @@ -12,8 +12,10 @@ class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph (functional name: :obj:`remove_isolated_nodes`).""" - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: # Gather all nodes that occur in at least one edge (across all types): n_id_dict = defaultdict(list) for store in data.edge_stores: diff --git a/torch_geometric/transforms/to_device.py b/torch_geometric/transforms/to_device.py index d05b86d0540c..f1230bea2715 100644 --- a/torch_geometric/transforms/to_device.py +++ b/torch_geometric/transforms/to_device.py @@ -29,8 +29,10 @@ def __init__( self.attrs = attrs or [] self.non_blocking = non_blocking - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: return data.to(self.device, *self.attrs, non_blocking=self.non_blocking) diff --git a/torch_geometric/transforms/to_sparse_tensor.py b/torch_geometric/transforms/to_sparse_tensor.py index b629fd728c46..d4b8f1df3d7b 100644 --- a/torch_geometric/transforms/to_sparse_tensor.py +++ b/torch_geometric/transforms/to_sparse_tensor.py @@ -38,8 +38,10 @@ def __init__(self, attr: Optional[str] = 'edge_weight', self.remove_edge_index = remove_edge_index self.fill_cache = fill_cache - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.edge_stores: if 'edge_index' not in store: continue diff --git a/torch_geometric/transforms/to_undirected.py b/torch_geometric/transforms/to_undirected.py index 61a6b5749e76..4705e239ef18 100644 --- a/torch_geometric/transforms/to_undirected.py +++ b/torch_geometric/transforms/to_undirected.py @@ -34,8 +34,10 @@ def __init__(self, reduce: str = "add", merge: bool = True): self.reduce = reduce self.merge = merge - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for store in data.edge_stores: if 'edge_index' not in store: continue From 27c4dd7d85d0c0ce68808e468491ce75ded7f96d Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 17 Oct 2022 23:08:21 +0800 Subject: [PATCH 3/5] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24b8f8df2259..cf0c58d9f846 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), , [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) From 80318e8f3bc3f5d3a6139f763cc97cc9bac9ec26 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 17 Oct 2022 23:19:30 +0800 Subject: [PATCH 4/5] typehints --- torch_geometric/transforms/compose.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_geometric/transforms/compose.py b/torch_geometric/transforms/compose.py index 665baa35c15b..168ac00565b3 100644 --- a/torch_geometric/transforms/compose.py +++ b/torch_geometric/transforms/compose.py @@ -13,8 +13,10 @@ class Compose(BaseTransform): def __init__(self, transforms: List[Callable]): self.transforms = transforms - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: + def __call__( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: for transform in self.transforms: if isinstance(data, (list, tuple)): data = [transform(d) for d in data] From 9b76686f701d726fe9cf45245e921128f8c20473 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 17 Oct 2022 19:02:44 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf0c58d9f846..d351ac27a222 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), , [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))