Skip to content

Commit

Permalink
Optimization for graph_sample_neighbors API (PaddlePaddle#41447) (Pad…
Browse files Browse the repository at this point in the history
…dlePaddle#41897)

* add eids result for graph_sample_neighbors

* fix bug

* move fisher_yates sample to warp

* add cpu eid output

* delete comment

* delete comment

* change nullptr placeholder

* optimize sample kernel

* fix mutable_data
  • Loading branch information
DesmonDay authored Apr 19, 2022
1 parent 6449a23 commit 6115b01
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 78 deletions.
100 changes: 89 additions & 11 deletions paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,42 @@ void SampleUniqueNeighbors(
}
}

template <class bidiiter>
void SampleUniqueNeighborsWithEids(
bidiiter src_begin,
bidiiter src_end,
bidiiter eid_begin,
bidiiter eid_end,
int num_samples,
std::mt19937& rng,
std::uniform_int_distribution<int>& dice_distribution) {
int left_num = std::distance(src_begin, src_end);
for (int i = 0; i < num_samples; i++) {
bidiiter r1 = src_begin, r2 = eid_begin;
int random_step = dice_distribution(rng) % left_num;
std::advance(r1, random_step);
std::advance(r2, random_step);
std::swap(*src_begin, *r1);
std::swap(*eid_begin, *r2);
++src_begin;
++eid_begin;
--left_num;
}
}

template <typename T>
void SampleNeighbors(const T* row,
const T* col_ptr,
const T* eids,
const T* input,
std::vector<T>* output,
std::vector<int>* output_count,
std::vector<T>* output_eids,
int sample_size,
int bs) {
// Allocate the memory of output
// Collect the neighbors size
int bs,
bool return_eids) {
std::vector<std::vector<T>> out_src_vec;
std::vector<std::vector<T>> out_eids_vec;
// `sample_cumsum_sizes` record the start position and end position
// after sampling.
std::vector<int> sample_cumsum_sizes(bs + 1);
Expand All @@ -65,10 +90,18 @@ void SampleNeighbors(const T* row,
std::vector<T> out_src;
out_src.resize(cap);
out_src_vec.emplace_back(out_src);
if (return_eids) {
std::vector<T> out_eids;
out_eids.resize(cap);
out_eids_vec.emplace_back(out_eids);
}
}

output_count->resize(bs);
output->resize(total_neighbors);
if (return_eids) {
output_eids->resize(total_neighbors);
}

std::random_device rd;
std::mt19937 rng{rd()};
Expand All @@ -85,15 +118,28 @@ void SampleNeighbors(const T* row,
int cap = end - begin;
if (sample_size < cap) {
std::copy(row + begin, row + end, out_src_vec[i].begin());
// TODO(daisiming): Check whether is correct.
SampleUniqueNeighbors(out_src_vec[i].begin(),
out_src_vec[i].end(),
sample_size,
rng,
dice_distribution);
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
SampleUniqueNeighborsWithEids(out_src_vec[i].begin(),
out_src_vec[i].end(),
out_eids_vec[i].begin(),
out_eids_vec[i].end(),
sample_size,
rng,
dice_distribution);
} else {
SampleUniqueNeighbors(out_src_vec[i].begin(),
out_src_vec[i].end(),
sample_size,
rng,
dice_distribution);
}
*(output_count->data() + i) = sample_size;
} else {
std::copy(row + begin, row + end, out_src_vec[i].begin());
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
}
*(output_count->data() + i) = cap;
}
}
Expand All @@ -107,6 +153,11 @@ void SampleNeighbors(const T* row,
std::copy(out_src_vec[i].begin(),
out_src_vec[i].begin() + k,
output->data() + sample_cumsum_sizes[i]);
if (return_eids) {
std::copy(out_eids_vec[i].begin(),
out_eids_vec[i].begin() + k,
output_eids->data() + sample_cumsum_sizes[i]);
}
}
}

Expand All @@ -131,8 +182,35 @@ void GraphSampleNeighborsKernel(

std::vector<T> output;
std::vector<int> output_count;
SampleNeighbors<T>(
row_data, col_ptr_data, x_data, &output, &output_count, sample_size, bs);

if (return_eids) {
const T* eids_data = eids.get_ptr()->data<T>();
std::vector<T> output_eids;
SampleNeighbors<T>(row_data,
col_ptr_data,
eids_data,
x_data,
&output,
&output_count,
&output_eids,
sample_size,
bs,
return_eids);
out_eids->Resize({static_cast<int>(output_eids.size())});
T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
std::copy(output_eids.begin(), output_eids.end(), out_eids_data);
} else {
SampleNeighbors<T>(row_data,
col_ptr_data,
nullptr,
x_data,
&output,
&output_count,
nullptr,
sample_size,
bs,
return_eids);
}
out->Resize({static_cast<int>(output.size())});
T* out_data = dev_ctx.template Alloc<T>(out);
std::copy(output.begin(), output.end(), out_data);
Expand Down
Loading

0 comments on commit 6115b01

Please sign in to comment.