Skip to content

Commit

Permalink
apply shard_num (PaddlePaddle#55)
Browse files Browse the repository at this point in the history
* save performance;test=develop

* increase shard_num
  • Loading branch information
danleifeng authored Jul 21, 2022
1 parent ce40d68 commit e540156
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/ps/table/memory_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ int32_t MemorySparseTable::Save(const std::string& dirname,

size_t file_start_idx = _avg_local_shard_num * _shard_idx;

#ifdef PADDLE_WITH_GPU_GRAPH
int thread_num = _real_local_shard_num;
#else
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
#endif
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) {
Expand Down
23 changes: 15 additions & 8 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,6 @@ class PSGPUWrapper {
PSGPUWrapper() {
HeterPs_ = NULL;
sleep_seconds_before_fail_exit_ = 300;
pull_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
pull_thread_pool_[i].reset(new ::ThreadPool(1));
}
hbm_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
hbm_thread_pool_[i].reset(new ::ThreadPool(1));
}
}

void PullSparse(const paddle::platform::Place& place, const int table_id,
Expand Down Expand Up @@ -392,6 +384,21 @@ class PSGPUWrapper {
void InitializeGPUServer(paddle::distributed::PSParameter ps_param) {
auto sparse_table =
ps_param.server_param().downpour_server_param().downpour_table_param(0);
// set build thread_num and shard_num
thread_keys_thread_num_ = sparse_table.shard_num();
thread_keys_shard_num_ = sparse_table.shard_num();
VLOG(1) << "ps_gpu build phase thread_num:" << thread_keys_thread_num_
<< " shard_num:" << thread_keys_shard_num_;

pull_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
pull_thread_pool_[i].reset(new ::ThreadPool(1));
}
hbm_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
hbm_thread_pool_[i].reset(new ::ThreadPool(1));
}

auto sparse_table_accessor = sparse_table.accessor();
auto sparse_table_accessor_parameter =
sparse_table_accessor.ctr_accessor_param();
Expand Down

0 comments on commit e540156

Please sign in to comment.