diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 08bef6261cf34..2268f5253660c 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -1079,14 +1079,44 @@ std::string GraphTable::get_inverse_etype(std::string &etype) { return res; } -int32_t GraphTable::load_node_and_edge_file(std::string etype, - std::string ntype, - std::string epath, - std::string npath, +int32_t GraphTable::parse_type_to_typepath(std::string &type2files, + std::string graph_data_local_path, + std::vector &res_type, + std::unordered_map &res_type2path) { + auto type2files_split = paddle::string::split_string(type2files, ","); + if (type2files_split.size() == 0) { + return -1; + } + for (auto one_type2file : type2files_split) { + auto one_type2file_split = paddle::string::split_string(one_type2file, ":"); + auto type = one_type2file_split[0]; + auto type_dir = one_type2file_split[1]; + res_type.push_back(type); + res_type2path[type] = graph_data_local_path + "/" + type_dir; + } + return 0; +} + +int32_t GraphTable::load_node_and_edge_file(std::string etype2files, + std::string ntype2files, + std::string graph_data_local_path, int part_num, bool reverse) { - auto etypes = paddle::string::split_string(etype, ","); - auto ntypes = paddle::string::split_string(ntype, ","); + std::vector etypes; + std::unordered_map edge_to_edgedir; + int res = parse_type_to_typepath(etype2files, graph_data_local_path, etypes, edge_to_edgedir); + if (res != 0) { + VLOG(0) << "parse edge type and edgedir failed!"; + return -1; + } + std::vector ntypes; + std::unordered_map node_to_nodedir; + res = parse_type_to_typepath(ntype2files, graph_data_local_path, ntypes, node_to_nodedir); + if (res != 0) { + VLOG(0) << "parse node type and nodedir failed!"; + return -1; + } + VLOG(0) << "etypes size: " << etypes.size(); VLOG(0) << "whether reverse: " << reverse; is_load_reverse_edge = reverse; @@ -1098,7 +1128,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype, tasks.push_back( _shards_task_pool[i % task_pool_size_]->enqueue([&, i, this]() -> int { if (i < etypes.size()) { - std::string etype_path = epath + "/" + etypes[i]; + std::string etype_path = edge_to_edgedir[etypes[i]]; auto etype_path_list = paddle::framework::localfs_list(etype_path); std::string etype_path_str; if (part_num > 0 && part_num < (int)etype_path_list.size()) { @@ -1116,6 +1146,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype, this->load_edges(etype_path_str, true, r_etype); } } else { + std::string npath = node_to_nodedir[ntypes[0]]; auto npath_list = paddle::framework::localfs_list(npath); std::string npath_str; if (part_num > 0 && part_num < (int)npath_list.size()) { diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 0855babec83c1..9fe1231987f84 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -537,14 +537,18 @@ class GraphTable : public Table { virtual int32_t Initialize(const GraphParameter &config); int32_t Load(const std::string &path, const std::string ¶m); - int32_t load_node_and_edge_file(std::string etype, - std::string ntype, - std::string epath, - std::string npath, + int32_t load_node_and_edge_file(std::string etype2files, + std::string ntype2files, + std::string graph_data_local_path, int part_num, bool reverse); std::string get_inverse_etype(std::string &etype); + + int32_t parse_type_to_typepath(std::string &type2files, + std::string graph_data_local_path, + std::vector &res_type, + std::unordered_map &res_type2path); int32_t load_edges(const std::string &path, bool reverse, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 4cc1b746a558f..e2c6df3102aeb 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -155,15 +155,14 @@ void GraphGpuWrapper::load_node_file(std::string name, std::string filepath) { } } -void GraphGpuWrapper::load_node_and_edge(std::string etype, - std::string ntype, - std::string epath, - std::string npath, +void GraphGpuWrapper::load_node_and_edge(std::string etype2files, + std::string ntype2files, + std::string graph_data_local_path, int part_num, bool reverse) { - ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table_->load_node_and_edge_file( - etype, ntype, epath, npath, part_num, reverse); + ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->load_node_and_edge_file( + etype2files, ntype2files, graph_data_local_path, part_num, reverse); } void GraphGpuWrapper::add_table_feat_conf(std::string table_name, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index b41303b85e0df..8ca3ee5899279 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -48,10 +48,9 @@ class GraphGpuWrapper { int feat_shape); void load_edge_file(std::string name, std::string filepath, bool reverse); void load_node_file(std::string name, std::string filepath); - void load_node_and_edge(std::string etype, - std::string ntype, - std::string epath, - std::string npath, + void load_node_and_edge(std::string etype2files, + std::string ntype2files, + std::string graph_data_local_path, int part_num, bool reverse); int32_t load_next_partition(int idx); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 91eb11d9dbdbc..9ccd724d53301 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -40,6 +40,7 @@ limitations under the License. */ #endif DECLARE_int32(gpugraph_dedup_pull_push_mode); +DECLARE_int32(gpugraph_sparse_table_storage_mode); namespace paddle { namespace framework { @@ -336,19 +337,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { platform::Timer timeline; - std::vector> task_futures; - int device_num = heter_devices_.size(); - auto& local_keys = gpu_task->feature_keys_; - auto& local_ptr = gpu_task->value_ptr_; auto& local_dim_keys = gpu_task->feature_dim_keys_; auto& local_dim_ptr = gpu_task->value_dim_ptr_; - auto& device_keys = gpu_task->device_keys_; - auto& device_vals = gpu_task->device_values_; auto& device_dim_keys = gpu_task->device_dim_keys_; auto& device_dim_ptr = gpu_task->device_dim_ptr_; - auto& device_dim_mutex = gpu_task->dim_mutex_; for (size_t dev = 0; dev < device_dim_keys.size(); dev++) { device_dim_keys[dev].resize(multi_mf_dim_); @@ -461,7 +455,10 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } }; + //fleet_ptr->pslib_ptr_->_worker_ptr->acquire_table_mutex(this->table_id_); threads.resize(thread_keys_shard_num_ * multi_mf_dim_); + + std::vector> task_futures; for (int i = 0; i < thread_keys_shard_num_; i++) { for (int j = 0; j < multi_mf_dim_; j++) { task_futures.emplace_back( @@ -471,6 +468,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { for (auto& f : task_futures) { f.wait(); } + //fleet_ptr->pslib_ptr_->_worker_ptr->release_table_mutex(this->table_id_); task_futures.clear(); timeline.Pause(); VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec() @@ -484,203 +482,231 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { gloo_wrapper->Barrier(); } - timeline.Start(); - std::vector>> pass_values; - - bool record_status = false; - auto& device_task_keys = gpu_task->device_task_keys_; - auto& device_task_ptrs = gpu_task->device_task_ptr_; - auto build_pull_dynamic_mf_func = [this, - device_num, - &local_dim_keys, - &local_dim_ptr, - &device_dim_keys, - &device_dim_ptr, - &device_dim_mutex](int i, int j) { - std::vector> task_keys(device_num); -#ifdef PADDLE_WITH_PSLIB - std::vector> task_ptrs( - device_num); -#endif +} -#ifdef PADDLE_WITH_PSCORE - std::vector> task_ptrs( - device_num); -#endif - for (size_t k = 0; k < local_dim_keys[i][j].size(); k++) { - int shard = local_dim_keys[i][j][k] % device_num; - task_keys[shard].push_back(local_dim_keys[i][j][k]); - task_ptrs[shard].push_back(local_dim_ptr[i][j][k]); - } - // allocate local keys to devices - for (int dev = 0; dev < device_num; dev++) { - device_dim_mutex[dev][j]->lock(); - int len = task_keys[dev].size(); - int cur = device_dim_keys[dev][j].size(); - device_dim_keys[dev][j].resize(device_dim_keys[dev][j].size() + len); - device_dim_ptr[dev][j].resize(device_dim_ptr[dev][j].size() + len); - for (int k = 0; k < len; ++k) { - device_dim_keys[dev][j][cur + k] = task_keys[dev][k]; - device_dim_ptr[dev][j][cur + k] = task_ptrs[dev][k]; +void PSGPUWrapper::PrepareGPUTask(std::shared_ptr gpu_task) { + platform::Timer timeline; + int device_num = heter_devices_.size(); + std::vector threads; + std::vector> task_futures; + auto& local_keys = gpu_task->feature_keys_; + auto& local_ptr = gpu_task->value_ptr_; + auto& local_dim_keys = gpu_task->feature_dim_keys_; + auto& local_dim_ptr = gpu_task->value_dim_ptr_; + + auto& device_keys = gpu_task->device_keys_; + auto& device_vals = gpu_task->device_values_; + auto& device_dim_keys = gpu_task->device_dim_keys_; + auto& device_dim_ptr = gpu_task->device_dim_ptr_; + auto& device_dim_mutex = gpu_task->dim_mutex_; + //auto& device_mutex = gpu_task->mutex_; + + if (multi_mf_dim_) { + for (size_t dev = 0; dev < device_dim_keys.size(); dev++) { + device_dim_keys[dev].resize(multi_mf_dim_); + device_dim_ptr[dev].resize(multi_mf_dim_); } - device_dim_mutex[dev][j]->unlock(); } - }; - auto build_func = [device_num, - record_status, - &pass_values, - &local_keys, - &local_ptr, - &device_task_keys, - &device_task_ptrs](int i) { - auto& task_keys = device_task_keys[i]; -#ifdef PADDLE_WITH_PSLIB - auto& task_ptrs = device_task_ptrs[i]; -#endif -#ifdef PADDLE_WITH_PSCORE - auto& task_ptrs = device_task_ptrs[i]; -#endif - for (size_t j = 0; j < local_keys[i].size(); j++) { - int shard = local_keys[i][j] % device_num; - task_keys[shard].push_back(local_keys[i][j]); - task_ptrs[shard].push_back(local_ptr[i][j]); - } -#ifdef PADDLE_WITH_PSLIB - if (record_status) { - size_t local_keys_size = local_keys.size(); - size_t pass_values_size = pass_values.size(); - for (size_t j = 0; j < pass_values_size; j += local_keys_size) { - auto& shard_values = pass_values[j]; - for (size_t pair_idx = 0; pair_idx < pass_values[j].size(); - pair_idx++) { - auto& cur_pair = shard_values[pair_idx]; - int shard = cur_pair.first % device_num; - task_keys[shard].push_back(cur_pair.first); - task_ptrs[shard].push_back( - (paddle::ps::DownpourFixedFeatureValue*)cur_pair.second); + timeline.Start(); + std::vector>> pass_values; + + bool record_status = false; + auto& device_task_keys = gpu_task->device_task_keys_; + auto& device_task_ptrs = gpu_task->device_task_ptr_; + auto build_pull_dynamic_mf_func = [this, + device_num, + &local_dim_keys, + &local_dim_ptr, + &device_dim_keys, + &device_dim_ptr, + &device_dim_mutex](int i, int j) { + std::vector> task_keys(device_num); + #ifdef PADDLE_WITH_PSLIB + std::vector> task_ptrs( + device_num); + #endif + + #ifdef PADDLE_WITH_PSCORE + std::vector> task_ptrs( + device_num); + #endif + for (size_t k = 0; k < local_dim_keys[i][j].size(); k++) { + int shard = local_dim_keys[i][j][k] % device_num; + task_keys[shard].push_back(local_dim_keys[i][j][k]); + task_ptrs[shard].push_back(local_dim_ptr[i][j][k]); + } + // allocate local keys to devices + for (int dev = 0; dev < device_num; dev++) { + device_dim_mutex[dev][j]->lock(); + int len = task_keys[dev].size(); + int cur = device_dim_keys[dev][j].size(); + device_dim_keys[dev][j].resize(device_dim_keys[dev][j].size() + len); + device_dim_ptr[dev][j].resize(device_dim_ptr[dev][j].size() + len); + for (int k = 0; k < len; ++k) { + device_dim_keys[dev][j][cur + k] = task_keys[dev][k]; + device_dim_ptr[dev][j][cur + k] = task_ptrs[dev][k]; } + device_dim_mutex[dev][j]->unlock(); } + }; + auto build_func = [device_num, + record_status, + &pass_values, + &local_keys, + &local_ptr, + &device_task_keys, + &device_task_ptrs](int i) { + auto& task_keys = device_task_keys[i]; + #ifdef PADDLE_WITH_PSLIB + auto& task_ptrs = device_task_ptrs[i]; + #endif + + #ifdef PADDLE_WITH_PSCORE + auto& task_ptrs = device_task_ptrs[i]; + #endif + + for (size_t j = 0; j < local_keys[i].size(); j++) { + int shard = local_keys[i][j] % device_num; + task_keys[shard].push_back(local_keys[i][j]); + task_ptrs[shard].push_back(local_ptr[i][j]); + } + #ifdef PADDLE_WITH_PSLIB + if (record_status) { + size_t local_keys_size = local_keys.size(); + size_t pass_values_size = pass_values.size(); + for (size_t j = 0; j < pass_values_size; j += local_keys_size) { + auto& shard_values = pass_values[j]; + for (size_t pair_idx = 0; pair_idx < pass_values[j].size(); + pair_idx++) { + auto& cur_pair = shard_values[pair_idx]; + int shard = cur_pair.first % device_num; + task_keys[shard].push_back(cur_pair.first); + task_ptrs[shard].push_back( + (paddle::ps::DownpourFixedFeatureValue*)cur_pair.second); + } + } + } + #endif + }; + if (!multi_mf_dim_) { + for (int i = 0; i < thread_keys_shard_num_; i++) { + task_futures.emplace_back(hbm_thread_pool_[i]->enqueue(build_func, i)); + } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); + VLOG(0) << "GpuPs build hbmps done"; } -#endif - }; - if (!multi_mf_dim_) { - for (int i = 0; i < thread_keys_shard_num_; i++) { - task_futures.emplace_back(hbm_thread_pool_[i]->enqueue(build_func, i)); - } - for (auto& f : task_futures) { - f.wait(); - } - task_futures.clear(); - VLOG(0) << "GpuPs build hbmps done"; - } - std::vector> prefix_sum; - prefix_sum.resize(device_num); - for (int i = 0; i < device_num; i++) { - prefix_sum[i].resize(thread_keys_shard_num_ + 1); - prefix_sum[i][0] = 0; - } - auto calc_prefix_func = [this, - &prefix_sum, - &device_keys, - &device_vals, - &device_task_keys](int device_num) { - for (int j = 0; j < thread_keys_shard_num_; j++) { - prefix_sum[device_num][j + 1] = - prefix_sum[device_num][j] + device_task_keys[j][device_num].size(); - } - device_keys[device_num].resize( - prefix_sum[device_num][thread_keys_shard_num_]); - device_vals[device_num].resize( - prefix_sum[device_num][thread_keys_shard_num_]); - }; - if (!multi_mf_dim_) { + std::vector> prefix_sum; + prefix_sum.resize(device_num); for (int i = 0; i < device_num; i++) { - task_futures.emplace_back( - hbm_thread_pool_[i]->enqueue(calc_prefix_func, i)); + prefix_sum[i].resize(thread_keys_shard_num_ + 1); + prefix_sum[i][0] = 0; } - for (auto& f : task_futures) { - f.wait(); + auto calc_prefix_func = [this, + &prefix_sum, + &device_keys, + &device_vals, + &device_task_keys](int device_num) { + for (int j = 0; j < thread_keys_shard_num_; j++) { + prefix_sum[device_num][j + 1] = + prefix_sum[device_num][j] + device_task_keys[j][device_num].size(); + } + device_keys[device_num].resize( + prefix_sum[device_num][thread_keys_shard_num_]); + device_vals[device_num].resize( + prefix_sum[device_num][thread_keys_shard_num_]); + }; + if (!multi_mf_dim_) { + for (int i = 0; i < device_num; i++) { + task_futures.emplace_back( + hbm_thread_pool_[i]->enqueue(calc_prefix_func, i)); + } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); } - task_futures.clear(); - } - VLOG(0) << "prefix done"; - auto prepare_dev_value_func = [device_num, - &prefix_sum, - &device_keys, - &device_vals, - &device_task_keys, - &device_task_ptrs](int dev, int shard_id) { - // auto& task_keys = device_task_keys[shard_id]; -#ifdef PADDLE_WITH_PSLIB - auto& task_ptrs = device_task_ptrs[shard_id]; -#endif - - // #ifdef PADDLE_WITH_PSCORE - // auto& task_ptrs = device_task_ptrs[shard_id]; - // #endif - - // int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id]; - // int cur = prefix_sum[dev][shard_id]; -#ifdef PADDLE_WITH_PSLIB - for (int j = 0; j < len; ++j) { - device_keys[dev][cur + j] = task_keys[dev][j]; - float* ptr_val = task_ptrs[dev][j]->data(); - FeatureValue& val = device_vals[dev][cur + j]; - size_t dim = task_ptrs[dev][j]->size(); - - val.delta_score = ptr_val[1]; - val.show = ptr_val[2]; - val.clk = ptr_val[3]; - val.slot = ptr_val[6]; - val.lr = ptr_val[4]; - val.lr_g2sum = ptr_val[5]; - val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; + VLOG(0) << "prefix done"; + auto prepare_dev_value_func = [device_num, + &prefix_sum, + &device_keys, + &device_vals, + &device_task_keys, + &device_task_ptrs](int dev, int shard_id) { + // auto& task_keys = device_task_keys[shard_id]; + #ifdef PADDLE_WITH_PSLIB + auto& task_ptrs = device_task_ptrs[shard_id]; + #endif + + // #ifdef PADDLE_WITH_PSCORE + // auto& task_ptrs = device_task_ptrs[shard_id]; + // #endif + + // int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id]; + // int cur = prefix_sum[dev][shard_id]; + #ifdef PADDLE_WITH_PSLIB + for (int j = 0; j < len; ++j) { + device_keys[dev][cur + j] = task_keys[dev][j]; + float* ptr_val = task_ptrs[dev][j]->data(); + FeatureValue& val = device_vals[dev][cur + j]; + size_t dim = task_ptrs[dev][j]->size(); + + val.delta_score = ptr_val[1]; + val.show = ptr_val[2]; + val.clk = ptr_val[3]; + val.slot = ptr_val[6]; + val.lr = ptr_val[4]; + val.lr_g2sum = ptr_val[5]; + val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; + } } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; + } + #endif + VLOG(3) << "GpuPs build hbmps done"; + }; + if (multi_mf_dim_) { + threads.resize(thread_keys_shard_num_ * multi_mf_dim_); + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i * multi_mf_dim_ + j] = + std::thread(build_pull_dynamic_mf_func, i, j); } } - } -#endif - VLOG(3) << "GpuPs build hbmps done"; - }; - - if (multi_mf_dim_) { - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - threads[i * multi_mf_dim_ + j] = - std::thread(build_pull_dynamic_mf_func, i, j); + for (std::thread& t : threads) { + t.join(); } - } - for (std::thread& t : threads) { - t.join(); - } - } else { - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < device_num; j++) { - task_futures.emplace_back( - hbm_thread_pool_[i]->enqueue(prepare_dev_value_func, j, i)); + } else { + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < device_num; j++) { + task_futures.emplace_back( + hbm_thread_pool_[i]->enqueue(prepare_dev_value_func, j, i)); + } } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); } - for (auto& f : task_futures) { - f.wait(); - } - task_futures.clear(); - } - timeline.Pause(); - VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() - << " seconds."; + timeline.Pause(); + VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() + << " seconds."; } + void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { int device_num = heter_devices_.size(); platform::Timer timeline; @@ -909,6 +935,7 @@ void PSGPUWrapper::start_build_thread() { running_ = true; VLOG(3) << "start build CPU ps thread."; pre_build_threads_ = std::thread([this] { pre_build_thread(); }); + buildpull_threads_ = std::thread([this] { build_pull_thread(); }); } void PSGPUWrapper::pre_build_thread() { @@ -931,6 +958,25 @@ void PSGPUWrapper::pre_build_thread() { VLOG(3) << "build cpu thread end"; } +void PSGPUWrapper::build_pull_thread() { + while (running_) { + std::shared_ptr gpu_task = nullptr; + if (!buildcpu_ready_channel_->Get(gpu_task)) { + continue; + } + VLOG(3) << "thread build pull start."; + platform::Timer timer; + timer.Start(); + // build cpu ps data process + BuildPull(gpu_task); + timer.Pause(); + VLOG(1) << "thread BuildPull end, cost time: " << timer.ElapsedSec() + << "s"; + buildpull_ready_channel_->Put(gpu_task); + } + VLOG(3) << "build cpu thread end"; +} + void PSGPUWrapper::build_task() { // build_task: build_pull + build_gputask std::shared_ptr gpu_task = nullptr; @@ -939,18 +985,18 @@ void PSGPUWrapper::build_task() { return; } // ins and pre_build end - if (!buildcpu_ready_channel_->Get(gpu_task)) { + if (!buildpull_ready_channel_->Get(gpu_task)) { return; } - VLOG(0) << "BuildPull start."; + VLOG(0) << "PrepareGPUTask start."; platform::Timer timer; timer.Start(); - BuildPull(gpu_task); + PrepareGPUTask(gpu_task); BuildGPUTask(gpu_task); timer.Pause(); - VLOG(0) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec() - << "s"; + VLOG(0) << "PrepareGPUTask + BuildGPUTask end, cost time: " << timer.ElapsedSec() + << "s"; current_task_ = gpu_task; } @@ -1000,6 +1046,8 @@ void PSGPUWrapper::EndPass() { int thread_num = 8; auto accessor_wrapper_ptr = GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); + //auto fleet_ptr = FleetWrapper::GetInstance(); + //fleet_ptr->pslib_ptr_->_worker_ptr->acquire_table_mutex(this->table_id_); auto dump_pool_to_cpu_func = [this, thread_num, &accessor_wrapper_ptr]( int i, int j, int z) { PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); @@ -1075,6 +1123,7 @@ void PSGPUWrapper::EndPass() { gpu_task_pool_.Push(current_task_); current_task_ = nullptr; gpu_free_channel_->Put(current_task_); + //fleet_ptr->pslib_ptr_->_worker_ptr->release_table_mutex(this->table_id_); timer.Pause(); VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index c48cf3347573a..54f4abd97e831 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -191,11 +191,13 @@ class PSGPUWrapper { void BuildGPUTask(std::shared_ptr gpu_task); void PreBuildTask(std::shared_ptr gpu_task); void BuildPull(std::shared_ptr gpu_task); + void PrepareGPUTask(std::shared_ptr gpu_task); void LoadIntoMemory(bool is_shuffle); void BeginPass(); void EndPass(); void start_build_thread(); void pre_build_thread(); + void build_pull_thread(); void build_task(); void Finalize() { @@ -205,10 +207,13 @@ class PSGPUWrapper { } data_ready_channel_->Close(); buildcpu_ready_channel_->Close(); + buildpull_ready_channel_->Close(); gpu_free_channel_->Close(); running_ = false; VLOG(3) << "begin stop pre_build_threads_"; pre_build_threads_.join(); + VLOG(3) << "begin stop buildpull_threads_"; + buildpull_threads_.join(); s_instance_ = nullptr; VLOG(3) << "PSGPUWrapper Finalize Finished."; HeterPs_->show_table_collisions(); @@ -278,6 +283,8 @@ class PSGPUWrapper { data_ready_channel_->SetCapacity(3); buildcpu_ready_channel_->Open(); buildcpu_ready_channel_->SetCapacity(3); + buildpull_ready_channel_->Open(); + buildpull_ready_channel_->SetCapacity(1); gpu_free_channel_->Open(); gpu_free_channel_->SetCapacity(1); @@ -712,8 +719,13 @@ class PSGPUWrapper { paddle::framework::ChannelObject>> gpu_free_channel_ = paddle::framework::MakeChannel>(); + std::shared_ptr< + paddle::framework::ChannelObject>> + buildpull_ready_channel_ = + paddle::framework::MakeChannel>(); std::shared_ptr current_task_ = nullptr; std::thread pre_build_threads_; + std::thread buildpull_threads_; bool running_ = false; std::vector> pull_thread_pool_; std::vector> hbm_thread_pool_; diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 4933cbb6cf74b..a26ed5dbdad8c 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -968,7 +968,10 @@ PADDLE_DEFINE_EXPORTED_int32( PADDLE_DEFINE_EXPORTED_bool(gpugraph_load_node_list_into_hbm, true, "enable load_node_list_into_hbm, default true"); - +PADDLE_DEFINE_EXPORTED_int32( + gpugraph_sparse_table_storage_mode, + 0, + "parse_table_storage_mode, default 0"); /** * ProcessGroupNCCL related FLAG * Name: nccl_blocking_wait