Skip to content

Commit

Permalink
Gpugraph.0621 (PaddlePaddle#59)
Browse files Browse the repository at this point in the history
Optimize CUDA thread parallelism in MergeGrad phase,6.5h-3.25h
  • Loading branch information
lxsbupt authored Jul 4, 2022
1 parent e7f3193 commit 1ce2252
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 32 deletions.
18 changes: 11 additions & 7 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h"
#include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include <xpu/runtime.h>
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
Expand Down Expand Up @@ -55,16 +54,20 @@ class HeterComm {
HeterComm& operator=(const HeterComm&) = delete;

void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len,
int* left, int* right, int gpu_num);
int* left, int* right, int gpu_num);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); // NOLINT
int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads,
size_t len, int& uniq_len);
size_t len, int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad);
void segment_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads,
const uint32_t* d_index, size_t len,
const uint32_t* d_fea_num_info,
size_t uniq_len, size_t& segment_len);
void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num, int offset = -1);
size_t chunk_size, int stream_num, int offset = -1);
void build_ps(int num, KeyType* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size, int stream_num);
size_t feature_value_size, size_t chunk_size, int stream_num);
void dump();
void show_one_table(int gpu_num);
void show_table_collisions();
Expand Down Expand Up @@ -237,7 +240,6 @@ class HeterComm {
char* src_val, size_t val_size);


CommonFeatureValueAccessor feature_value_accessor_;
protected:
using Table = HashTable<KeyType, ValType>;
using PtrTable = HashTable<KeyType, float*>;
Expand All @@ -249,6 +251,8 @@ class HeterComm {
int block_size_{256};
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;

CommonFeatureValueAccessor feature_value_accessor_;

private:
int topo_aware_{0};
std::vector<LocalStorage> storage_;
Expand Down
168 changes: 146 additions & 22 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ limitations under the License. */

DECLARE_double(gpugraph_hbm_table_load_factor);
DECLARE_bool(gpugraph_enable_gpu_direct_access);
DECLARE_bool(gpugraph_enable_segment_merge_grads);
DECLARE_uint64(gpugraph_merge_grads_segment_size);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -621,31 +623,25 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
int gpu_num, KeyType* d_keys, float* d_grads, size_t len,
int& uniq_len) {
int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad) {
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);

size_t temp_storage_bytes;

size_t grad_dim = max_mf_dim_;
size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_));

auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());

auto d_merge_grads = memory::Alloc(place, len * grad_value_size);
float* d_merge_grads_ptr =
reinterpret_cast<float*>(d_merge_grads->ptr());

auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1));
uint32_t* d_fea_num_info_ptr =
reinterpret_cast<uint32_t*>(d_fea_num_info->ptr());
uint32_t* d_index = (uint32_t*)&d_fea_num_info_ptr[len];
uint32_t* d_idx = (uint32_t*)&d_index[len];
int* d_merged_size = (int*)&d_idx[len];
int grid_size = (len - 1) / block_size_ + 1;
heter_comm_kernel_->fill_idx(d_idx, len, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len,
Expand Down Expand Up @@ -686,13 +682,135 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset,
uniq_len, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

if (enable_segment_merge_grad) {
segment_merge_grad(
gpu_num,
d_merge_keys_ptr, d_grads, d_index, len,
d_fea_num_info_ptr, uniq_len,
segment_len);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_keys, d_merge_keys_ptr,
sizeof(KeyType) * segment_len,
cudaMemcpyDeviceToDevice, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
} else {
auto d_merge_grads = memory::Alloc(place, len * grad_value_size);
float* d_merge_grads_ptr = reinterpret_cast<float*>(d_merge_grads->ptr());

heter_comm_kernel_->merge_gradient(
d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
grad_value_size * uniq_len,
cudaMemcpyDeviceToDevice, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
}
}

template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::segment_merge_grad(
int gpu_num, // the device number
KeyType* d_keys, // the sorted keys list, which will be modified after merged
float* d_grads, // the raw grads list, which will be modified after merged
const uint32_t* d_index, // the storage position of d_keys, its length is len.
size_t len, // the number of raw input keys
const uint32_t* d_fea_num_info, // prefix sum array, its length is uniq_len+1
size_t uniq_len, // the number of unique keys
size_t& segments_num) { // the number of segment merged keys

int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);

auto grad_dim = max_mf_dim_;
auto grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_));

auto d_buffer1 = memory::Alloc(place, sizeof(uint32_t) * len);
auto d_segments = reinterpret_cast<uint32_t*>(d_buffer1->ptr());
auto d_buffer2 = memory::Alloc(place, sizeof(uint32_t) * len);
auto d_segments_offset = reinterpret_cast<uint32_t*>(d_buffer2->ptr());
auto d_buffer3 = memory::Alloc(place, sizeof(uint32_t) * len);
auto d_segments_fea_num_info = reinterpret_cast<uint32_t*>(d_buffer3->ptr());
auto d_buffer4 = memory::Alloc(place, sizeof(uint32_t) * len);
auto d_segments_fea_num_offset = reinterpret_cast<uint32_t*>(d_buffer4->ptr());
auto d_buffer5 = memory::Alloc(place, sizeof(uint32_t));
auto d_segments_num = reinterpret_cast<uint32_t*>(d_buffer5->ptr());
CUDA_CHECK(cudaMemsetAsync(d_segments_num, 0, sizeof(uint32_t), stream));

uint32_t segment_size = FLAGS_gpugraph_merge_grads_segment_size;
heter_comm_kernel_->split_segments(
d_fea_num_info, uniq_len,
d_segments,
d_segments_num,
segment_size, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum(
NULL, temp_storage_bytes, d_segments, d_segments_num,
uniq_len, stream));
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum(
d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_num,
uniq_len, stream));
CUDA_CHECK(cudaMemcpyAsync(&segments_num, d_segments_num, sizeof(uint32_t),
cudaMemcpyDeviceToHost, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, d_segments, d_segments_offset,
uniq_len, stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_offset,
uniq_len, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

heter_comm_kernel_->expand_segments(
d_fea_num_info,
d_segments_offset, uniq_len,
d_segments_fea_num_info, segment_size, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, d_segments_fea_num_info, d_segments_fea_num_offset,
segments_num, stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
d_temp_storage->ptr(), temp_storage_bytes, d_segments_fea_num_info, d_segments_fea_num_offset,
segments_num, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

auto d_segments_keys = memory::Alloc(place, sizeof(KeyType) * segments_num);
auto d_segments_keys_ptr = reinterpret_cast<KeyType*>(d_segments_keys->ptr());
heter_comm_kernel_->shrink_keys(
d_keys, d_segments_fea_num_offset,
d_segments_keys_ptr, segments_num,
stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

auto d_segment_grads = memory::Alloc(place, segments_num * grad_value_size);
auto d_segment_grads_ptr = reinterpret_cast<float*>(d_segment_grads->ptr());
heter_comm_kernel_->merge_gradient(
d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream);
d_segments_keys_ptr, d_segments_fea_num_offset, d_segments_fea_num_info, d_index,
(char*)d_grads, (char*)d_segment_grads_ptr, segments_num,
grad_dim, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
grad_value_size * uniq_len,
cudaMemcpyDeviceToDevice, stream));

PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_keys, d_segments_keys_ptr,
sizeof(KeyType) * segments_num,
cudaMemcpyDeviceToDevice, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_segment_grads_ptr,
grad_value_size * segments_num,
cudaMemcpyDeviceToDevice, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
}

Expand All @@ -715,21 +833,17 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
auto d_shard_index_tmp = memory::Alloc(place, len * sizeof(int));
int* d_shard_index_tmp_ptr = reinterpret_cast<int*>(d_shard_index_tmp->ptr());

// int grid_size = (len - 1) / block_size_ + 1;

heter_comm_kernel_->fill_idx(d_idx_tmp_ptr, len, stream);
heter_comm_kernel_->calc_shard_index(d_keys, len, d_shard_index_tmp_ptr,
total_device, stream);

size_t temp_storage_bytes;
const int num_bits = 1 + log2i(total_device);

heter_comm_kernel_->sort_pairs(
NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr,
d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream);

auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);

heter_comm_kernel_->sort_pairs(
d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr,
d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream);
Expand Down Expand Up @@ -856,8 +970,10 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
sync_stream(node.out_stream);
}
}

heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
val_type_size, stream);

sync_stream(stream);
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
Expand Down Expand Up @@ -930,9 +1046,20 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
d_shard_grads_ptr = reinterpret_cast<float*>(d_shard_grads->ptr());

int uniq_len = len;
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len);

int grid_size = (uniq_len - 1) / block_size_ + 1;
size_t segment_len = 0;
if (FLAGS_gpugraph_enable_segment_merge_grads) {
// do two gradient merge
// 1st. do segmented gradient merge
// 2nd. do global gradient merge
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, true);
len = segment_len;
uniq_len = 0;
segment_len = 0;
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, false);
} else {
// Perform gradient merge only once
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, false);
}

split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
dev_num);
Expand Down Expand Up @@ -1067,8 +1194,6 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
int uniq_len = len;
merge_grad(dev_num, d_keys, d_grads, len, uniq_len);

// int grid_size = (uniq_len - 1) / block_size_ + 1;

split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
dev_num);

Expand Down Expand Up @@ -1242,7 +1367,6 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);

// int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
heter_comm_kernel_->fill_shard_grads(
storage.local_keys + merge_num, storage.all_keys + index,
storage.local_grads + merge_num, storage.all_grads + index,
Expand Down
Loading

0 comments on commit 1ce2252

Please sign in to comment.