From 184c631fff642e4444fe2b4ef50700df6a9a3159 Mon Sep 17 00:00:00 2001 From: SmallBirdLiao <105655932+SmallBirdLiao@users.noreply.github.com> Date: Thu, 11 Aug 2022 10:27:56 +0800 Subject: [PATCH] Lxch curand bug fix (#48) * pull sparse-ptr asyn * fix curand bug Co-authored-by: liaoxiaochao --- .../framework/fleet/heter_ps/hashtable.h | 2 +- .../framework/fleet/heter_ps/hashtable_inl.h | 34 ++++++++++++------- .../framework/fleet/heter_ps/heter_comm_inl.h | 2 +- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 5c98a253c124d9..f40f89529a9cae 100755 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -73,7 +73,7 @@ class HashTable { template void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd& sgd, - gpuStream_t stream); + gpuStream_t stream, int dev_id); int size() { return container_->size(); } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index f34acd05219c1c..b3534e7deaba46 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -136,6 +136,10 @@ __global__ void curand_init_kernel(curandState* p_value, int len) { class CuRandState { public: + struct CallBackInfo { + std::shared_ptr* obj; + int dev_id; + }; CuRandState() = default; CuRandState(const CuRandState&) = delete; CuRandState(CuRandState&&) = delete; @@ -165,24 +169,28 @@ class CuRandState { return states_; } - static HeterObjectPool& pool() { - static HeterObjectPool p; - return p; + static HeterObjectPool& pool(int dev_id) { + static HeterObjectPool p[100]; + return p[dev_id]; } - static std::shared_ptr get() { - return pool().Get(); + static std::shared_ptr get(int dev_id) { + return pool(dev_id).Get(); } - static void CUDART_CB pushback_cu_rand_state(void *data) { - auto state = static_cast*>(data); - pool().Push(std::move(*state)); + static void CUDART_CB pushback_cu_rand_state(void* data) { + auto state = static_cast(data); + pool(state->dev_id).Push(std::move(*(state->obj))); + delete state->obj; delete state; } - static void push(std::shared_ptr state, gpuStream_t stream) { + static void push(std::shared_ptr state, gpuStream_t stream, int dev_id) { + CallBackInfo* obj = new CallBackInfo(); + obj->dev_id = dev_id; + obj->obj = new std::shared_ptr(std::move(state)); CHECK(cudaLaunchHostFunc(stream, pushback_cu_rand_state, - new std::shared_ptr(std::move(state))) == cudaSuccess); + obj) == cudaSuccess); } private: size_t size_ = 0; @@ -382,16 +390,16 @@ template template void HashTable::update(const KeyType* d_keys, const char* d_grads, size_t len, - Sgd& sgd, gpuStream_t stream) { + Sgd& sgd, gpuStream_t stream, int dev_id) { if (len == 0) { return; } - auto state = CuRandState::get(); + auto state = CuRandState::get(dev_id); auto d_state = state->get(len, stream); const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; dy_mf_update_kernel<<>>( container_, *device_optimizer_config_, d_keys, d_grads, d_state, len, sgd, push_grad_value_size_); - CuRandState::push(state, stream); + CuRandState::push(state, stream, dev_id); } template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index a692d1f0a51e77..b19221230fc2a4 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -1140,7 +1140,7 @@ void HeterComm::push_sparse(int gpu_num ptr_tables_[i]->rwlock_->WRLock(); ptr_tables_[i]->update(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, sgd, - resource_->remote_stream(i, gpu_num)); + resource_->remote_stream(i, gpu_num), i); } for (int i = 0; i < total_gpu; ++i) { cudaStreamSynchronize(resource_->remote_stream(i, gpu_num));