Skip to content

Commit

Permalink
[multi-node] Add multi node degree (PaddlePaddle#308)
Browse files Browse the repository at this point in the history
* change get_node_degree return type

* add multi-node node degree

* fix degree norm
  • Loading branch information
DesmonDay authored and danleifeng committed Sep 12, 2023
1 parent f736018 commit fb932ac
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 19 deletions.
10 changes: 8 additions & 2 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2579,10 +2579,16 @@ std::shared_ptr<phi::Allocation> GetNodeDegree(
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto edge_to_id = gpu_graph_ptr->edge_to_id;
int* node_degree_ptr = reinterpret_cast<int*>(node_degree->ptr());
for (auto &iter : edge_to_id) {
int edge_idx = iter.second;
gpu_graph_ptr->get_node_degree(
conf.gpuid, edge_idx, node_ids, len, node_degree);
auto sub_node_degree = gpu_graph_ptr->get_node_degree(
conf.gpuid, edge_idx, node_ids, len);
int* sub_node_degree_ptr = reinterpret_cast<int*>(sub_node_degree->ptr());
cudaMemcpy(node_degree_ptr + edge_idx * len,
sub_node_degree_ptr,
sizeof(int) * len,
cudaMemcpyDeviceToDevice);
}
return node_degree;
}
Expand Down
13 changes: 10 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,18 @@ class GpuPsGraphTable
bool return_weight);
std::vector<std::shared_ptr<phi::Allocation>> get_edge_type_graph(
int gpu_id, int edge_type_len);
void get_node_degree(int gpu_id,
std::shared_ptr<phi::Allocation> get_node_degree(int gpu_id,
int edge_idx,
uint64_t *key,
int len,
std::shared_ptr<phi::Allocation> node_degree);
int len);
std::shared_ptr<phi::Allocation> get_node_degree_all2all(int gpu_id,
int edge_idx,
uint64_t *key,
int len);
std::shared_ptr<phi::Allocation> get_node_degree_single(int gpu_id,
int edge_idx,
uint64_t *key,
int len);
int get_feature_of_nodes(int gpu_id,
uint64_t *d_walk,
uint64_t *d_offset,
Expand Down
77 changes: 72 additions & 5 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3019,18 +3019,84 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type(
return result;
}

void GpuPsGraphTable::get_node_degree(
std::shared_ptr<phi::Allocation> GpuPsGraphTable::get_node_degree(
int gpu_id,
int edge_idx,
uint64_t* key,
int len,
std::shared_ptr<phi::Allocation> node_degree) {
int* node_degree_ptr =
reinterpret_cast<int*>(node_degree->ptr()) + edge_idx * len;
int len) {
if (multi_node_ && FLAGS_enable_graph_multi_node_sampling) {
// multi node mode
auto node_degree = get_node_degree_all2all(
gpu_id,
edge_idx,
key,
len);
return node_degree;
} else {
auto node_degree = get_node_degree_single(
gpu_id,
edge_idx,
key,
len);
return node_degree;
}
}

std::shared_ptr<phi::Allocation> GpuPsGraphTable::get_node_degree_all2all(
int gpu_id,
int edge_idx,
uint64_t* key,
int len) {
platform::CUDADeviceGuard guard(gpu_id);
auto &loc = storage_[gpu_id];
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
auto stream = resource_->local_stream(gpu_id, 0);

loc.alloc(len, sizeof(int));

// all2all mode begins, init resource, partition keys, pull vals by all2all.

auto pull_size = gather_inter_keys_by_all2all(gpu_id, len, key, stream);
VLOG(2) << "gather_inter_keys_by_all2all sage get_degree finish, pull_size=" << pull_size << ", len=" << len;

// do single-node multi-card get_node_degree
auto result = get_node_degree_single(gpu_id,
edge_idx,
loc.d_merged_keys,
pull_size);

auto node_degree =
memory::AllocShared(place,
len * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));

// all2all mode finish, scatter degree values by all2all
scatter_inter_vals_by_all2all_common(gpu_id,
len,
sizeof(int), //value_bytes
reinterpret_cast<const int*>(result->ptr()), // in
reinterpret_cast<int*>(node_degree->ptr()), // out
reinterpret_cast<int*>(loc.d_merged_vals), // tmp hbm
stream);
return node_degree;
}

std::shared_ptr<phi::Allocation> GpuPsGraphTable::get_node_degree_single(
int gpu_id,
int edge_idx,
uint64_t* key,
int len) {
int total_gpu = resource_->total_device();
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
auto stream = resource_->local_stream(gpu_id, 0);

auto node_degree =
memory::AllocShared(place,
len * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* node_degree_ptr = reinterpret_cast<int*>(node_degree->ptr());

int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT
Expand Down Expand Up @@ -3142,6 +3208,7 @@ void GpuPsGraphTable::get_node_degree(
destroy_storage(gpu_id, i);
}
device_mutex_[gpu_id]->unlock();
return node_degree;
}

NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1051,14 +1051,13 @@ GraphGpuWrapper::get_edge_type_graph(int gpu_id, int edge_type_len) {
->get_edge_type_graph(gpu_id, edge_type_len);
}

void GraphGpuWrapper::get_node_degree(
std::shared_ptr<phi::Allocation> GraphGpuWrapper::get_node_degree(
int gpu_id,
int edge_idx,
uint64_t *key,
int len,
std::shared_ptr<phi::Allocation> node_degree) {
int len) {
return (reinterpret_cast<GpuPsGraphTable *>(graph_table))
->get_node_degree(gpu_id, edge_idx, key, len, node_degree);
->get_node_degree(gpu_id, edge_idx, key, len);
}
void GraphGpuWrapper::set_infer_mode(bool infer_mode) {
if (graph_table != nullptr) {
Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ class GraphGpuWrapper {
std::vector<std::shared_ptr<phi::Allocation>> edge_type_graphs,
bool weighted,
bool return_weight);
void get_node_degree(int gpu_id,
int edge_idx,
uint64_t* key,
int len,
std::shared_ptr<phi::Allocation> node_degree);
std::shared_ptr<phi::Allocation> get_node_degree(int gpu_id,
int edge_idx,
uint64_t* key,
int len);
gpuStream_t get_local_stream(int gpuid);
std::vector<uint64_t> graph_neighbor_sample(
int gpu_id,
Expand Down

0 comments on commit fb932ac

Please sign in to comment.