Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【HETERPS】add cuda remote_streams #34276

Merged
merged 4 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#endif
#include "paddle/fluid/framework/rw_lock.h"
#include "thrust/pair.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
// #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/type_defs.h"
Expand Down Expand Up @@ -63,6 +64,8 @@ class HashTable {

int size() { return container_->size(); }

std::unique_ptr<RWLock> rwlock_{nullptr};

private:
TableContainer<KeyType, ValType>* container_;
int BLOCK_SIZE_{256};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ __global__ void update_kernel(Table* table,
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
container_ = new TableContainer<KeyType, ValType>(capacity);
rwlock_.reset(new RWLock);
}

template <typename KeyType, typename ValType>
Expand Down
19 changes: 13 additions & 6 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,15 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
auto& node = path_[num][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->RDLock();
tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<ValType*>(node.val_storage),
h_right[i] - h_left[i] + 1, resource_->remote_stream(i));
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i));
cudaStreamSynchronize(resource_->remote_stream(i, num));
tables_[i]->rwlock_->UNLock();
}

walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
Expand Down Expand Up @@ -621,13 +624,15 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
cudaStreamSynchronize(node.in_stream);

platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i));
resource_->remote_stream(i, gpu_num));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i));
cudaStreamSynchronize(resource_->remote_stream(i, gpu_num));
tables_[i]->rwlock_->UNLock();
}
}

Expand All @@ -641,9 +646,11 @@ void HeterComm<KeyType, ValType, GradType>::update_one_table(

int dev_id = resource_->dev_id(gpu_num);
platform::CUDADeviceGuard guard(dev_id);
tables_[gpu_num]->rwlock_->WRLock();
tables_[gpu_num]->update(d_keys, d_grads, len, sgd,
resource_->remote_stream(gpu_num));
cudaStreamSynchronize(resource_->remote_stream(gpu_num));
resource_->remote_stream(gpu_num, gpu_num));
tables_[gpu_num]->rwlock_->UNLock();
cudaStreamSynchronize(resource_->remote_stream(gpu_num, gpu_num));
}

template <typename KeyType, typename ValType, typename GradType>
Expand Down
14 changes: 8 additions & 6 deletions paddle/fluid/framework/fleet/heter_ps/heter_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ GPUResource::GPUResource(std::vector<int>& dev_ids, int index) {
platform::CUDADeviceGuard guard(dev_id_);
local_streams_.resize(dev_ids_.size());
comm_streams_.resize(dev_ids_.size());
remote_streams_.resize(dev_ids_.size());

for (size_t i = 0; i < dev_ids_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&local_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&comm_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&remote_streams_[i], cudaStreamNonBlocking));
}

PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&remote_stream_, cudaStreamNonBlocking));
}

GPUResource::~GPUResource() {
Expand All @@ -47,7 +47,9 @@ GPUResource::~GPUResource() {
for (size_t i = 0; i < comm_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(comm_streams_[i]));
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_stream_));
for (size_t i = 0; i < remote_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_streams_[i]));
}
}

void HeterPsResource::enable_p2p() {
Expand Down Expand Up @@ -90,8 +92,8 @@ cudaStream_t HeterPsResource::local_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->local_stream(stream_num);
}

cudaStream_t HeterPsResource::remote_stream(int gpu_num) {
return resources_[gpu_num]->remote_stream();
cudaStream_t HeterPsResource::remote_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->remote_stream(stream_num);
}

int HeterPsResource::dev_id(int num) { return dev_ids_[num]; }
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/heter_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ class GPUResource {
int dev_id() const { return dev_id_; }
int index() const { return index_; }
gpuStream_t local_stream(int num) { return local_streams_[num]; }
gpuStream_t remote_stream() { return remote_stream_; }
gpuStream_t remote_stream(int num) { return remote_streams_[num]; }
gpuStream_t comm_stream(int num) { return comm_streams_[num]; }

int dev_id_;
int index_;
std::vector<int> dev_ids_;
gpuStream_t remote_stream_;
std::vector<gpuStream_t> remote_streams_;
std::vector<gpuStream_t> local_streams_;
std::vector<gpuStream_t> comm_streams_;
};
Expand All @@ -57,7 +57,7 @@ class HeterPsResource {
int get_index_by_devid(int devid);
int dev_id(int num);
gpuStream_t local_stream(int gpu_num, int stream_num);
gpuStream_t remote_stream(int gpu_num);
gpuStream_t remote_stream(int gpu_num, int stream_num);
gpuStream_t comm_stream(int gpu_num, int stream_num);

std::vector<std::shared_ptr<GPUResource>> resources_;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);

PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num,
total_length, gpu_keys);
cudaStreamSynchronize(stream);
Expand All @@ -135,7 +135,7 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
platform::DeviceContextPool::Instance().Get(
BOOST_GET_CONST(platform::CUDAPlace, place)))
->stream();
CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>(
CopyKeysKernel<<<(total_len + 1024 - 1) / 1024, 1024, 0, stream>>>(
origin_keys, total_keys, gpu_len, slot_num, total_len);
cudaStreamSynchronize(stream);
}
Expand Down Expand Up @@ -173,7 +173,7 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);

PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, hidden_size,
slot_lengths.size(), total_length, batch_size, d_slot_vector);
cudaStreamSynchronize(stream);
Expand Down