Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TypeHints] Node2Vec #5669

Merged
merged 8 commits into from
Oct 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
17 changes: 11 additions & 6 deletions test/nn/models/test_node2vec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torch_geometric.nn import Node2Vec
from torch_geometric.testing import withPackage
from torch_geometric.testing import is_full_test, withPackage


@withPackage('torch_cluster')
@@ -12,14 +12,19 @@ def test_node2vec():
context_size=2)
assert model.__repr__() == 'Node2Vec(3, 16)'

z = model(torch.arange(3))
assert z.size() == (3, 16)
assert model(torch.arange(3)).size() == (3, 16)

pos_rw, neg_rw = model.sample(torch.arange(3))

loss = model.loss(pos_rw, neg_rw)
assert 0 <= loss.item()
assert float(model.loss(pos_rw, neg_rw)) >= 0

acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )),
torch.ones(20, 16), torch.randint(10, (20, )))
assert 0 <= acc and acc <= 1

if is_full_test():
jit = torch.jit.export(model)

assert jit(torch.arange(3)).size() == (3, 16)

pos_rw, neg_rw = jit.sample(torch.arange(3))
assert float(jit.loss(pos_rw, neg_rw)) >= 0
49 changes: 36 additions & 13 deletions torch_geometric/nn/models/node2vec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor

from torch_geometric.typing import OptTensor
from torch_geometric.utils.num_nodes import maybe_num_nodes

try:
@@ -46,9 +50,19 @@ class Node2Vec(torch.nn.Module):
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
weight matrix will be sparse. (default: :obj:`False`)
"""
def __init__(self, edge_index, embedding_dim, walk_length, context_size,
walks_per_node=1, p=1, q=1, num_negative_samples=1,
num_nodes=None, sparse=False):
def __init__(
self,
edge_index: Tensor,
embedding_dim: int,
walk_length: int,
context_size: int,
walks_per_node: int = 1,
p: float = 1.0,
q: float = 1.0,
num_negative_samples: int = 1,
num_nodes: Optional[int] = None,
sparse: bool = False,
):
super().__init__()

if random_walk is None:
@@ -76,20 +90,20 @@ def __init__(self, edge_index, embedding_dim, walk_length, context_size,
def reset_parameters(self):
self.embedding.reset_parameters()

def forward(self, batch=None):
def forward(self, batch: OptTensor = None) -> Tensor:
"""Returns the embeddings for the nodes in :obj:`batch`."""
emb = self.embedding.weight
return emb if batch is None else emb.index_select(0, batch)

def loader(self, **kwargs):
def loader(self, **kwargs) -> DataLoader:
return DataLoader(range(self.adj.sparse_size(0)),
collate_fn=self.sample, **kwargs)

def pos_sample(self, batch):
def pos_sample(self, batch: Tensor) -> Tensor:
batch = batch.repeat(self.walks_per_node)
rowptr, col, _ = self.adj.csr()
rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
if not isinstance(rw, torch.Tensor):
if not isinstance(rw, Tensor):
rw = rw[0]

walks = []
@@ -98,7 +112,7 @@ def pos_sample(self, batch):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)

def neg_sample(self, batch):
def neg_sample(self, batch: Tensor) -> Tensor:
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

rw = torch.randint(self.adj.sparse_size(0),
@@ -111,12 +125,12 @@ def neg_sample(self, batch):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)

def sample(self, batch):
if not isinstance(batch, torch.Tensor):
def sample(self, batch: Tensor) -> Tuple[Tensor, Tensor]:
if not isinstance(batch, Tensor):
batch = torch.tensor(batch)
return self.pos_sample(batch), self.neg_sample(batch)

def loss(self, pos_rw, neg_rw):
def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
r"""Computes the loss given positive and negative random walks."""

# Positive loss.
@@ -143,8 +157,17 @@ def loss(self, pos_rw, neg_rw):

return pos_loss + neg_loss

def test(self, train_z, train_y, test_z, test_y, solver='lbfgs',
multi_class='auto', *args, **kwargs):
def test(
self,
train_z: Tensor,
train_y: Tensor,
test_z: Tensor,
test_y: Tensor,
solver: str = 'lbfgs',
multi_class: str = 'auto',
*args,
**kwargs,
) -> float:
r"""Evaluates latent space quality via a logistic regression downstream
task."""
from sklearn.linear_model import LogisticRegression