Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement neighbor sampler #76

Merged
merged 9 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58))
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69))
- Added `pyg::sampler::neighbor_sample` interface ([#54](https://github.com/pyg-team/pyg-lib/pull/54))
- Added `pyg::sampler::neighbor_sample` implementation ([#54](https://github.com/pyg-team/pyg-lib/pull/54), [#76](https://github.com/pyg-team/pyg-lib/pull/76))
- Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45)))
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45))
- Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44))
Expand Down
23 changes: 17 additions & 6 deletions pyg_lib/csrc/sampler/cpu/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@ class Mapper {
to_local_vec = std::vector<scalar_t>(num_nodes, -1);
}

void fill(const scalar_t* nodes_data, const scalar_t size) {
std::pair<scalar_t, bool> insert(const scalar_t& node) {
std::pair<scalar_t, bool> res;
if (use_vec) {
for (scalar_t i = 0; i < size; ++i)
to_local_vec[nodes_data[i]] = i;
auto old = to_local_vec[node];
res = std::pair<scalar_t, bool>(old == -1 ? curr : old, old == -1);
if (res.second)
to_local_vec[node] = curr;
} else {
for (scalar_t i = 0; i < size; ++i)
to_local_map.insert({nodes_data[i], i});
auto out = to_local_map.insert({node, curr});
res = std::pair<scalar_t, bool>(out.first->second, out.second);
}
if (res.second)
curr++;
return res;
}

void fill(const scalar_t* nodes_data, const scalar_t size) {
for (size_t i = 0; i < size; ++i)
insert(nodes_data[i]);
}

void fill(const at::Tensor& nodes) {
Expand All @@ -54,7 +65,7 @@ class Mapper {
}

private:
scalar_t num_nodes, num_entries;
scalar_t num_nodes, num_entries, curr = 0;

bool use_vec;
std::vector<scalar_t> to_local_vec;
Expand Down
138 changes: 135 additions & 3 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,153 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "parallel_hashmap/phmap.h"

#include "pyg_lib/csrc/random/cpu/rand_engine.h"
#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"

namespace pyg {
namespace sampler {

namespace {

template <bool replace, bool directed>
std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors) {
AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] {
const auto num_nodes = rowptr.size(0) - 1;

const auto* rowptr_data = rowptr.data_ptr<scalar_t>();
const auto* col_data = col.data_ptr<scalar_t>();
const auto* seed_data = seed.data_ptr<scalar_t>();

pyg::random::RandintEngine<scalar_t> eng;

// Initialize some data structures for the sampling process:
std::vector<scalar_t> samples, rows, cols, edges;
// TODO (matthias) Approximate number of sampled entries for mapper.
auto mapper = pyg::sampler::Mapper<scalar_t>(num_nodes, seed.size(0));

for (size_t i = 0; i < seed.numel(); i++) {
samples.push_back(seed_data[i]);
mapper.insert(seed_data[i]);
}

size_t begin = 0, end = samples.size();
for (size_t ell = 0; ell < num_neighbors.size(); ++ell) {
const auto num_samples = num_neighbors[ell];

for (size_t i = begin; i < end; i++) {
const auto v = samples[i];
const auto row_start = rowptr_data[v];
const auto row_end = rowptr_data[v + 1];
const auto row_count = row_end - row_start;

if (row_count == 0)
continue;

if ((num_samples < 0) || (!replace && (num_samples >= row_count))) {
for (scalar_t e = row_start; e < row_end; ++e) {
const auto w = col_data[e];
const auto res = mapper.insert(w);
if (res.second)
samples.push_back(w);
if (directed) {
rows.push_back(i);
cols.push_back(res.first);
edges.push_back(e);
}
}
} else if (replace) {
for (size_t j = 0; j < num_samples; ++j) {
const scalar_t e = eng(row_start, row_end);
const auto w = col_data[e];
const auto res = mapper.insert(w);
if (res.second)
samples.push_back(w);
if (directed) {
rows.push_back(i);
cols.push_back(res.first);
edges.push_back(e);
}
}
} else {
std::unordered_set<scalar_t> rnd_indices;
for (scalar_t j = row_count - num_samples; j < row_count; ++j) {
scalar_t rnd = eng(0, j);
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
}
const scalar_t e = row_start + rnd;
const auto w = col_data[e];
const auto res = mapper.insert(w);
if (res.second)
samples.push_back(w);
if (directed) {
rows.push_back(i);
cols.push_back(res.first);
edges.push_back(e);
}
}
}
}
begin = end, end = samples.size();
}

if (!directed) {
// TODO (matthias) Use pyg::sampler::subgraph() for this.
for (size_t i = 0; i < samples.size(); ++i) {
const auto v = samples[i];
const auto row_start = rowptr_data[v];
const auto row_end = rowptr_data[v + 1];
for (scalar_t e = row_start; e < row_end; ++e) {
const auto local_node = mapper.map(col_data[v]);
if (local_node != -1) {
rows.push_back(i);
cols.push_back(local_node);
edges.push_back(e);
}
}
}
}

return std::make_tuple(pyg::utils::from_vector<scalar_t>(rows),
pyg::utils::from_vector<scalar_t>(cols),
pyg::utils::from_vector<scalar_t>(samples),
pyg::utils::from_vector<scalar_t>(edges));
});
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t> num_neighbors,
const std::vector<int64_t>& num_neighbors,
bool replace,
bool directed,
bool isolated,
bool disjoint,
bool return_edge_id) {
return std::make_tuple(rowptr, col, seed, at::nullopt);
if (disjoint) {
AT_ERROR("Disjoint subgraphs are currently not supported");
}
if (!return_edge_id) {
AT_ERROR("The indices of edges of the original graph must be returned");
}

if (replace && directed) {
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return sample<true, true>(rowptr, col, seed, num_neighbors);
} else if (replace && !directed) {
return sample<true, false>(rowptr, col, seed, num_neighbors);
} else if (!replace && directed) {
return sample<false, true>(rowptr, col, seed, num_neighbors);
} else {
return sample<false, false>(rowptr, col, seed, num_neighbors);
}
}

} // namespace
Expand Down
17 changes: 12 additions & 5 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t> num_neighbors,
const std::vector<int64_t>& num_neighbors,
bool replace,
bool directed,
bool isolated,
bool disjoint,
bool return_edge_id) {
// TODO (matthias) Add TensorArg definitions.
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
at::TensorArg seed_t{seed, "seed", 1};

at::CheckedFrom c = "neighbor_sample";
at::checkAllDefined(c, {rowptr_t, col_t, seed_t});
at::checkAllSameType(c, {rowptr_t, col_t, seed_t});

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, replace, directed, isolated,
return op.call(rowptr, col, seed, num_neighbors, replace, directed, disjoint,
return_edge_id);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, bool replace, bool directed, bool isolated, bool "
"num_neighbors, bool replace, bool directed, bool disjoint, bool "
"return_edge_id) -> (Tensor, Tensor, Tensor, Tensor?)"));
}

Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
namespace pyg {
namespace sampler {

// Recursively samples neighbors from all nodes indices in `seed`
// Recursively samples neighbors from all node indices in `seed`
// in the graph given by `(rowptr, col)`.
// Returns: (row, col, node_id, edge_id)
PYG_API
std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t> num_neighbors,
const std::vector<int64_t>& num_neighbors,
bool replace = false,
bool directed = true,
bool isolated = true,
bool disjoint = false,
bool return_edge_id = true);

} // namespace sampler
Expand Down
47 changes: 44 additions & 3 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
from typing import Tuple, Optional
from typing import List, Tuple, Optional

import torch
from torch import Tensor


def neighbor_sample(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
num_neighbors: List[int],
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
return_edge_id: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
r"""Recursively samples neighbors from all node indices in :obj:`seed`
in the graph given by :obj:`(rowptr, col)`.

Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
seed (torch.Tensor): The seed node indices.
num_neighbors (List[int]): The number of neighbors to sample for each
node in each iteration. If an entry is set to :obj:`-1`, all
neighbors will be included.
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj:`True` , will create disjoint
subgraphs for every seed node. (default: :obj:`False`)
return_edge_id (bool, optional): If set to :obj:`False`, will not
return the indices of edges of the original graph.
(default: :obj: `True`)

Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]):
Row indices, col indices of the returned subtree/subgraph, as well as
original node indices for all nodes sampled.
In addition, may return the indices of edges of the original graph.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
replace, directed, disjoint,
return_edge_id)


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand All @@ -24,8 +65,7 @@ def subgraph(
Returns:
(torch.Tensor, torch.Tensor, Optional[torch.Tensor]): Compressed source
node indices and target node indices of the induced subgraph.
In addition, may return the indices of edges of the original graph
contained in the induced subgraph.
In addition, may return the indices of edges of the original graph.
"""
return torch.ops.pyg.subgraph(rowptr, col, nodes, return_edge_id)

Expand Down Expand Up @@ -56,6 +96,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,


__all__ = [
'neighbor_sample',
'subgraph',
'random_walk',
]
49 changes: 49 additions & 0 deletions test/sampler/test_neighbor_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch_sparse import SparseTensor

import pyg_lib

neighbor_sample = pyg_lib.sampler.neighbor_sample


def test_neighbor_sample():
adj = SparseTensor.from_edge_index(torch.tensor([[1], [0]]))
rowptr, col, _ = adj.csr()

# Sampling should work:
out = neighbor_sample(rowptr, col, torch.tensor([1]), [1], False, False,
False, True)
assert out[0].tolist() == [0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [1, 0]

# Sampling in a non-directed way should not sample in wrong direction:
out = neighbor_sample(rowptr, col, torch.tensor([0]), [1], False, False,
False, True)
assert out[0].tolist() == []
assert out[1].tolist() == []
assert out[2].tolist() == [0]

# Sampling with more hops:
out = neighbor_sample(rowptr, col, torch.tensor([1]), [1, 1], False, False,
False, True)
assert out[0].tolist() == [0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [1, 0]


def test_neighbor_sample_seed():
rowptr = torch.tensor([0, 3, 5])
col = torch.tensor([0, 1, 2, 0, 1, 0, 2])
input_nodes = torch.tensor([0, 1])

torch.manual_seed(42)
out1 = neighbor_sample(rowptr, col, input_nodes, [1], True, False, False,
True)

torch.manual_seed(42)
out2 = neighbor_sample(rowptr, col, input_nodes, [1], True, False, False,
True)

for data1, data2 in zip(out1, out2):
assert data1.tolist() == data2.tolist()