From 0b424ce432419ba99838d8b0f3dab763869290ff Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 15 Aug 2022 08:45:46 +0800 Subject: [PATCH] [GPUPS] opt Push merge (#50) * add merge grad atomic kernel * push merge grad Co-authored-by: yaoxuefeng6 --- .../framework/fleet/heter_ps/heter_comm.h | 7 + .../framework/fleet/heter_ps/heter_comm_inl.h | 395 +++++++++++------- 2 files changed, 246 insertions(+), 156 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 60558f29616658..c65e9a9b210572 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -47,6 +47,12 @@ struct CustomGradMerger { return out; } + template + __device__ __forceinline__ + void copy_all_field(float* output, const float* input, GPUAccessor& gpu_accessor) { + gpu_accessor.PushValueFill(output, input); + } + template __device__ __forceinline__ void copy_basic_field(float* output, const float* input, GPUAccessor& gpu_accessor) { @@ -241,6 +247,7 @@ class HeterComm { std::vector> allocators_; int multi_mf_dim_{8}; int max_mf_dim_ = 8; + int use_merge_atomic_ = 1; }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index b19221230fc2a4..3dc3e683279c13 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_HETERPS //#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace framework { @@ -86,12 +87,9 @@ __global__ void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, } template -__global__ void dy_mf_fill_shard_grads(KeyType* d_shard_keys, - KeyType* d_keys, - float* d_shard_grads, - float* d_grads, - T* idx, - size_t len, +__global__ void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + float* d_shard_grads, float* d_grads, + T* idx, size_t len, size_t grad_value_size, GPUAccessor gpu_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; @@ -106,14 +104,11 @@ __global__ void dy_mf_fill_shard_grads(KeyType* d_shard_keys, // optimized version template <> -__global__ void dy_mf_fill_shard_grads(FeatureKey* d_shard_keys, - FeatureKey* d_keys, - float* d_shard_grads, - float* d_grads, - int* idx, - size_t len, - size_t grad_value_size, - CommonFeatureValueAccessor gpu_accessor) { +__global__ void +dy_mf_fill_shard_grads( + FeatureKey* d_shard_keys, FeatureKey* d_keys, float* d_shard_grads, + float* d_grads, int* idx, size_t len, size_t grad_value_size, + CommonFeatureValueAccessor gpu_accessor) { const size_t i = blockIdx.x * blockDim.y + threadIdx.y; const size_t k = threadIdx.x; if (i < len) { @@ -121,21 +116,17 @@ __global__ void dy_mf_fill_shard_grads -__global__ void merge_gradient_basic_kernel(const uint32_t* offset, - const uint32_t* fea_num, - const uint32_t* index, - const char* input, - char* output, - int n, - size_t grad_value_size, - CustomGradMerger& merger, - GPUAccessor gpu_accessor) { +__global__ void merge_gradient_basic_kernel( + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_value_size, + CustomGradMerger& merger, GPUAccessor gpu_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { uint32_t start = offset[i]; @@ -153,16 +144,11 @@ __global__ void merge_gradient_basic_kernel(const uint32_t* offset, } template -__global__ void merge_gradient_embedx_kernel(const uint32_t* offset, - const uint32_t* fea_num, - const uint32_t* index, - const char* input, - char* output, - int n, - size_t grad_dim, - size_t grad_value_size, - CustomGradMerger& merger, - GPUAccessor gpu_accessor) { +__global__ void merge_gradient_embedx_kernel( + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_dim, + size_t grad_value_size, CustomGradMerger& merger, + GPUAccessor gpu_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { size_t value_idx = i / grad_dim; @@ -181,6 +167,69 @@ __global__ void merge_gradient_embedx_kernel(const uint32_t* offset, } } +template +__global__ void merge_gradient_atomic_kernel( + int uniq_len, const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, int n, + size_t grad_value_size, CustomGradMerger& merger, int hidden_size, + GPUAccessor gpu_accessor) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + size_t value_idx = idx / hidden_size; + size_t off = idx % hidden_size; + int dst = 0; + if (value_idx < offset[1]) { + dst = 0; + } else { + int high = uniq_len - 1; + int low = 1; + while (low < high) { + int mid = (low + high) / 2; + if (value_idx < offset[mid + 1]) { + high = mid; + } else { + low = mid + 1; + } + } + dst = low; + } + + int ori_index = index[value_idx]; + const float* rhs = (float*)(input + size_t(ori_index) * grad_value_size); + float* lhs = (float*)(output + dst * grad_value_size); + switch (off) { + case 0: + lhs[gpu_accessor.common_push_value.SlotIndex()] = + rhs[gpu_accessor.common_push_value.SlotIndex()]; + lhs[gpu_accessor.common_push_value.MfDimIndex()] = + rhs[gpu_accessor.common_push_value.MfDimIndex()]; + paddle::platform::CudaAtomicAdd( + &lhs[gpu_accessor.common_push_value.ShowIndex()], + rhs[gpu_accessor.common_push_value.ShowIndex()]); + break; + case 1: + paddle::platform::CudaAtomicAdd( + &lhs[gpu_accessor.common_push_value.ClickIndex()], + rhs[gpu_accessor.common_push_value.ClickIndex()]); + break; + case 2: + paddle::platform::CudaAtomicAdd( + &lhs[gpu_accessor.common_push_value.EmbedGIndex()], + rhs[gpu_accessor.common_push_value.EmbedGIndex()]); + break; + default: + int embedx_idx = off - 3; + int mf_dim = rhs[gpu_accessor.common_push_value.MfDimIndex()]; + if (embedx_idx < mf_dim) { + paddle::platform::CudaAtomicAdd( + &lhs[gpu_accessor.common_push_value.EmbedGIndex() + embedx_idx], + rhs[gpu_accessor.common_push_value.EmbedGIndex() + embedx_idx]); + } + break; + } + } +} + template __global__ void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, size_t len) { @@ -191,11 +240,8 @@ __global__ void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, } template -__global__ void dy_mf_fill_dvals(float* d_shard_vals, - float* d_vals, - T* idx, - size_t len, - size_t val_size, +__global__ void dy_mf_fill_dvals(float* d_shard_vals, float* d_vals, T* idx, + size_t len, size_t val_size, GPUAccessor gpu_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { @@ -208,12 +254,9 @@ __global__ void dy_mf_fill_dvals(float* d_shard_vals, // optimized version template <> -__global__ void dy_mf_fill_dvals(float* d_shard_vals, - float* d_vals, - int* idx, - size_t len, - size_t val_size, - CommonFeatureValueAccessor gpu_accessor) { +__global__ void dy_mf_fill_dvals( + float* d_shard_vals, float* d_vals, int* idx, size_t len, size_t val_size, + CommonFeatureValueAccessor gpu_accessor) { const size_t i = blockIdx.x * blockDim.y + threadIdx.y; const size_t k = threadIdx.x; if (i < len) { @@ -224,9 +267,11 @@ __global__ void dy_mf_fill_dvals(float* d_shard } } -template +template HeterComm::HeterComm( - size_t capacity, std::shared_ptr resource, GPUAccessor& gpu_accessor) { + size_t capacity, std::shared_ptr resource, + GPUAccessor& gpu_accessor) { VLOG(1) << "Construct new HeterComm"; resource_ = resource; storage_.resize(resource_->total_gpu()); @@ -238,18 +283,17 @@ HeterComm::HeterComm( // 2, 1, 20, (size_t)-1, false, false)); // NOLINT allocators_.push_back(std::make_shared( 8, 1, (unsigned int)-1, (size_t)-1, false, false)); - max_mf_dim_ = resource->max_mf_dim(); - auto accessor_wrapper_ptr = - GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); - size_t val_type_size = - accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_); - size_t grad_type_size = - accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); - VLOG(3) << " HeterComm init, max feature_value_size:" << val_type_size - << ", feature_value_push_size:" << grad_type_size; - auto ptr_table = new PtrTable(capacity / load_factor_); - ptr_table->set_feature_value_size(val_type_size, grad_type_size); - ptr_tables_.push_back(ptr_table); + max_mf_dim_ = resource->max_mf_dim(); + auto accessor_wrapper_ptr = + GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); + size_t val_type_size = + accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_); + size_t grad_type_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); + VLOG(3) << " HeterComm init, max feature_value_size:" << val_type_size + << ", feature_value_push_size:" << grad_type_size; + auto ptr_table = new PtrTable(capacity / load_factor_); + ptr_table->set_feature_value_size(val_type_size, grad_type_size); + ptr_tables_.push_back(ptr_table); if (multi_node_) { storage_[i].init(feanum_, resource_->dev_id(i)); } @@ -257,7 +301,8 @@ HeterComm::HeterComm( init_path(); } -template +template void HeterComm::init_path() { int total_gpu = resource_->total_gpu(); path_.resize(total_gpu); @@ -311,11 +356,10 @@ void HeterComm::init_path() { VLOG(0) << "HeterComm init_path done"; } -template -void HeterComm::create_storage(int start_index, - int end_index, - size_t keylen, - size_t vallen) { +template +void HeterComm::create_storage( + int start_index, int end_index, size_t keylen, size_t vallen) { auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; for (size_t i = 0; i < nodes.size(); ++i) { @@ -334,22 +378,24 @@ void HeterComm::create_storage(int star } } -template -void HeterComm::destroy_storage(int start_index, - int end_index) { +template +void HeterComm::destroy_storage( + int start_index, int end_index) { auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; for (size_t i = 0; i < nodes.size(); ++i) { platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); - PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), - nodes[i].key_storage)); - PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), - nodes[i].val_storage)); + PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree( + resource_->dev_id(nodes[i].gpu_num), nodes[i].key_storage)); + PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree( + resource_->dev_id(nodes[i].gpu_num), nodes[i].val_storage)); } } -template +template void HeterComm::walk_to_dest( int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, GradType* src_val) { @@ -401,7 +447,8 @@ void HeterComm::walk_to_dest( } } -template +template void HeterComm::walk_to_dest( int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, char* src_val, size_t val_size) { @@ -453,7 +500,8 @@ void HeterComm::walk_to_dest( } } -template +template void HeterComm::walk_to_src( int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val) { std::queue que; @@ -503,9 +551,11 @@ void HeterComm::walk_to_src( } } -template +template void HeterComm::walk_to_src( - int start_index, int gpu_num, int* h_left, int* h_right, char* src_val, size_t val_size) { + int start_index, int gpu_num, int* h_left, int* h_right, char* src_val, + size_t val_size) { std::queue que; for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -553,7 +603,8 @@ void HeterComm::walk_to_src( } } -template +template HeterComm::~HeterComm() { if (!multi_mf_dim_) { for (auto& table : tables_) { @@ -569,11 +620,13 @@ HeterComm::~HeterComm() { delete table; table = nullptr; } - } + } } -template -void HeterComm::show_one_table(int gpu_num) { +template +void HeterComm::show_one_table( + int gpu_num) { if (!multi_mf_dim_) { tables_[gpu_num]->show(); } else { @@ -581,7 +634,8 @@ void HeterComm::show_one_table(int gpu_ } } -template +template int HeterComm::log2i(int x) { unsigned res = 0; while (x >>= 1) { @@ -590,12 +644,15 @@ int HeterComm::log2i(int x) { return res; } -template -int HeterComm::get_index_by_devid(int devid) { +template +int HeterComm::get_index_by_devid( + int devid) { return resource_->get_index_by_devid(devid); } -template +template void HeterComm::set_sparse_sgd( const OptimizerConfig& optimizer_config) { for (int i = 0; i < resource_->total_gpu(); ++i) { @@ -608,7 +665,8 @@ void HeterComm::set_sparse_sgd( } } -template +template void HeterComm::set_embedx_sgd( const OptimizerConfig& optimizer_config) { for (int i = 0; i < resource_->total_gpu(); ++i) { @@ -622,12 +680,17 @@ void HeterComm::set_embedx_sgd( } /* -template -void HeterComm::build_ps(int num, KeyType* h_keys, - ValType* h_vals, +template +void HeterComm::build_ps(int num, +KeyType* h_keys, + ValType* +h_vals, size_t len, - size_t chunk_size, - int stream_num) { + size_t +chunk_size, + int stream_num) +{ if (len <= 0) { return; } @@ -676,14 +739,11 @@ void HeterComm::build_ps(int num, KeyTyp } */ -template -void HeterComm::build_ps(int num, - KeyType* h_keys, - char* pool, - size_t len, - size_t feature_value_size, - size_t chunk_size, - int stream_num) { +template +void HeterComm::build_ps( + int num, KeyType* h_keys, char* pool, size_t len, size_t feature_value_size, + size_t chunk_size, int stream_num) { if (len <= 0) { return; } @@ -724,7 +784,8 @@ void HeterComm::build_ps(int num, } } /* -template +template void HeterComm::merge_grad( int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, int& uniq_len) { // NOLINT @@ -776,12 +837,11 @@ void HeterComm::merge_grad( } */ -template -void HeterComm::merge_grad(int gpu_num, - KeyType* d_keys, - GradType* d_grads, - float* mf, size_t len, - int& uniq_len) { +template +void HeterComm::merge_grad( + int gpu_num, KeyType* d_keys, GradType* d_grads, float* mf, size_t len, + int& uniq_len) { platform::Timer timeline; timeline.Start(); int dev_id = resource_->dev_id(gpu_num); @@ -798,11 +858,9 @@ void HeterComm::merge_grad(int gpu_num, KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); auto d_merge_grads = memory::Alloc(place, len * grad_value_size); - float* d_merge_grads_ptr = - reinterpret_cast(d_merge_grads->ptr()); + float* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); - auto d_fea_num_info = - memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); + auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); @@ -866,35 +924,53 @@ void HeterComm::merge_grad(int gpu_num, PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); timeline.Pause(); timeline.Start(); - grid_size = (uniq_len - 1) / block_size_ + 1; + if (!use_merge_atomic_) { + grid_size = (uniq_len - 1) / block_size_ + 1; + + merge_gradient_basic_kernel<<>>( + d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + (char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, + gpu_accessor_); + + const size_t grad_dim = max_mf_dim_; + if (grad_dim > 0) { + int grid_size2 = (uniq_len * grad_dim - 1) / block_size_ + 1; + merge_gradient_embedx_kernel<<>>( + d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + (char*)d_merge_grads_ptr, uniq_len * grad_dim, grad_dim, + grad_value_size, merger_, gpu_accessor_); + } + } else { + // use merge atomic + int hidden_size = max_mf_dim_ + 3; + int N = len * hidden_size; + grid_size = (N - 1) / block_size_ + 1; - merge_gradient_basic_kernel<<>>( - d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, - (char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, gpu_accessor_); + cudaMemsetAsync(d_merge_grads_ptr, 0, uniq_len * grad_value_size, stream); - const size_t grad_dim = max_mf_dim_; - if (grad_dim > 0) { - int grid_size2 = (uniq_len * grad_dim - 1) / block_size_ + 1; - merge_gradient_embedx_kernel<<>>( - d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, (char*)d_merge_grads_ptr, uniq_len * grad_dim, grad_dim, grad_value_size, merger_, gpu_accessor_); + merge_gradient_atomic_kernel<<>>( + uniq_len, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + (char*)d_merge_grads_ptr, N, grad_value_size, merger_, hidden_size, + gpu_accessor_); } PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + timeline.Pause(); timeline.Start(); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpyAsync(d_grads, d_merge_grads_ptr, grad_value_size * uniq_len, - cudaMemcpyDeviceToDevice, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, + grad_value_size * uniq_len, + cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); timeline.Pause(); } -template +template void HeterComm::split_input_to_shard( KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, int gpu_num) { - int total_gpu = resource_->total_gpu(); int dev_id = resource_->dev_id(gpu_num); platform::CUDAPlace place = platform::CUDAPlace(dev_id); @@ -933,11 +1009,10 @@ void HeterComm::split_input_to_shard( cudaStreamSynchronize(stream); } -template -void HeterComm::pull_sparse(int num, - KeyType* d_keys, - ValType* d_vals, - size_t len) { +template +void HeterComm::pull_sparse( + int num, KeyType* d_keys, ValType* d_vals, size_t len) { if (len == 0) { return; } @@ -950,8 +1025,10 @@ void HeterComm::pull_sparse(int num, int grid_size = (len - 1) / block_size_ + 1; - auto h_left_alloc = memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); - auto h_right_alloc = memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); + auto h_left_alloc = + memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); + auto h_right_alloc = + memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); int* h_left = reinterpret_cast(h_left_alloc->ptr()); int* h_right = reinterpret_cast(h_right_alloc->ptr()); @@ -981,9 +1058,9 @@ void HeterComm::pull_sparse(int num, d_keys, d_idx_ptr, len); cudaMemcpyAsync(h_left, d_left_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToHost, stream); cudaMemcpyAsync(h_right, d_right_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); for (int i = 0; i < total_gpu; ++i) { @@ -1012,8 +1089,7 @@ void HeterComm::pull_sparse(int num, ptr_tables_[i]->rwlock_->RDLock(); ptr_tables_[i]->get(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, - resource_->remote_stream(i, num), - gpu_accessor_); + resource_->remote_stream(i, num), gpu_accessor_); } for (int i = 0; i < total_gpu; ++i) { cudaStreamSynchronize(resource_->remote_stream(i, num)); @@ -1024,7 +1100,8 @@ void HeterComm::pull_sparse(int num, time_lines[i].Pause(); } - walk_to_src(num, total_gpu, h_left, h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); + walk_to_src(num, total_gpu, h_left, h_right, + reinterpret_cast(d_shard_vals_ptr), val_type_size); for (int i = 0; i < total_gpu; ++i) { auto& node = path_[num][i].nodes_.front(); @@ -1033,9 +1110,9 @@ void HeterComm::pull_sparse(int num, dim3 block_dims(32, 32); const size_t grid_size_ = (len - 1) / 32 + 1; dim3 grid_dims(grid_size_); - + dy_mf_fill_dvals<<>>( - d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, gpu_accessor_); + d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, gpu_accessor_); cudaStreamSynchronize(stream); for (int i = 0; i < total_gpu; ++i) { @@ -1046,13 +1123,12 @@ void HeterComm::pull_sparse(int num, } } -template +template template -void HeterComm::push_sparse(int gpu_num, - KeyType* d_keys, - GradType* d_grads, - size_t len, - Sgd& sgd) { // NOLINT +void HeterComm::push_sparse( + int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, + Sgd& sgd) { // NOLINT if (len == 0) { return; } @@ -1069,8 +1145,10 @@ void HeterComm::push_sparse(int gpu_num // int h_left[total_gpu]; // NOLINT // int h_right[total_gpu]; // NOLINT - auto h_left_alloc = memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); - auto h_right_alloc = memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); + auto h_left_alloc = + memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); + auto h_right_alloc = + memory::Alloc(phi::GPUPinnedPlace(), sizeof(int) * total_gpu); int* h_left = reinterpret_cast(h_left_alloc->ptr()); int* h_right = reinterpret_cast(h_right_alloc->ptr()); @@ -1100,17 +1178,17 @@ void HeterComm::push_sparse(int gpu_num gpu_num); dim3 block_dims(32, 32); - const size_t grid_size_ = (uniq_len - 1) / 32 + 1; + const size_t grid_size_ = (uniq_len - 1) / 32 + 1; dim3 grid_dims(grid_size_); dy_mf_fill_shard_grads<<>>( - d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, - uniq_len, grad_value_size, gpu_accessor_); + d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len, + grad_value_size, gpu_accessor_); cudaMemcpyAsync(h_left, d_left_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToHost, stream); cudaMemcpyAsync(h_right, d_right_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); for (int i = 0; i < total_gpu; ++i) { @@ -1161,7 +1239,8 @@ void HeterComm::push_sparse(int gpu_num } } -template +template template void HeterComm::update_one_table( int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, @@ -1179,7 +1258,8 @@ void HeterComm::update_one_table( cudaStreamSynchronize(resource_->remote_stream(gpu_num, gpu_num)); } -template +template template void HeterComm::push_sparse_multi_node( int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, @@ -1200,7 +1280,8 @@ void HeterComm::push_sparse_multi_node( storage_[gpu_num].local_grads, uniq_len, sgd); } -template +template int HeterComm::gather_one_node_grad( int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { int total_gpu = resource_->total_gpu(); @@ -1286,7 +1367,8 @@ int HeterComm::gather_one_node_grad( return ret; } -template +template int HeterComm::gather_multi_node_grad( int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { int dev_id = resource_->dev_id(gpu_num); @@ -1346,7 +1428,8 @@ int HeterComm::gather_multi_node_grad( return ret; } -template +template void HeterComm::end_pass() { int total_gpu = resource_->total_gpu(); std::vector threads;