Skip to content

Commit

Permalink
recover limiao pr of edge_type_limit (PaddlePaddle#317)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
  • Loading branch information
huwei02 and root authored Jun 23, 2023
1 parent a450ac1 commit f33ac34
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 10 deletions.
75 changes: 75 additions & 0 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ DECLARE_bool(graph_get_neighbor_id);
DECLARE_int32(gpugraph_storage_mode);
DECLARE_uint64(gpugraph_slot_feasign_max_num);
DECLARE_bool(graph_metapath_split_opt);
DECLARE_double(graph_neighbor_size_percent);

PADDLE_DEFINE_EXPORTED_bool(graph_edges_split_only_by_src_id,
false,
Expand Down Expand Up @@ -3135,6 +3136,80 @@ void GraphTable::build_graph_total_keys() {
<< graph_total_keys_.size();
}

void GraphTable::calc_edge_type_limit() {
std::vector<uint64_t> graph_type_keys_;
std::vector<int> graph_type_keys_neighbor_size_;
std::vector<std::vector<int>> neighbor_size_array;
neighbor_size_array.resize(task_pool_size_);

int max_neighbor_size;
int neighbor_size_limit;
size_t size_limit;
double neighbor_size_percent = FLAGS_graph_neighbor_size_percent;
for (auto &it: this->edge_to_id) {
graph_type_keys_.clear();
graph_type_keys_neighbor_size_.clear();
for (int i = 0; i < task_pool_size_; i++) {
neighbor_size_array[i].clear();
}
auto edge_type = it.first;
auto edge_idx = it.second;
std::vector<std::vector<uint64_t>> keys;
this->get_all_id(GraphTableType::EDGE_TABLE, edge_idx, 1, &keys);
graph_type_keys_ = std::move(keys[0]);

std::vector<std::vector<uint64_t>> bags(task_pool_size_);
for (int i = 0; i < task_pool_size_; i++) {
auto predsize = graph_type_keys_.size() / task_pool_size_;
bags[i].reserve(predsize * 1.2);
}
for (auto x: graph_type_keys_) {
int location = x % task_pool_size_;
bags[location].push_back(x);
}

std::vector<std::future<int>> tasks;
for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, edge_idx, this]() -> int {
neighbor_size_array[i].reserve(bags[i].size());
for (size_t j = 0; j < bags[i].size(); j++) {
auto node_id = bags[i][j];
Node *v = find_node(GraphTableType::EDGE_TABLE, edge_idx, node_id);
if (v != nullptr) {
int neighbor_size = v->get_neighbor_size();
neighbor_size_array[i].push_back(neighbor_size);
} else {
VLOG(0) << "node id:" << node_id << ", not find in type: " << edge_idx;
}
}
return 0;
}));
}
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
for (int i = 0; i < task_pool_size_; i++) {
graph_type_keys_neighbor_size_.insert(graph_type_keys_neighbor_size_.end(),
neighbor_size_array[i].begin(), neighbor_size_array[i].end());
}
std::sort(graph_type_keys_neighbor_size_.begin(), graph_type_keys_neighbor_size_.end());
if (graph_type_keys_neighbor_size_.size() > 0) {
max_neighbor_size = graph_type_keys_neighbor_size_[graph_type_keys_neighbor_size_.size() - 1];
size_limit = graph_type_keys_neighbor_size_.size() * neighbor_size_percent;
if (size_limit < (graph_type_keys_neighbor_size_.size() - 1)) {
neighbor_size_limit = graph_type_keys_neighbor_size_[size_limit];
} else {
neighbor_size_limit = max_neighbor_size;
}
} else {
neighbor_size_limit = 0;
}
type_to_neighbor_limit_[edge_idx] = neighbor_size_limit;
VLOG(0) << "edge_type: " << edge_type << ", edge_idx[" << edge_idx << "] max neighbor_size: "
<< max_neighbor_size << ", neighbor_size_limit: " << neighbor_size_limit;
}
}

void GraphTable::build_graph_type_keys() {
VLOG(0) << "begin build_graph_type_keys, feature size="
<< this->feature_to_id.size();
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ class GraphTable : public Table {

void build_graph_total_keys();
void build_graph_type_keys();
void calc_edge_type_limit();
void build_node_iter_type_keys();
bool is_key_for_self_rank(const uint64_t &id);
void graph_partition(bool is_edge);
Expand All @@ -768,6 +769,7 @@ class GraphTable : public Table {
std::vector<std::vector<uint64_t>> graph_type_keys_;
std::unordered_map<int, int> type_to_index_;
robin_hood::unordered_set<uint64_t> unique_all_edge_keys_;
std::unordered_map<int, int> type_to_neighbor_limit_;

std::vector<std::vector<GraphShard *>> edge_shards, feature_shards,
node_shards;
Expand Down
28 changes: 24 additions & 4 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,7 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
// 获取全局采样状态
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index();
auto &edge_neighbor_size_limit = gpu_graph_ptr->get_type_to_neighbor_limit();
auto &cursor = gpu_graph_ptr->cursor_[conf.thread_id];
size_t node_type_len = first_node_type.size();
int remain_size = conf.buf_size - conf.walk_degree *
Expand Down Expand Up @@ -2804,11 +2805,11 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
// end sampling current epoch
cursor = 0;
*epoch_finish_ptr = true;
VLOG(0) << "sample epoch finish!";
VLOG(1) << "sample epoch finish!";
break;
} else if (sample_command == EVENT_WALKBUF_FULL) {
// end sampling current pass
VLOG(0) << "sample pass finish!";
VLOG(1) << "sample pass finish!";
break;
} else if (sample_command == EVENT_CONTINUE_SAMPLE) {
// continue sampling
Expand All @@ -2828,11 +2829,18 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
}

NeighborSampleQuery q;
if (edge_neighbor_size_limit.find(path[0]) == edge_neighbor_size_limit.end()) {
VLOG(0) << "Fail to find edge[" << path[0] << "] in edge_neighbor_size_limit";
assert(false);
break;
}
auto neighbor_size_limit = edge_neighbor_size_limit[path[0]];
q.initialize(conf.gpuid,
path[0],
(uint64_t)(d_type_keys + start),
conf.walk_degree,
tmp_len,
neighbor_size_limit,
step);
auto sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(
q, false, true, conf.weighted_sample);
Expand Down Expand Up @@ -2867,7 +2875,7 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
table,
host_vec_ptr,
stream) != 0) {
VLOG(2) << "gpu:" << conf.gpuid
VLOG(0) << "gpu:" << conf.gpuid
<< " in step 0, insert key stage, table is full";
update = false;
assert(false);
Expand All @@ -2882,7 +2890,7 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
table,
host_vec_ptr,
stream) != 0) {
VLOG(2) << "gpu:" << conf.gpuid
VLOG(0) << "gpu:" << conf.gpuid
<< " in step 0, insert sample res, table is full";
update = false;
assert(false);
Expand Down Expand Up @@ -2942,11 +2950,18 @@ int FillWalkBuf(const std::vector<uint64_t> &h_device_keys_len,
sample_keys_ptr = reinterpret_cast<uint64_t *>(sample_key_mem->ptr());
}
int edge_type_id = path[(step - 1) % path_len];
if (edge_neighbor_size_limit.find(edge_type_id) == edge_neighbor_size_limit.end()) {
VLOG(0) << "Fail to find edge[" << path[0] << "] in edge_neighbor_size_limit";
assert(false);
break;
}
neighbor_size_limit = edge_neighbor_size_limit[edge_type_id];
q.initialize(conf.gpuid,
edge_type_id,
(uint64_t)sample_keys_ptr,
1,
sample_res.total_sample_size,
neighbor_size_limit,
step);
int sample_key_len = sample_res.total_sample_size;
sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(
Expand Down Expand Up @@ -3102,6 +3117,7 @@ int FillWalkBufMultiPath(

// 获取全局采样状态
auto &cur_metapath = gpu_graph_ptr->cur_metapath_;
auto &edge_neighbor_size_limit = gpu_graph_ptr->get_type_to_neighbor_limit();
auto &path = gpu_graph_ptr->cur_parse_metapath_;
auto &cur_metapath_start = gpu_graph_ptr->cur_metapath_start_[conf.gpuid];
auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index();
Expand Down Expand Up @@ -3142,11 +3158,13 @@ int FillWalkBufMultiPath(
VLOG(2) << "sample edge type: " << path[0] << " step: " << 1;

NeighborSampleQuery q;
auto neighbor_size_limit = edge_neighbor_size_limit[path[0]];
q.initialize(conf.gpuid,
path[0],
(uint64_t)(d_type_keys + start),
conf.walk_degree,
tmp_len,
neighbor_size_limit,
step);
auto sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(
q, false, true, conf.weighted_sample);
Expand Down Expand Up @@ -3229,11 +3247,13 @@ int FillWalkBufMultiPath(
reinterpret_cast<uint64_t *>(sample_key_mem->ptr());
int edge_type_id = path[(step - 1) % path_len];
VLOG(2) << "sample edge type: " << edge_type_id << " step: " << step;
neighbor_size_limit = edge_neighbor_size_limit[edge_type_id];
q.initialize(conf.gpuid,
edge_type_id,
(uint64_t)sample_keys_ptr,
1,
sample_res.total_sample_size,
neighbor_size_limit,
step);
int sample_key_len = sample_res.total_sample_size;
sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,15 @@ struct NeighborSampleQuery {
int len;
int sample_size;
int sample_step;
int neighbor_size_limit;
void initialize(
int gpu_id, int table_idx, uint64_t src_nodes, int sample_size, int len, int sample_step=1) {
int gpu_id, int table_idx, uint64_t src_nodes, int sample_size, int len, int neighbor_size_limit, int sample_step=1) {
this->table_idx = table_idx;
this->gpu_id = gpu_id;
this->src_nodes = reinterpret_cast<uint64_t *>(src_nodes);
this->len = len;
this->sample_size = sample_size;
this->neighbor_size_limit = neighbor_size_limit;
this->sample_step = sample_step;
}
void display() {
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ class GpuPsGraphTable
NeighborSampleResult graph_neighbor_sample(int gpu_id,
uint64_t *key,
int sample_size,
int len);
int len,
int neighbor_size_limit);
NeighborSampleResult graph_neighbor_sample_v2(int gpu_id,
int idx,
uint64_t *key,
int sample_size,
int len,
int neighbor_size_limit,
bool cpu_query_switch,
bool compress,
bool weighted);
Expand All @@ -128,6 +130,7 @@ class GpuPsGraphTable
uint64_t *key,
int sample_size,
int len,
int neighbor_size_limit,
bool cpu_query_switch,
bool compress,
bool weighted);
Expand Down
21 changes: 18 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ __global__ void neighbor_sample_kernel_walking(GpuPsCommGraph graph,
uint64_t* res,
int sample_len,
int n,
int neighbor_size_limit,
int default_value) {
// graph: The corresponding edge table.
// node_info_list: The input node query, duplicate nodes allowed.
Expand Down Expand Up @@ -243,7 +244,13 @@ __global__ void neighbor_sample_kernel_walking(GpuPsCommGraph graph,
res[offset + j] = j;
}
__syncwarp();
for (int j = sample_len + threadIdx.x; j < neighbor_len; j += WARP_SIZE) {
int neighbor_num;
if (neighbor_len > neighbor_size_limit) {
neighbor_num = neighbor_size_limit;
} else {
neighbor_num = neighbor_len;
}
for (int j = sample_len + threadIdx.x; j < neighbor_num; j += WARP_SIZE) {
const int num = curand(&rng) % (j + 1);
if (num < sample_len) {
atomicMax(reinterpret_cast<unsigned int*>(res + offset + num),
Expand Down Expand Up @@ -1969,6 +1976,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3(
q.src_nodes,
q.sample_size,
q.len,
q.neighbor_size_limit,
cpu_switch,
compress,
weighted);
Expand All @@ -1980,6 +1988,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3(
q.src_nodes,
q.sample_size,
q.len,
q.neighbor_size_limit,
cpu_switch,
compress,
weighted);
Expand All @@ -1992,6 +2001,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3(
q.src_nodes,
q.sample_size,
q.len,
q.neighbor_size_limit,
cpu_switch,
compress,
weighted);
Expand All @@ -2002,9 +2012,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3(
NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
uint64_t* key,
int sample_size,
int len) {
int len,
int neighbor_size_limit) {
return graph_neighbor_sample_v2(
gpu_id, 0, key, sample_size, len, false, true, false);
gpu_id, 0, key, sample_size, len, neighbor_size_limit, false, true, false);
}

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_all2all(
Expand All @@ -2014,6 +2025,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_all2all(
uint64_t* d_keys,
int sample_size,
int len,
int neighbor_size_limit,
bool cpu_query_switch,
bool compress,
bool weighted) {
Expand All @@ -2034,6 +2046,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_all2all(
loc.d_merged_keys,
sample_size,
pull_size,
neighbor_size_limit,
cpu_query_switch,
compress,
weighted);
Expand Down Expand Up @@ -2142,6 +2155,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
uint64_t* key,
int sample_size,
int len,
int neighbor_size_limit,
bool cpu_query_switch,
bool compress,
bool weighted) {
Expand Down Expand Up @@ -2288,6 +2302,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
sample_array,
sample_size,
shard_len,
neighbor_size_limit,
default_value);
} else {
// Weighted sample.
Expand Down
Loading

0 comments on commit f33ac34

Please sign in to comment.