Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement neighbor sampler #76

Merged
merged 9 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 128 additions & 2 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,147 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <torch/script.h>

#include "../neighbor.h"
#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "utils.h"

#include "parallel_hashmap/phmap.h"

#ifdef _WIN32
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
#include <process.h>
#endif

namespace pyg {
namespace sampler {

namespace {

template <bool replace, bool directed>
std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors) {
// Initialize some data structures for the sampling process:
std::vector<int64_t> samples, rows, cols, edges;
phmap::flat_hash_map<int64_t, int64_t> to_local_node;

const auto num_nodes = rowptr.size(0) - 1;
AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "neighbor_kernel", [&] {
auto* rowptr_data = rowptr.data_ptr<scalar_t>();
auto* col_data = col.data_ptr<scalar_t>();
auto* seed_data = seed.data_ptr<scalar_t>();

for (int64_t i = 0; i < seed.numel(); i++) {
const auto& v = seed_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
}

int64_t begin = 0, end = samples.size();
for (int64_t ell = 0; ell < static_cast<int64_t>(num_neighbors.size());
ell++) {
const auto& num_samples = num_neighbors[ell];
for (int64_t i = begin; i < end; i++) {
const auto& w = samples[i];
const auto& row_start = rowptr_data[w];
const auto& row_end = rowptr_data[w + 1];
const auto row_count = row_end - row_start;

if (row_count == 0)
continue;

if ((num_samples < 0) || (!replace && (num_samples >= row_count))) {
for (int64_t offset = row_start; offset < row_end; offset++) {
const int64_t& v = col_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
samples.push_back(v);
if (directed) {
rows.push_back(i);
cols.push_back(res.first->second);
edges.push_back(offset);
}
}
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = row_start + uniform_randint(row_count);
const int64_t& v = col_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
}
}
} else {
std::unordered_set<int64_t> rnd_indices;
for (int64_t j = row_count - num_samples; j < row_count; j++) {
int64_t rnd = uniform_randint(j);
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
}
const int64_t offset = row_start + rnd;
const int64_t& v = col_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
samples.push_back(v);
if (directed) {
rows.push_back(i);
cols.push_back(res.first->second);
edges.push_back(offset);
}
}
}
}
begin = end, end = samples.size();
}

if (!directed) {
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (int64_t i = 0; i < static_cast<int64_t>(samples.size()); i++) {
const auto& w = samples[i];
const auto& row_start = rowptr_data[w];
const auto& row_end = rowptr_data[w + 1];
for (int64_t offset = row_start; offset < row_end; offset++) {
const auto& v = col_data[offset];
iter = to_local_node.find(v);
if (iter != to_local_node.end()) {
cols.push_back(iter->second);
rows.push_back(i);
edges.push_back(offset);
}
}
}
}
});
return std::make_tuple(from_vector<int64_t>(rows), from_vector<int64_t>(cols),
from_vector<int64_t>(samples),
from_vector<int64_t>(edges));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t> num_neighbors,
const std::vector<int64_t>& num_neighbors,
bool replace,
bool directed,
bool isolated,
bool return_edge_id) {
return std::make_tuple(rowptr, col, seed, at::nullopt);
if (replace && directed) {
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return sample<true, true>(rowptr, col, seed, num_neighbors);
} else if (replace && !directed) {
return sample<true, false>(rowptr, col, seed, num_neighbors);
} else if (!replace && directed) {
return sample<false, true>(rowptr, col, seed, num_neighbors);
} else {
return sample<false, false>(rowptr, col, seed, num_neighbors);
}
}

} // namespace
Expand Down
36 changes: 36 additions & 0 deletions pyg_lib/csrc/sampler/cpu/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <torch/torch.h>

#include "parallel_hashmap/phmap.h"

template <typename scalar_t>
inline torch::Tensor from_vector(const std::vector<scalar_t>& vec,
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
bool inplace = false) {
const auto size = (int64_t)vec.size();
const auto out = torch::from_blob((scalar_t*)vec.data(), {size},
c10::CppTypeToScalarType<scalar_t>::value);
return inplace ? out : out.clone();
}

template <typename key_t, typename scalar_t>
inline c10::Dict<key_t, torch::Tensor> from_vector(
const phmap::flat_hash_map<key_t, std::vector<scalar_t>>& vec_dict,
bool inplace = false) {
c10::Dict<key_t, torch::Tensor> out_dict;
for (const auto& kv : vec_dict)
out_dict.insert(kv.first, from_vector<scalar_t>(kv.second, inplace));
return out_dict;
}

inline int64_t uniform_randint(int64_t low, int64_t high) {
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
CHECK_LT(low, high);
auto options = torch::TensorOptions().dtype(torch::kInt64);
auto ret = torch::randint(low, high, {1}, options);
auto ptr = ret.data_ptr<int64_t>();
return *ptr;
}

inline int64_t uniform_randint(int64_t high) {
return uniform_randint(0, high);
}
18 changes: 15 additions & 3 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#include "neighbor.h"
#ifdef WITH_PYTHON
#include <Python.h>
#endif

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <torch/script.h>

#include "neighbor.h"

namespace pyg {
namespace sampler {
Expand All @@ -10,12 +15,19 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t> num_neighbors,
const std::vector<int64_t>& num_neighbors,
bool replace,
bool directed,
bool isolated,
bool return_edge_id) {
// TODO (matthias) Add TensorArg definitions.
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
at::TensorArg seed_t{seed, "seed", 1};

at::CheckedFrom c = "neighbor_sample";
at::checkAllDefined(c, {rowptr_t, col_t, seed_t});
at::checkAllSameType(c, {rowptr_t, col_t, seed_t});

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t> num_neighbors,
const std::vector<int64_t>& num_neighbors,
bool replace = false,
bool directed = true,
bool isolated = true,
Expand Down
36 changes: 36 additions & 0 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@
from torch import Tensor


def neighbor_sample(
rowptr: Tensor, col: Tensor, seed: Tensor, num_neighbors: list[int],
replace: bool = False, directed: bool = True, isolated: bool = True,
return_edge_id: bool = True
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
r"""Recursively samples neighbors from all nodes indices in :obj:`seed`
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
in the graph given by :obj:`(rowptr, col)`.

Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
seed (torch.Tensor):
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
num_neighbors (list[int]): The number of neighbors to sample for each
node in each iteration. In heterogeneous graphs, may also take in a
dictionary denoting the amount of neighbors to sample for each
individual edge type. If an entry is set to :obj:`-1`, all neighbors
will be included.
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
isolated (bool, optional): #TODO(kgajdamo): add description
kgajdamo marked this conversation as resolved.
Show resolved Hide resolved
return_edge_id (bool, optional): If set to :obj:`False`, will not
return the indices of edges of the original graph.
(default: :obj: `True`)

Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]):
#TODO(kgajdamo): add description
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
replace, directed, isolated,
return_edge_id)


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand Down Expand Up @@ -56,6 +91,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,


__all__ = [
'neighbor_sample',
'subgraph',
'random_walk',
]
49 changes: 49 additions & 0 deletions test/csrc/sampler/test_neighbor_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we can just test in C++. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but unfortunately, when I want to compile pyg-lib with unit tests I am receiving undefined reference error. I wonder if this is only happening on my side?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show me the log?

Copy link
Contributor Author

@kgajdamo kgajdamo Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: CMakeFiles/test_biased_random.dir/test/csrc/random/test_biased_random.cpp.o: in function `BiasedSamplingCDFConversionTest_BasicAssertions_Test::TestBody()':
test_biased_random.cpp:(.text._ZN52BiasedSamplingCDFConversionTest_BasicAssertions_Test8TestBodyEv+0x344): undefined reference to `testing::internal::GetBoolAssertionFailureMessage[abi:cxx11](testing::AssertionResult const&, char const*, char const*, char const*)'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN52BiasedSamplingCDFConversionTest_BasicAssertions_Test8TestBodyEv+0x479): undefined reference to `testing::internal::GetBoolAssertionFailureMessage[abi:cxx11](testing::AssertionResult const&, char const*, char const*, char const*)'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: CMakeFiles/test_biased_random.dir/test/csrc/random/test_biased_random.cpp.o: in function `BiasedSamplingAliasConversionTest_BasicAssertions_Test::TestBody()':
test_biased_random.cpp:(.text._ZN54BiasedSamplingAliasConversionTest_BasicAssertions_Test8TestBodyEv+0x90b): undefined reference to `testing::internal::GetBoolAssertionFailureMessage[abi:cxx11](testing::AssertionResult const&, char const*, char const*, char const*)'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN54BiasedSamplingAliasConversionTest_BasicAssertions_Test8TestBodyEv+0x9e3): undefined reference to `testing::internal::GetBoolAssertionFailureMessage[abi:cxx11](testing::AssertionResult const&, char const*, char const*, char const*)'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN54BiasedSamplingAliasConversionTest_BasicAssertions_Test8TestBodyEv+0xabd): undefined reference to `testing::internal::GetBoolAssertionFailureMessage[abi:cxx11](testing::AssertionResult const&, char const*, char const*, char const*)'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: CMakeFiles/test_biased_random.dir/test/csrc/random/test_biased_random.cpp.o:test_biased_random.cpp:(.text._ZN54BiasedSamplingAliasConversionTest_BasicAssertions_Test8TestBodyEv+0xb9d): more undefined references to `testing::internal::GetBoolAssertionFailureMessage[abi:cxx11](testing::AssertionResult const&, char const*, char const*, char const*)' follow
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: CMakeFiles/test_biased_random.dir/test/csrc/random/test_biased_random.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperOpFailure<int, double>(char const*, char const*, int const&, double const&, char const*)':
test_biased_random.cpp:(.text._ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_[_ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_]+0xa4): undefined reference to `testing::Message::GetString[abi:cxx11]() const'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_[_ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_]+0x15d): undefined reference to `testing::Message::GetString[abi:cxx11]() const'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_[_ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_]+0x1ff): undefined reference to `testing::Message::GetString[abi:cxx11]() const'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_[_ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_]+0x2b0): undefined reference to `testing::Message::GetString[abi:cxx11]() const'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: test_biased_random.cpp:(.text._ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_[_ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_]+0x352): undefined reference to `testing::Message::GetString[abi:cxx11]() const'
/home/kgajdamo/miniconda3/envs/pyg/bin/../lib/gcc/x86_64-conda-linux-gnu/9.3.0/../../../../x86_64-conda-linux-gnu/bin/ld: CMakeFiles/test_biased_random.dir/test/csrc/random/test_biased_random.cpp.o:test_biased_random.cpp:(.text._ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_[_ZN7testing8internal18CmpHelperOpFailureIidEENS_15AssertionResultEPKcS4_RKT_RKT0_S4_]+0x40b): more undefined references to `testing::Message::GetString[abi:cxx11]() const' follow
collect2: error: ld returned 1 exit status
make[2]: *** [CMakeFiles/test_biased_random.dir/build.make:102: test_biased_random] Error 1
make[1]: *** [CMakeFiles/Makefile2:204: CMakeFiles/test_biased_random.dir/all] Error 2
make: *** [Makefile:146: all] Error 2

Copy link
Member

@rusty1s rusty1s Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZenoTan Did you see this before? If not, let's ignore C++ tests for now :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't meet this but it looks like the linkage to gtest failed. It's fine to ignore though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also get this error during compilation or is it just me?

Copy link
Member

@ZenoTan ZenoTan Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not get exactly the same, but similar issues when I tried to build other libraries.

Copy link
Member

@DamianSzwichtenberg DamianSzwichtenberg Aug 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kgajdamo @rusty1s I noticed this error occurs when we have a mix of CF/anaconda c++ packages in our toolchain. Compiling cpp tests in simple venv works well.

from torch_sparse import SparseTensor

import pyg_lib

neighbor_sample = pyg_lib.sampler.neighbor_sample


def test_neighbor_sample(): # row, col
adj = SparseTensor.from_edge_index(torch.tensor([[0], [1]]))
rowptr, col, _ = adj.csr()

# Sampling in a non-directed way should not sample in wrong direction:
out = neighbor_sample(rowptr, col, torch.tensor([1]), [1], False, False,
False, True)
assert out[0].tolist() == []
assert out[1].tolist() == []
assert out[2].tolist() == [1]

# Sampling should work:
out = neighbor_sample(rowptr, col, torch.tensor([0]), [1], False, False,
False, True)
assert out[0].tolist() == [0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0, 1]

# Sampling with more hops:
out = neighbor_sample(rowptr, col, torch.tensor([0]), [1, 1], False, False,
False, True)
assert out[0].tolist() == [0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0, 1]


def test_neighbor_sample_seed():
rowptr = torch.tensor([0, 3, 5])
col = torch.tensor([0, 1, 2, 0, 1, 0, 2])
input_nodes = torch.tensor([0, 1])

torch.manual_seed(42)
out1 = neighbor_sample(rowptr, col, input_nodes, [1], True, False, False,
True)

torch.manual_seed(42)
out2 = neighbor_sample(rowptr, col, input_nodes, [1], True, False, False,
True)

for data1, data2 in zip(out1, out2):
assert data1.tolist() == data2.tolist()