Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type Hints] transforms.Distance #5685

Merged
merged 20 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
68ab956
add type hints dense_sage_conv
sachinsharma9780 Oct 12, 2022
1f8aa17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2022
0dc28d6
adding PR number for improved type hint support
sachinsharma9780 Oct 12, 2022
6c3aa8c
Merge branch 'master' of https://github.com/sachinsharma9780/pytorch_…
sachinsharma9780 Oct 12, 2022
590b0fb
adding BoolTensor type to mask argument
sachinsharma9780 Oct 12, 2022
87e85ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2022
1b07460
add type hints dense_gcn_conv
sachinsharma9780 Oct 12, 2022
0b049e3
Merge branch 'master' of https://github.com/sachinsharma9780/pytorch_…
sachinsharma9780 Oct 12, 2022
fd858a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2022
c1cb37b
Merge branch 'master' into master
rusty1s Oct 13, 2022
2ab0361
update
rusty1s Oct 13, 2022
fdf9f54
Merge branch 'master' into master
rusty1s Oct 13, 2022
5ac1b73
Changelog
rusty1s Oct 13, 2022
1ee2cf8
add type hints transforms/distance
sachinsharma9780 Oct 13, 2022
0aad980
Merge branch 'master' of https://github.com/sachinsharma9780/pytorch_…
sachinsharma9780 Oct 13, 2022
ca2064f
adding type hints support in transform/distance
sachinsharma9780 Oct 13, 2022
f0c210c
Merge branch 'master' into master
rusty1s Oct 14, 2022
5f3a2b2
Update torch_geometric/transforms/distance.py
rusty1s Oct 14, 2022
ef7e505
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2022
9335739
Merge branch 'master' into master
rusty1s Oct 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
9 changes: 5 additions & 4 deletions torch_geometric/transforms/distance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional
from torch import tensor
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
import torch

from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

from torch_geometric.data import Data

@functional_transform('distance')
class Distance(BaseTransform):
Expand All @@ -18,12 +19,12 @@ class Distance(BaseTransform):
cat (bool, optional): If set to :obj:`False`, all existing edge
attributes will be replaced. (default: :obj:`True`)
"""
def __init__(self, norm=True, max_value=None, cat=True):
def __init__(self, norm: bool = True, max_value: Optional[float] = None, cat: bool = True):
self.norm = norm
self.max = max_value
self.cat = cat

def __call__(self, data):
def __call__(self, data: Data) -> Data:
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr

dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
Expand Down