Skip to content

Commit

Permalink
Lxch curand bug fix (PaddlePaddle#48)
Browse files Browse the repository at this point in the history
* pull sparse-ptr asyn

* fix curand bug

Co-authored-by: liaoxiaochao <liaoxiaochao@baidu.com>
  • Loading branch information
SmallBirdLiao and liaoxiaochao-bb authored Aug 11, 2022
1 parent df8f7c0 commit 184c631
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class HashTable {

template <typename Sgd>
void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd& sgd,
gpuStream_t stream);
gpuStream_t stream, int dev_id);

int size() { return container_->size(); }

Expand Down
34 changes: 21 additions & 13 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ __global__ void curand_init_kernel(curandState* p_value, int len) {

class CuRandState {
public:
struct CallBackInfo {
std::shared_ptr<CuRandState>* obj;
int dev_id;
};
CuRandState() = default;
CuRandState(const CuRandState&) = delete;
CuRandState(CuRandState&&) = delete;
Expand Down Expand Up @@ -165,24 +169,28 @@ class CuRandState {
return states_;
}

static HeterObjectPool<CuRandState>& pool() {
static HeterObjectPool<CuRandState> p;
return p;
static HeterObjectPool<CuRandState>& pool(int dev_id) {
static HeterObjectPool<CuRandState> p[100];
return p[dev_id];
}

static std::shared_ptr<CuRandState> get() {
return pool().Get();
static std::shared_ptr<CuRandState> get(int dev_id) {
return pool(dev_id).Get();
}

static void CUDART_CB pushback_cu_rand_state(void *data) {
auto state = static_cast<std::shared_ptr<CuRandState>*>(data);
pool().Push(std::move(*state));
static void CUDART_CB pushback_cu_rand_state(void* data) {
auto state = static_cast<CallBackInfo*>(data);
pool(state->dev_id).Push(std::move(*(state->obj)));
delete state->obj;
delete state;
}

static void push(std::shared_ptr<CuRandState> state, gpuStream_t stream) {
static void push(std::shared_ptr<CuRandState> state, gpuStream_t stream, int dev_id) {
CallBackInfo* obj = new CallBackInfo();
obj->dev_id = dev_id;
obj->obj = new std::shared_ptr<CuRandState>(std::move(state));
CHECK(cudaLaunchHostFunc(stream, pushback_cu_rand_state,
new std::shared_ptr<CuRandState>(std::move(state))) == cudaSuccess);
obj) == cudaSuccess);
}
private:
size_t size_ = 0;
Expand Down Expand Up @@ -382,16 +390,16 @@ template <typename KeyType, typename ValType>
template <typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const char* d_grads, size_t len,
Sgd& sgd, gpuStream_t stream) {
Sgd& sgd, gpuStream_t stream, int dev_id) {
if (len == 0) {
return;
}
auto state = CuRandState::get();
auto state = CuRandState::get(dev_id);
auto d_state = state->get(len, stream);
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, *device_optimizer_config_, d_keys, d_grads, d_state, len, sgd, push_grad_value_size_);
CuRandState::push(state, stream);
CuRandState::push(state, stream, dev_id);
}

template <typename KeyType, typename ValType>
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ void HeterComm<KeyType, ValType, GradType, GPUAccessor>::push_sparse(int gpu_num
ptr_tables_[i]->rwlock_->WRLock();
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, gpu_num));
resource_->remote_stream(i, gpu_num), i);
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i, gpu_num));
Expand Down

0 comments on commit 184c631

Please sign in to comment.