diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index ee6a801fa9183..aae676244bbb6 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -240,7 +240,11 @@ int32_t MemorySparseTable::Save(const std::string& dirname, size_t file_start_idx = _avg_local_shard_num * _shard_idx; +#ifdef PADDLE_WITH_GPU_GRAPH + int thread_num = _real_local_shard_num; +#else int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; +#endif omp_set_num_threads(thread_num); #pragma omp parallel for schedule(dynamic) for (size_t i = 0; i < _real_local_shard_num; ++i) { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index a8fab77238e4e..98a27670a1af6 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -151,14 +151,6 @@ class PSGPUWrapper { PSGPUWrapper() { HeterPs_ = NULL; sleep_seconds_before_fail_exit_ = 300; - pull_thread_pool_.resize(thread_keys_shard_num_); - for (size_t i = 0; i < pull_thread_pool_.size(); i++) { - pull_thread_pool_[i].reset(new ::ThreadPool(1)); - } - hbm_thread_pool_.resize(thread_keys_shard_num_); - for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { - hbm_thread_pool_[i].reset(new ::ThreadPool(1)); - } } void PullSparse(const paddle::platform::Place& place, const int table_id, @@ -392,6 +384,21 @@ class PSGPUWrapper { void InitializeGPUServer(paddle::distributed::PSParameter ps_param) { auto sparse_table = ps_param.server_param().downpour_server_param().downpour_table_param(0); + // set build thread_num and shard_num + thread_keys_thread_num_ = sparse_table.shard_num(); + thread_keys_shard_num_ = sparse_table.shard_num(); + VLOG(1) << "ps_gpu build phase thread_num:" << thread_keys_thread_num_ + << " shard_num:" << thread_keys_shard_num_; + + pull_thread_pool_.resize(thread_keys_shard_num_); + for (size_t i = 0; i < pull_thread_pool_.size(); i++) { + pull_thread_pool_[i].reset(new ::ThreadPool(1)); + } + hbm_thread_pool_.resize(thread_keys_shard_num_); + for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { + hbm_thread_pool_[i].reset(new ::ThreadPool(1)); + } + auto sparse_table_accessor = sparse_table.accessor(); auto sparse_table_accessor_parameter = sparse_table_accessor.ctr_accessor_param();