From 2efdedda47e9ff7270cb284ce0588da069489db5 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:01:07 +0800 Subject: [PATCH] fix slotfea (#156) fix slot_feature secondary storage ssd mode warmstartup bug --- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 641 +++++++++--------- 1 file changed, 339 insertions(+), 302 deletions(-) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 5dda38fc4e0cc..9762822e196db 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -196,7 +196,12 @@ void PSGPUWrapper::add_key_to_gputask(std::shared_ptr gpu_task) { VLOG(0) << "GpuPs task add keys cost " << timeline.ElapsedSec() << " seconds."; timeline.Start(); - gpu_task->UniqueKeys(); + size_t slot_num = slot_vector_.size() - 1; + // no slot_fea mode and whole_hbm mode, only keep one unique_sort action + if (slot_num > 0 && FLAGS_gpugraph_storage_mode != + paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { + gpu_task->UniqueKeys(); + } timeline.Pause(); VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; } @@ -208,7 +213,7 @@ void PSGPUWrapper::resize_gputask(std::shared_ptr gpu_task) { gpu_task->feature_dim_keys_[i][j].push_back(0); } gpu_task->value_dim_ptr_[i][j].resize( - gpu_task->feature_dim_keys_[i][j].size()); + gpu_task->feature_dim_keys_[i][j].size()); } } } @@ -350,287 +355,312 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { } void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { - platform::Timer timeline; - platform::Timer time_stage; - timeline.Start(); - //8卡数据分片 - size_t device_num = heter_devices_.size(); - std::vector threads; - size_t slot_num = slot_vector_.size() - 1;//node slot 9008 in slot_vector - auto& local_dim_keys = gpu_task->feature_dim_keys_; - double divide_nodeid_cost = 0; - double get_feature_id_cost = 0; - double add_feature_to_set_cost = 0; - double add_feature_to_key_cost = 0; - - std::vector> node_ids(device_num); - size_t node_num = 0; - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - node_num += local_dim_keys[i][j].size(); + platform::Timer timeline; + platform::Timer time_stage; + timeline.Start(); + // 8卡数据分片 + size_t device_num = heter_devices_.size(); + std::vector threads; + size_t slot_num = slot_vector_.size() - 1; // node slot 9008 in slot_vector + auto& local_dim_keys = gpu_task->feature_dim_keys_; // [shard_num, 0, keys]] + double divide_nodeid_cost = 0; + double get_feature_id_cost = 0; + double add_feature_to_set_cost = 0; + double add_feature_to_key_cost = 0; + + std::vector> node_ids(device_num); + size_t node_num = 0; + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + node_num += local_dim_keys[i][j].size(); + } + } + for (auto& node_id_vector : node_ids) { + node_id_vector.reserve(node_num * 1.2 / device_num); + } + + auto& device_dim_mutex = gpu_task->dim_mutex_; + + auto divide_nodeid_to_device = + [this, device_num, &local_dim_keys, &node_ids, &device_dim_mutex](int i, + int j) { + std::vector> task_keys(device_num); + size_t batch = 10000; + for (size_t k = 0; k < device_num; k++) { + task_keys[k].reserve(batch * 1.2 / device_num); } + std::vector shuffle_device = shuffle_int_vector(device_num); + size_t start = 0; + while (start < local_dim_keys[i][j].size()) { + if (batch + start > local_dim_keys[i][j].size()) { + batch = local_dim_keys[i][j].size() - start; + } + for (size_t k = start; k < (start + batch); k++) { + int shard = local_dim_keys[i][j][k] % device_num; + task_keys[shard].push_back(local_dim_keys[i][j][k]); + } + // allocate local keys to devices + for (auto dev : shuffle_device) { + device_dim_mutex[dev][0]->lock(); + int len = task_keys[dev].size(); + for (int k = 0; k < len; ++k) { + node_ids[dev].push_back(task_keys[dev][k]); + } + device_dim_mutex[dev][0]->unlock(); + task_keys[dev].clear(); + } + start += batch; + } + }; + threads.resize(thread_keys_shard_num_ * multi_mf_dim_); + time_stage.Start(); + + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i * multi_mf_dim_ + j] = + std::thread(divide_nodeid_to_device, i, j); } - for (auto &node_id_vector : node_ids){ - node_id_vector.reserve(node_num * 1.2 / device_num); + } + for (std::thread& t : threads) { + t.join(); + } + threads.clear(); + time_stage.Pause(); + divide_nodeid_cost = time_stage.ElapsedSec(); + gpu_task->sub_graph_feas = (void*)(new std::vector); + std::vector& sub_graph_feas = + *((std::vector*)gpu_task->sub_graph_feas); + std::vector> feature_ids(device_num); + std::vector feature_list(device_num); + std::vector feature_list_size(device_num); + size_t batch = 40000; + + time_stage.Start(); + if (FLAGS_gpugraph_storage_mode == + paddle::framework::GpuGraphStorageMode::MEM_EMB_AND_GPU_GRAPH) { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + auto h_slot_feature_num_map = gpu_graph_ptr->slot_feature_num_map(); + int fea_num_per_node = 0; + for (size_t i = 0; i < slot_num; ++i) { + fea_num_per_node += h_slot_feature_num_map[i]; } - auto& device_dim_mutex = gpu_task->dim_mutex_; - - auto divide_nodeid_to_device = [this, - device_num, - &local_dim_keys, - &node_ids, - &device_dim_mutex](int i, int j) { - std::vector> task_keys(device_num); - size_t batch = 10000; - for (size_t k = 0; k < device_num; k++) { - task_keys[k].reserve(batch * 1.2 / device_num); - } - std::vector shuffle_device = shuffle_int_vector(device_num); - size_t start = 0; - while (start < local_dim_keys[i][j].size()) { - if (batch + start > local_dim_keys[i][j].size()) { - batch = local_dim_keys[i][j].size() - start; - } - for (size_t k = start; k < (start + batch); k++) { - int shard = local_dim_keys[i][j][k] % device_num; - task_keys[shard].push_back(local_dim_keys[i][j][k]); - } - // allocate local keys to devices - for (auto dev : shuffle_device) { - device_dim_mutex[dev][0]->lock(); - int len = task_keys[dev].size(); - for (int k = 0; k < len; ++k) { - node_ids[dev].push_back(task_keys[dev][k]); - } - device_dim_mutex[dev][0]->unlock(); - task_keys[dev].clear(); - } - start += batch; - } + auto get_feature_id = [this, + slot_num, + batch, + fea_num_per_node, + &h_slot_feature_num_map, + &node_ids, + &feature_ids](int i) { + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + int* d_slot_feature_num_map; + uint64_t* d_node_list_ptr; + uint64_t* d_feature_list_ptr; + CUDA_CHECK( + cudaMalloc((void**)&d_slot_feature_num_map, slot_num * sizeof(int))); + CUDA_CHECK(cudaMemcpy(d_slot_feature_num_map, + h_slot_feature_num_map.data(), + sizeof(int) * slot_num, + cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMalloc((void**)&d_node_list_ptr, batch * sizeof(uint64_t))); + CUDA_CHECK(cudaMalloc((void**)&d_feature_list_ptr, + batch * fea_num_per_node * sizeof(uint64_t))); + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + uint64_t pos = 0; + size_t real_batch = 0; + feature_ids[i].resize(node_ids[i].size() * fea_num_per_node); + while (pos < node_ids[i].size()) { + real_batch = (pos + batch) <= node_ids[i].size() + ? batch + : node_ids[i].size() - pos; + CUDA_CHECK(cudaMemcpy(d_node_list_ptr, + node_ids[i].data() + pos, + real_batch * sizeof(uint64_t), + cudaMemcpyHostToDevice)); + int ret = gpu_graph_ptr->get_feature_of_nodes(i, + d_node_list_ptr, + d_feature_list_ptr, + real_batch, + slot_num, + d_slot_feature_num_map, + fea_num_per_node); + PADDLE_ENFORCE_EQ( + ret, + 0, + platform::errors::PreconditionNotMet("get_feature_of_nodes error")); + + CUDA_CHECK(cudaMemcpy(feature_ids[i].data() + pos * fea_num_per_node, + d_feature_list_ptr, + real_batch * fea_num_per_node * sizeof(uint64_t), + cudaMemcpyDeviceToHost)); + pos += real_batch; + } + cudaFree(d_slot_feature_num_map); + cudaFree(d_node_list_ptr); + cudaFree(d_feature_list_ptr); }; - threads.resize(thread_keys_shard_num_ * multi_mf_dim_); - time_stage.Start(); - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - threads[i * multi_mf_dim_ + j] = - std::thread(divide_nodeid_to_device, i, j); - } + threads.resize(device_num); + for (size_t i = 0; i < device_num; i++) { + threads[i] = std::thread(get_feature_id, i); } for (std::thread& t : threads) { t.join(); } threads.clear(); - time_stage.Pause(); - divide_nodeid_cost = time_stage.ElapsedSec(); - gpu_task->sub_graph_feas = (void *) (new std::vector); - std::vector &sub_graph_feas = *((std::vector *) gpu_task->sub_graph_feas); - std::vector> feature_ids(device_num); - std::vector feature_list(device_num); - std::vector feature_list_size(device_num); - size_t batch = 40000; - - time_stage.Start(); - if (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::MEM_EMB_AND_GPU_GRAPH) { - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - auto h_slot_feature_num_map = gpu_graph_ptr->slot_feature_num_map(); - int fea_num_per_node = 0; - for (size_t i = 0; i < slot_num; ++i) { - fea_num_per_node += h_slot_feature_num_map[i]; - } - - auto get_feature_id = [this, slot_num, batch, fea_num_per_node, &h_slot_feature_num_map, &node_ids, &feature_ids](int i) { - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - int * d_slot_feature_num_map; - uint64_t * d_node_list_ptr; - uint64_t * d_feature_list_ptr; - CUDA_CHECK(cudaMalloc((void**)&d_slot_feature_num_map, slot_num * sizeof(int))); - CUDA_CHECK(cudaMemcpy(d_slot_feature_num_map, h_slot_feature_num_map.data(), - sizeof(int) * slot_num, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMalloc((void**)&d_node_list_ptr, batch * sizeof(uint64_t))); - CUDA_CHECK(cudaMalloc((void**)&d_feature_list_ptr, batch * fea_num_per_node * sizeof(uint64_t))); - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - uint64_t pos = 0; - size_t real_batch = 0; - feature_ids[i].resize(node_ids[i].size() * fea_num_per_node); - while (pos < node_ids[i].size()) { - real_batch = (pos + batch) <= node_ids[i].size() ? batch : node_ids[i].size() - pos; - CUDA_CHECK(cudaMemcpy(d_node_list_ptr, - node_ids[i].data() + pos, - real_batch * sizeof(uint64_t), - cudaMemcpyHostToDevice)); - int ret = gpu_graph_ptr->get_feature_of_nodes(i, - d_node_list_ptr, - d_feature_list_ptr, - real_batch, - slot_num, - d_slot_feature_num_map, - fea_num_per_node); - PADDLE_ENFORCE_EQ( - ret, - 0, - platform::errors::PreconditionNotMet( - "get_feature_of_nodes error")); - - CUDA_CHECK(cudaMemcpy(feature_ids[i].data() + pos * fea_num_per_node, - d_feature_list_ptr, - real_batch * fea_num_per_node * sizeof(uint64_t), - cudaMemcpyDeviceToHost)); - pos += real_batch; - } - cudaFree(d_slot_feature_num_map); - cudaFree(d_node_list_ptr); - cudaFree(d_feature_list_ptr); - }; - - threads.resize(device_num); - for (size_t i = 0; i < device_num; i++) { - threads[i] = std::thread(get_feature_id, i); - } - for (std::thread& t : threads) { - t.join(); - } - threads.clear(); - for (size_t i = 0; i < device_num; i++) { - feature_list[i] = feature_ids[i].data(); - feature_list_size[i] = feature_ids[i].size(); - } - } - else if (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::MEM_EMB_FEATURE_AND_GPU_GRAPH - || FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) { - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - sub_graph_feas = gpu_graph_ptr->get_sub_graph_fea(node_ids, slot_num); - for (size_t i = 0; i < device_num; i++) { - feature_list[i] = sub_graph_feas[i].feature_list; - feature_list_size[i] = sub_graph_feas[i].feature_size; - } - } - else { - PADDLE_ENFORCE_EQ( - 1, - 0, - " FLAGS_gpugraph_storage_mode is not adaptived "); + for (size_t i = 0; i < device_num; i++) { + feature_list[i] = feature_ids[i].data(); + feature_list_size[i] = feature_ids[i].size(); } - time_stage.Pause(); - get_feature_id_cost = time_stage.ElapsedSec(); - size_t feature_num = 0; + } else if (FLAGS_gpugraph_storage_mode == + paddle::framework::GpuGraphStorageMode:: + MEM_EMB_FEATURE_AND_GPU_GRAPH || + FLAGS_gpugraph_storage_mode == + paddle::framework::GpuGraphStorageMode:: + SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + sub_graph_feas = gpu_graph_ptr->get_sub_graph_fea(node_ids, slot_num); for (size_t i = 0; i < device_num; i++) { - feature_num += feature_list_size[i]; + feature_list[i] = sub_graph_feas[i].feature_list; + feature_list_size[i] = sub_graph_feas[i].feature_size; } - VLOG(0) << "feature_num is " << feature_num << " node_num num is " << node_num; - - size_t set_num = device_num * 8; - std::vector> feature_id_set(set_num); - std::vector set_mutex(set_num); - - auto add_feature_to_set = [this, set_num, &feature_list, &feature_id_set, &set_mutex] (int dev, size_t start, size_t end) { - size_t batch = 10000 * set_num; - std::vector> feature_list_tmp(set_num); - for (size_t i = 0; i < set_num; i++) { - feature_list_tmp[i].reserve((batch * 1.2) /set_num); - } - std::vector shuffle_set_index = shuffle_int_vector(set_num); - size_t pos = start; - size_t real_batch = 0; - while (pos < end) { - real_batch = (pos + batch <= end) ? batch : end - pos; - for (size_t i = pos; i < pos + real_batch; i++) { - if (feature_list[dev][i] == 0) { - continue; - } - int shard_num = feature_list[dev][i] % set_num; - feature_list_tmp[shard_num].push_back(feature_list[dev][i]); - } - // uniq in local + } else { + PADDLE_ENFORCE_EQ(1, 0, " FLAGS_gpugraph_storage_mode is not adaptived "); + } + time_stage.Pause(); + get_feature_id_cost = time_stage.ElapsedSec(); + size_t feature_num = 0; + for (size_t i = 0; i < device_num; i++) { + feature_num += feature_list_size[i]; + } + VLOG(0) << "feature_num is " << feature_num << " node_num num is " + << node_num; + + size_t set_num = thread_keys_shard_num_; + std::vector> feature_id_set(set_num); + std::vector set_mutex(set_num); + + auto add_feature_to_set = + [this, set_num, &feature_list, &feature_id_set, &set_mutex]( + int dev, size_t start, size_t end) { + size_t batch = 10000 * set_num; + std::vector> feature_list_tmp(set_num); for (size_t i = 0; i < set_num; i++) { - std::sort(feature_list_tmp[i].begin(), feature_list_tmp[i].end()); - size_t idx = 0; - size_t total = feature_list_tmp[i].size(); - for (size_t j = 0; j < total; j++) { - auto &k = feature_list_tmp[i][j]; - if (idx > 0 && feature_list_tmp[i][idx - 1] == k) { + feature_list_tmp[i].reserve((batch * 1.2) / set_num); + } + std::vector shuffle_set_index = shuffle_int_vector(set_num); + size_t pos = start; + size_t real_batch = 0; + while (pos < end) { + real_batch = (pos + batch <= end) ? batch : end - pos; + for (size_t i = pos; i < pos + real_batch; i++) { + if (feature_list[dev][i] == 0) { continue; } - feature_list_tmp[i][idx] = k; - ++idx; + int shard_num = feature_list[dev][i] % set_num; + feature_list_tmp[shard_num].push_back(feature_list[dev][i]); } - feature_list_tmp[i].resize(idx); - } - // uniq in global - for (auto set_index : shuffle_set_index) { - set_mutex[set_index].lock(); - for (auto feature_id : feature_list_tmp[set_index]) { + // uniq in local + for (size_t i = 0; i < set_num; i++) { + std::sort(feature_list_tmp[i].begin(), feature_list_tmp[i].end()); + size_t idx = 0; + size_t total = feature_list_tmp[i].size(); + for (size_t j = 0; j < total; j++) { + auto& k = feature_list_tmp[i][j]; + if (idx > 0 && feature_list_tmp[i][idx - 1] == k) { + continue; + } + feature_list_tmp[i][idx] = k; + ++idx; + } + feature_list_tmp[i].resize(idx); + } + // uniq in global + for (auto set_index : shuffle_set_index) { + set_mutex[set_index].lock(); + for (auto feature_id : feature_list_tmp[set_index]) { feature_id_set[set_index].insert(feature_id); + } + set_mutex[set_index].unlock(); + feature_list_tmp[set_index].clear(); } - set_mutex[set_index].unlock(); - feature_list_tmp[set_index].clear(); + pos += real_batch; } - pos += real_batch; - } - }; - size_t device_thread_num = 8; - threads.resize(device_num * device_thread_num); - time_stage.Start(); - for (size_t i = 0; i < device_num; i++) { - size_t start = 0; - for (size_t j = 0; j < device_thread_num; j++) { - size_t batch = feature_list_size[i] / device_thread_num; - if (j < feature_list_size[i] % device_thread_num) { - batch += 1; - } - threads[i * device_thread_num + j] = std::thread(add_feature_to_set, i, start, start + batch); - start += batch; + }; + size_t device_thread_num = 8; + threads.resize(device_num * device_thread_num); + time_stage.Start(); + for (size_t i = 0; i < device_num; i++) { + size_t start = 0; + for (size_t j = 0; j < device_thread_num; j++) { + size_t batch = feature_list_size[i] / device_thread_num; + if (j < feature_list_size[i] % device_thread_num) { + batch += 1; } + threads[i * device_thread_num + j] = + std::thread(add_feature_to_set, i, start, start + batch); + start += batch; } - for (std::thread& t : threads) { - t.join(); - } - threads.clear(); - time_stage.Pause(); - add_feature_to_set_cost = time_stage.ElapsedSec(); - auto add_feature_to_key = [this, device_num, &feature_id_set, &local_dim_keys, set_num](int dev) { - // set_num = device_num * 8, a % set_num = b , a = set_num * m + b , a % device_num = b % device_num - size_t key_num = 0; - for (size_t i = dev; i < set_num; i += device_num) { - key_num += feature_id_set[i].size(); - } - VLOG(0) << " feature_num is " << key_num << " for device: " << dev; - local_dim_keys[dev][0].reserve(local_dim_keys[dev][0].size() + key_num); - for (size_t i = dev; i < set_num; i += device_num) { - for (auto it = feature_id_set[i].begin(); it != feature_id_set[i].end(); it++) { - local_dim_keys[dev][0].push_back(*it); - } - feature_id_set[i].clear(); - } - }; - time_stage.Start(); - threads.resize(device_num); - for (size_t i = 0; i < device_num; i++) { - threads[i] = std::thread(add_feature_to_key, i); + } + for (std::thread& t : threads) { + t.join(); + } + threads.clear(); + time_stage.Pause(); + add_feature_to_set_cost = time_stage.ElapsedSec(); + auto add_feature_to_key = [this, + device_num, + &feature_id_set, + &local_dim_keys, + set_num](int shard_num, int j) { + local_dim_keys[shard_num][j].reserve(local_dim_keys[shard_num][j].size() + + feature_id_set[shard_num].size()); + for (auto it = feature_id_set[shard_num].begin(); + it != feature_id_set[shard_num].end(); + it++) { + local_dim_keys[shard_num][j].push_back(*it); } - for (std::thread& t : threads) { - t.join(); + feature_id_set[shard_num].clear(); + }; + time_stage.Start(); + threads.resize(thread_keys_shard_num_ * multi_mf_dim_); + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i * multi_mf_dim_ + j] = std::thread(add_feature_to_key, i, j); } - time_stage.Pause(); - add_feature_to_key_cost = time_stage.ElapsedSec(); - threads.clear(); - timeline.Pause(); - VLOG(0) << " add_slot_feature costs: " << timeline.ElapsedSec() << " s." - << " divide_nodeid_cost " << divide_nodeid_cost - << " get_feature_id_cost " << get_feature_id_cost - << " add_feature_to_set_cost " << add_feature_to_set_cost - << " add_feature_to_key_cost " << add_feature_to_key_cost; + } + for (std::thread& t : threads) { + t.join(); + } + time_stage.Pause(); + add_feature_to_key_cost = time_stage.ElapsedSec(); + threads.clear(); + timeline.Pause(); + VLOG(0) << " add_slot_feature costs: " << timeline.ElapsedSec() << " s." + << " divide_nodeid_cost " << divide_nodeid_cost + << " get_feature_id_cost " << get_feature_id_cost + << " add_feature_to_set_cost " << add_feature_to_set_cost + << " add_feature_to_key_cost " << add_feature_to_key_cost; } void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { platform::Timer timeline; - size_t slot_num = slot_vector_.size() - 1; //node slot 9008 in slot_vector - if (slot_num > 0 && FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { + size_t slot_num = slot_vector_.size() - 1; // node slot 9008 in slot_vector + if (slot_num > 0 && FLAGS_gpugraph_storage_mode != + paddle::framework::GpuGraphStorageMode::WHOLE_HBM) { add_slot_feature(gpu_task); } resize_gputask(gpu_task); + + platform::Timer time_stage; + time_stage.Start(); + gpu_task->UniqueKeys(); + time_stage.Pause(); + VLOG(0) << "BuildPull slot feature uniq and sort cost time: " + << time_stage.ElapsedSec(); + auto& local_dim_keys = gpu_task->feature_dim_keys_; auto& local_dim_ptr = gpu_task->value_dim_ptr_; @@ -928,10 +958,10 @@ void PSGPUWrapper::PrepareGPUTask(std::shared_ptr gpu_task) { prefix_sum[i][0] = 0; } auto calc_prefix_func = [this, - &prefix_sum, - &device_keys, - &device_vals, - &device_task_keys](int device_num) { + &prefix_sum, + &device_keys, + &device_vals, + &device_task_keys](int device_num) { for (int j = 0; j < thread_keys_shard_num_; j++) { prefix_sum[device_num][j + 1] = prefix_sum[device_num][j] + device_task_keys[j][device_num].size(); @@ -953,42 +983,42 @@ void PSGPUWrapper::PrepareGPUTask(std::shared_ptr gpu_task) { } VLOG(0) << "prefix done"; auto prepare_dev_value_func = [device_num, - &prefix_sum, - &device_keys, - &device_vals, - &device_task_keys, - &device_task_ptrs](int dev, int shard_id) { + &prefix_sum, + &device_keys, + &device_vals, + &device_task_keys, + &device_task_ptrs](int dev, int shard_id) { #ifdef PADDLE_WITH_PSLIB - auto& task_ptrs = device_task_ptrs[shard_id]; - - for (int j = 0; j < len; ++j) { - device_keys[dev][cur + j] = task_keys[dev][j]; - float* ptr_val = task_ptrs[dev][j]->data(); - FeatureValue& val = device_vals[dev][cur + j]; - size_t dim = task_ptrs[dev][j]->size(); - - val.delta_score = ptr_val[1]; - val.show = ptr_val[2]; - val.clk = ptr_val[3]; - val.slot = ptr_val[6]; - val.lr = ptr_val[4]; - val.lr_g2sum = ptr_val[5]; - val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; - } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; - } + auto& task_ptrs = device_task_ptrs[shard_id]; + + for (int j = 0; j < len; ++j) { + device_keys[dev][cur + j] = task_keys[dev][j]; + float* ptr_val = task_ptrs[dev][j]->data(); + FeatureValue& val = device_vals[dev][cur + j]; + size_t dim = task_ptrs[dev][j]->size(); + + val.delta_score = ptr_val[1]; + val.show = ptr_val[2]; + val.clk = ptr_val[3]; + val.slot = ptr_val[6]; + val.lr = ptr_val[4]; + val.lr_g2sum = ptr_val[5]; + val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; } } - #endif - VLOG(3) << "GpuPs build hbmps done"; + } +#endif + VLOG(3) << "GpuPs build hbmps done"; }; if (!multi_mf_dim_) { for (int i = 0; i < thread_keys_shard_num_; i++) { @@ -1004,10 +1034,9 @@ void PSGPUWrapper::PrepareGPUTask(std::shared_ptr gpu_task) { } timeline.Pause(); VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() - << " seconds."; + << " seconds."; } - void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { int device_num = heter_devices_.size(); platform::Timer stagetime; @@ -1150,16 +1179,21 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { }; threads[j] = std::thread(build_ps_thread, i, j, len, feature_value_size); } - //build feature table - size_t slot_num = slot_vector_.size() - 1;//node slot 9008 in slot_vector - if (slot_num > 0 && (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::MEM_EMB_FEATURE_AND_GPU_GRAPH - || FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH)) { - auto build_feature_table = [this, &gpu_task](int i) { - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - std::vector * tmp = (std::vector *) gpu_task->sub_graph_feas; - gpu_graph_ptr->build_gpu_graph_fea((*tmp)[i], i); - }; - threads.push_back(std::thread(build_feature_table, i)); + // build feature table + size_t slot_num = slot_vector_.size() - 1; // node slot 9008 in slot_vector + if (slot_num > 0 && + (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode:: + MEM_EMB_FEATURE_AND_GPU_GRAPH || + FLAGS_gpugraph_storage_mode == + paddle::framework::GpuGraphStorageMode:: + SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH)) { + auto build_feature_table = [this, &gpu_task](int i) { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + std::vector* tmp = + (std::vector*)gpu_task->sub_graph_feas; + gpu_graph_ptr->build_gpu_graph_fea((*tmp)[i], i); + }; + threads.push_back(std::thread(build_feature_table, i)); } struct task_info task; @@ -1227,9 +1261,12 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { f.wait(); } gpu_task_futures.clear(); - if (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::MEM_EMB_FEATURE_AND_GPU_GRAPH - || FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) { - std::vector * tmp = (std::vector *) gpu_task->sub_graph_feas; + if (FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode:: + MEM_EMB_FEATURE_AND_GPU_GRAPH || + FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode:: + SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) { + std::vector* tmp = + (std::vector*)gpu_task->sub_graph_feas; delete tmp; gpu_task->sub_graph_feas = NULL; }