Skip to content

Commit

Permalink
delete utils file, add mapper to neighbor kernel, replace rand function
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Aug 25, 2022
1 parent bb092a9 commit 6168329
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 66 deletions.
10 changes: 10 additions & 0 deletions pyg_lib/csrc/sampler/cpu/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ class Mapper {
fill(nodes.data_ptr<scalar_t>(), nodes.numel());
}

bool insert_to_local_map(const scalar_t& node, const scalar_t pos) {
if (use_vec) {
const auto res = to_local_vec.insert(to_local_vec.begin() + node, pos);
return res != to_local_vec.end();
} else {
const auto res = to_local_map.insert({node, pos});
return res.second;
}
}

bool exists(const scalar_t& node) {
if (use_vec)
return to_local_vec[node] >= 0;
Expand Down
61 changes: 31 additions & 30 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
#include <torch/script.h>

#include "../neighbor.h"
#include "pyg_lib/csrc/random/cpu/rand_engine.h"
#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "utils.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"

#include "parallel_hashmap/phmap.h"

#ifdef _WIN32
#include <process.h>
#endif

namespace pyg {
namespace sampler {

Expand All @@ -23,26 +20,29 @@ sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors) {
using namespace pyg::utils;

// Initialize some data structures for the sampling process:
std::vector<int64_t> samples, rows, cols, edges;
phmap::flat_hash_map<int64_t, int64_t> to_local_node;

const auto num_nodes = rowptr.size(0) - 1;
AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "neighbor_kernel", [&] {
auto mapper = pyg::sampler::Mapper<scalar_t>(num_nodes, seed.size(0));
mapper.fill(seed);

auto* rowptr_data = rowptr.data_ptr<scalar_t>();
auto* col_data = col.data_ptr<scalar_t>();
auto* seed_data = seed.data_ptr<scalar_t>();

for (int64_t i = 0; i < seed.numel(); i++) {
const auto& v = seed_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
}

int64_t begin = 0, end = samples.size();
for (int64_t ell = 0; ell < static_cast<int64_t>(num_neighbors.size());
ell++) {
const auto& num_samples = num_neighbors[ell];
const auto num_samples = num_neighbors[ell];
for (int64_t i = begin; i < end; i++) {
const auto& w = samples[i];
const auto& row_start = rowptr_data[w];
Expand All @@ -53,46 +53,48 @@ sample(const at::Tensor& rowptr,
continue;

if ((num_samples < 0) || (!replace && (num_samples >= row_count))) {
for (int64_t offset = row_start; offset < row_end; offset++) {
const int64_t& v = col_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
for (scalar_t offset = row_start; offset < row_end; offset++) {
const scalar_t& v = col_data[offset];
const auto res = mapper.insert_to_local_map(v, samples.size());
if (res)
samples.push_back(v);
if (directed) {
rows.push_back(i);
cols.push_back(res.first->second);
cols.push_back(mapper.map(v));
edges.push_back(offset);
}
}
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = row_start + uniform_randint(row_count);
const int64_t& v = col_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
pyg::random::RandintEngine<scalar_t> eng;
const scalar_t offset = row_start + eng(0, row_count);
const scalar_t& v = col_data[offset];
const auto res = mapper.insert_to_local_map(v, samples.size());
if (res)
samples.push_back(v);
if (directed) {
cols.push_back(i);
rows.push_back(res.first->second);
rows.push_back(mapper.map(v));
edges.push_back(offset);
}
}
} else {
std::unordered_set<int64_t> rnd_indices;
std::unordered_set<scalar_t> rnd_indices;
for (int64_t j = row_count - num_samples; j < row_count; j++) {
int64_t rnd = uniform_randint(j);
pyg::random::RandintEngine<scalar_t> eng;
scalar_t rnd = eng(0, j);
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
}
const int64_t offset = row_start + rnd;
const int64_t& v = col_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
const scalar_t offset = row_start + rnd;
const scalar_t& v = col_data[offset];
const auto res = mapper.insert_to_local_map(v, samples.size());
if (res)
samples.push_back(v);
if (directed) {
rows.push_back(i);
cols.push_back(res.first->second);
cols.push_back(mapper.map(v));
edges.push_back(offset);
}
}
Expand All @@ -102,16 +104,15 @@ sample(const at::Tensor& rowptr,
}

if (!directed) {
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (int64_t i = 0; i < static_cast<int64_t>(samples.size()); i++) {
const auto& w = samples[i];
const auto& row_start = rowptr_data[w];
const auto& row_end = rowptr_data[w + 1];
for (int64_t offset = row_start; offset < row_end; offset++) {
for (scalar_t offset = row_start; offset < row_end; offset++) {
const auto& v = col_data[offset];
iter = to_local_node.find(v);
if (iter != to_local_node.end()) {
cols.push_back(iter->second);
const auto global_node = mapper.map(v);
if (global_node != -1) {
cols.push_back(global_node);
rows.push_back(i);
edges.push_back(offset);
}
Expand Down
36 changes: 0 additions & 36 deletions pyg_lib/csrc/sampler/cpu/utils.h

This file was deleted.

0 comments on commit 6168329

Please sign in to comment.