Skip to content

Commit

Permalink
add get all neighbor (PaddlePaddle#35)
Browse files Browse the repository at this point in the history
* add whether load node and edge parallel flag

* add whether load node and edge parallel flag

* add whether load node and edge parallel flag

* add get all neighbor

* add get all neighbor

* add get all neighbor

* add get all neighbor

Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
  • Loading branch information
miaoli06 and root authored Jun 15, 2022
1 parent 0c33297 commit 0969831
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 1 deletion.
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,30 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int slice
return res;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_neighbor_id(int type_id, int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) {
for (int j = 0; j < search_shards[idx].size(); j++) {
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
[&search_shards, idx, j]() -> std::vector<uint64_t> {
return search_shards[idx][j]->get_all_neighbor_id();
}));
}
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) {
res[(uint64_t)(id) % slice_num].push_back(id);
}
}
return res;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
Expand All @@ -1661,6 +1685,29 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
return res;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_neighbor_id(int type_id, int idx,
int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<uint64_t>>> tasks;
VLOG(0) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (int i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i]() -> std::vector<uint64_t> {
return search_shards[i]->get_all_neighbor_id();
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
VLOG(0) << "end task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
}
return res;
}

int GraphTable::get_all_feature_ids(int type_id, int idx, int slice_num,
std::vector<std::vector<uint64_t>>* output) {
output->resize(slice_num);
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ class GraphShard {
}
return res;
}
std::vector<uint64_t> get_all_neighbor_id() {
std::vector<uint64_t> res;
std::unordered_set<uint64_t> uset;
for (size_t i = 0; i < bucket.size(); i++) {
size_t neighbor_size = bucket[i]->get_neighbor_size();
for (size_t j = 0; j < neighbor_size; j++) {
uset.emplace(bucket[i]->get_neighbor_id(j));
//res.push_back(bucket[i]->get_neighbor_id(j));
}
}
res.assign(uset.begin(), uset.end());
return res;
}
std::set<uint64_t> get_all_feature_ids() {
std::set<uint64_t> total_res;
std::set<uint64_t> res;
Expand Down Expand Up @@ -486,8 +499,11 @@ class GraphTable : public Table {
const std::string &edge_type);

std::vector<std::vector<uint64_t>> get_all_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_neighbor_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_id(int type, int idx,
int slice_num);
std::vector<std::vector<uint64_t>> get_all_neighbor_id(int type_id, int idx,
int slice_num);
int get_all_feature_ids(int type, int idx,
int slice_num, std::vector<std::vector<uint64_t>>* output);
int32_t load_nodes(const std::string &path, std::string node_type);
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/text_format.h"
#include "gflags/gflags.h"
#if (defined PADDLE_WITH_DISTRIBUTE) && (defined PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#endif
Expand All @@ -34,6 +35,8 @@
#endif

USE_INT_STAT(STAT_total_feasign_num_in_mem);
DECLARE_bool(graph_get_neighbor_id);

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -469,6 +472,7 @@ void DatasetImpl<T>::LoadIntoMemory() {
type_total_key[i].push_back(gpu_graph_device_keys[i][j]);
}
}

for (size_t i = 0; i < readers_.size(); i++) {
readers_[i]->SetDeviceKeys(&type_total_key[i], node_idx);
readers_[i]->SetGpuGraphMode(gpu_graph_mode_);
Expand Down Expand Up @@ -505,6 +509,14 @@ void DatasetImpl<T>::LoadIntoMemory() {
gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]);
}
}
if (FLAGS_graph_get_neighbor_id) {
auto gpu_graph_neighbor_keys = gpu_graph_ptr->get_all_neighbor_id(0, edge_idx, thread_num_);
for (size_t i = 0; i < gpu_graph_neighbor_keys.size(); i++) {
for (size_t k = 0; k < gpu_graph_neighbor_keys[i].size(); k++) {
gpu_graph_total_keys_.push_back(gpu_graph_neighbor_keys[i][k]);
}
}
}
}

} else {
Expand Down
15 changes: 14 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,27 @@ std::vector<std::vector<uint64_t>> GraphGpuWrapper::get_all_id(int type,
->cpu_graph_table_->get_all_id(type, slice_num);
}

std::vector<std::vector<uint64_t>> GraphGpuWrapper::get_all_neighbor_id(int type,
int slice_num) {
return ((GpuPsGraphTable *)graph_table)
->cpu_graph_table_->get_all_neighbor_id(type, slice_num);
}

std::vector<std::vector<uint64_t>> GraphGpuWrapper::get_all_id(int type,
int idx,
int slice_num) {
return ((GpuPsGraphTable *)graph_table)
->cpu_graph_table_->get_all_id(type, idx, slice_num);
}


std::vector<std::vector<uint64_t>> GraphGpuWrapper::get_all_neighbor_id(int type,
int idx,
int slice_num) {
return ((GpuPsGraphTable *)graph_table)
->cpu_graph_table_->get_all_neighbor_id(type, idx, slice_num);
}

int GraphGpuWrapper::get_all_feature_ids(int type, int idx, int slice_num,
std::vector<std::vector<uint64_t>>* output) {
return ((GpuPsGraphTable *)graph_table)
Expand Down Expand Up @@ -196,7 +210,6 @@ void GraphGpuWrapper::upload_batch(int idx,
for (int i = 0; i < ids.size(); i++) {
GpuPsCommGraph sub_graph =
g->cpu_graph_table_->make_gpu_ps_graph(idx, ids[i]);
// sub_graph.display_on_cpu();
g->build_graph_on_single_gpu(sub_graph, i, idx);
sub_graph.release_on_cpu();
VLOG(0) << "sub graph on gpu " << i << " is built";
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ class GraphGpuWrapper {
void set_search_level(int level);
void init_search_level(int level);
std::vector<std::vector<uint64_t>> get_all_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_neighbor_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_id(int type, int idx,
int slice_num);
std::vector<std::vector<uint64_t>> get_all_neighbor_id(int type, int idx,
int slice_num);
int get_all_feature_ids(int type, int idx, int slice_num,
std::vector<std::vector<uint64_t>>* output);
NodeQueryResult query_node_list(int gpu_id, int idx, int start,
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,19 @@ PADDLE_DEFINE_EXPORTED_bool(
graph_load_in_parallel, false,
"It controls whether load graph node and edge with mutli threads parallely.");

/**
* Distributed related FLAG
* Name: FLAGS_graph_get_neighbor_id
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example:
* Note: Control get all neighbor id when running sub part graph
* If it is not set, do not need get neighbor id when run all part graph
*/
PADDLE_DEFINE_EXPORTED_bool(
graph_get_neighbor_id, false,
"It controls get all neighbor id when running sub part graph.");

/**
* KP kernel related FLAG
* Name: FLAGS_run_kp_kernel
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ void BindGraphGpuWrapper(py::module* m) {
&GraphGpuWrapper::upload_batch))
.def("get_all_id", py::overload_cast<int, int, int>(&GraphGpuWrapper::get_all_id))
.def("get_all_id", py::overload_cast<int, int>(&GraphGpuWrapper::get_all_id))
.def("get_all_neighbor_id", py::overload_cast<int, int, int>(&GraphGpuWrapper::get_all_neighbor_id))
.def("get_all_neighbor_id", py::overload_cast<int, int>(&GraphGpuWrapper::get_all_neighbor_id))
.def("load_next_partition", &GraphGpuWrapper::load_next_partition)
.def("make_partitions", &GraphGpuWrapper::make_partitions)
.def("make_complementary_graph",
Expand Down

0 comments on commit 0969831

Please sign in to comment.