diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc index 3e0f631ed41bc..05f7c5c5780ea 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -34,6 +34,7 @@ int32_t SSDSparseTable::Initialize() { MemorySparseTable::Initialize(); _db = paddle::distributed::RocksDBHandler::GetInstance(); _db->initialize(FLAGS_rocksdb_path, _real_local_shard_num); + VLOG(0) << "initalize SSDSparseTable succ"; return 0; } @@ -549,7 +550,11 @@ int32_t SSDSparseTable::Save(const std::string& path, std::string table_path = TableDir(path); _afs_client.remove(paddle::string::format_string( "%s/part-%03d-*", table_path.c_str(), _shard_idx)); +#ifdef PADDLE_WITH_GPU_GRAPH + int thread_num = _real_local_shard_num; +#else int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; +#endif // std::atomic feasign_size; std::atomic feasign_size_all{0}; diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index c88c91f166112..bd8432b0d694a 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2118,6 +2118,10 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { } } +void SlotRecordInMemoryDataFeed::InitGraphResource() { + gpu_graph_data_generator_.AllocResource(thread_id_, feed_vec_); +} + void SlotRecordInMemoryDataFeed::LoadIntoMemory() { VLOG(3) << "SlotRecord LoadIntoMemory() begin, thread_id=" << thread_id_; if (!so_parser_name_.empty()) { @@ -2654,7 +2658,7 @@ bool SlotRecordInMemoryDataFeed::Start() { pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_); #endif #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) - gpu_graph_data_generator_.AllocResource(this->place_, feed_vec_); + gpu_graph_data_generator_.SetFeedVec(feed_vec_); #endif return true; } diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 3c4f2c5bbc74d..ed4954e4bb027 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -25,8 +25,10 @@ limitations under the License. */ #include "cub/cub.cuh" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" DECLARE_bool(enable_opt_get_features); +DECLARE_int32(gpugraph_storage_mode); namespace paddle { namespace framework { @@ -348,6 +350,9 @@ int GraphDataGenerator::FillInsBuf() { buf_state_.Debug(); if (total_instance == 0) { + if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::CPU) { + return -1; + } int res = FillWalkBuf(d_walk_); if (!res) { // graph iterate complete @@ -393,9 +398,9 @@ int GraphDataGenerator::FillInsBuf() { uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); int *random_row = reinterpret_cast(d_random_row_->ptr()); int *d_pair_num = reinterpret_cast(d_pair_num_->ptr()); - cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream_); + cudaMemsetAsync(d_pair_num, 0, sizeof(int), train_stream_); int len = buf_state_.len; - GraphFillIdKernel<<>>( + GraphFillIdKernel<<>>( ins_buf + ins_buf_pair_len_ * 2, d_pair_num, walk, @@ -405,16 +410,22 @@ int GraphDataGenerator::FillInsBuf() { len, walk_len_); int h_pair_num; - cudaMemcpyAsync( - &h_pair_num, d_pair_num, sizeof(int), cudaMemcpyDeviceToHost, stream_); + cudaMemcpyAsync(&h_pair_num, + d_pair_num, + sizeof(int), + cudaMemcpyDeviceToHost, + train_stream_); if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); uint64_t *feature = reinterpret_cast(d_feature_->ptr()); - cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream_); + cudaMemsetAsync(d_pair_num, 0, sizeof(int), train_stream_); int len = buf_state_.len; VLOG(2) << "feature_buf start[" << ins_buf_pair_len_ * 2 * slot_num_ << "] len[" << len << "]"; - GraphFillFeatureKernel<<>>( + GraphFillFeatureKernel<<>>( feature_buf + ins_buf_pair_len_ * 2 * slot_num_, d_pair_num, walk, @@ -427,7 +438,7 @@ int GraphDataGenerator::FillInsBuf() { slot_num_); } - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(train_stream_); ins_buf_pair_len_ += h_pair_num; if (debug_mode_) { @@ -463,21 +474,23 @@ int GraphDataGenerator::GenerateBatch() { int total_instance = 0; platform::CUDADeviceGuard guard(gpuid_); int res = 0; + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); if (!gpu_graph_training_) { + auto &infer_node_type_start = gpu_graph_ptr->infer_node_type_start_[gpuid_]; while (cursor_ < h_device_keys_.size()) { size_t device_key_size = h_device_keys_[cursor_]->size(); - if (infer_node_type_start_[cursor_] >= device_key_size) { + if (infer_node_type_start[cursor_] >= device_key_size) { cursor_++; continue; } total_instance = - (infer_node_type_start_[cursor_] + batch_size_ <= device_key_size) + (infer_node_type_start[cursor_] + batch_size_ <= device_key_size) ? batch_size_ - : device_key_size - infer_node_type_start_[cursor_]; + : device_key_size - infer_node_type_start[cursor_]; uint64_t *d_type_keys = reinterpret_cast(d_device_keys_[cursor_]->ptr()); - d_type_keys += infer_node_type_start_[cursor_]; - infer_node_type_start_[cursor_] += total_instance; + d_type_keys += infer_node_type_start[cursor_]; + infer_node_type_start[cursor_] += total_instance; VLOG(1) << "in graph_data generator:batch_size = " << batch_size_ << " instance = " << total_instance; total_instance *= 2; @@ -490,16 +503,16 @@ int GraphDataGenerator::GenerateBatch() { CopyDuplicateKeys<<>>( + train_stream_>>>( id_tensor_ptr_, d_type_keys, total_instance / 2); GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); + train_stream_>>>(show_tensor_ptr_, total_instance); GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + train_stream_>>>(clk_tensor_ptr_, total_instance); break; } if (total_instance == 0) { @@ -520,6 +533,8 @@ int GraphDataGenerator::GenerateBatch() { ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_; total_instance *= 2; + VLOG(2) << "total ins: " << total_instance << " gpuid: " << gpuid_ + << " feed_vec: " << feed_vec_[0]; id_tensor_ptr_ = feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); show_tensor_ptr_ = @@ -542,12 +557,12 @@ int GraphDataGenerator::GenerateBatch() { slot_tensor_ptr_, sizeof(uint64_t *) * slot_num_, cudaMemcpyHostToDevice, - stream_); + train_stream_); cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), slot_lod_tensor_ptr_, sizeof(uint64_t *) * slot_num_, cudaMemcpyHostToDevice, - stream_); + train_stream_); } } @@ -563,16 +578,16 @@ int GraphDataGenerator::GenerateBatch() { ins_cursor, sizeof(uint64_t) * total_instance, cudaMemcpyDeviceToDevice, - stream_); + train_stream_); GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); + train_stream_>>>(show_tensor_ptr_, total_instance); GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + train_stream_>>>(clk_tensor_ptr_, total_instance); } else { ins_cursor = (uint64_t *)id_tensor_ptr_; } @@ -608,15 +623,16 @@ int GraphDataGenerator::GenerateBatch() { GraphFillSlotKernel<<>>((uint64_t *)d_slot_tensor_ptr_->ptr(), - feature_buf, - total_instance * slot_num_, - total_instance, - slot_num_); + train_stream_>>>( + (uint64_t *)d_slot_tensor_ptr_->ptr(), + feature_buf, + total_instance * slot_num_, + total_instance, + slot_num_); GraphFillSlotLodKernelOpt<<>>( + train_stream_>>>( (uint64_t *)d_slot_lod_tensor_ptr_->ptr(), (total_instance + 1) * slot_num_, total_instance + 1); @@ -633,13 +649,13 @@ int GraphDataGenerator::GenerateBatch() { &feature_buf[feature_buf_offset + j * slot_num_], sizeof(uint64_t) * 2, cudaMemcpyDeviceToDevice, - stream_); + train_stream_); } GraphFillSlotLodKernel<<>>(slot_lod_tensor_ptr_[i], - total_instance + 1); + train_stream_>>>(slot_lod_tensor_ptr_[i], + total_instance + 1); } } } @@ -655,7 +671,7 @@ int GraphDataGenerator::GenerateBatch() { } } - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(train_stream_); if (!gpu_graph_training_) return 1; ins_buf_pair_len_ -= total_instance / 2; if (debug_mode_) { @@ -773,45 +789,50 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, d_actual_sample_size, d_prefix_sum + 1, len, - stream_)); - auto d_temp_storage = memory::Alloc(place_, temp_storage_bytes); + sample_stream_)); + auto d_temp_storage = memory::Alloc( + place_, + temp_storage_bytes, + phi::Stream(reinterpret_cast(sample_stream_))); CUDA_CHECK(cub::DeviceScan::InclusiveSum(d_temp_storage->ptr(), temp_storage_bytes, d_actual_sample_size, d_prefix_sum + 1, len, - stream_)); + sample_stream_)); - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(sample_stream_); if (step == 1) { - GraphFillFirstStepKernel<<>>( - d_prefix_sum, - d_tmp_sampleidx2row, - walk, - d_start_ids, - len, - walk_degree_, - walk_len_, - d_actual_sample_size, - d_neighbors, - d_sample_keys); + GraphFillFirstStepKernel<<>>(d_prefix_sum, + d_tmp_sampleidx2row, + walk, + d_start_ids, + len, + walk_degree_, + walk_len_, + d_actual_sample_size, + d_neighbors, + d_sample_keys); } else { GraphFillSampleKeysKernel<<>>(d_neighbors, - d_sample_keys, - d_prefix_sum, - d_sampleidx2row, - d_tmp_sampleidx2row, - d_actual_sample_size, - cur_degree, - len); - - GraphDoWalkKernel<<>>( + sample_stream_>>>(d_neighbors, + d_sample_keys, + d_prefix_sum, + d_sampleidx2row, + d_tmp_sampleidx2row, + d_actual_sample_size, + cur_degree, + len); + + GraphDoWalkKernel<<>>( d_neighbors, walk, d_prefix_sum, @@ -849,7 +870,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, delete[] h_offset2idx; delete[] h_sample_keys; } - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(sample_stream_); cur_sampleidx2row_ = 1 - cur_sampleidx2row_; } @@ -899,38 +920,51 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { uint64_t *walk = reinterpret_cast(d_walk->ptr()); int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); uint64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); - cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), stream_); - cudaMemsetAsync( - len_per_row, 0, once_max_sample_keynum * sizeof(int), stream_); + cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_); + VLOG(2) << "wxx aaa"; + // cudaMemsetAsync( + // len_per_row, 0, once_max_sample_keynum * sizeof(int), sample_stream_); int i = 0; int total_row = 0; - size_t node_type_len = first_node_type_.size(); + + // 获取全局采样状态 + auto &first_node_type = gpu_graph_ptr->first_node_type_; + auto &meta_path = gpu_graph_ptr->meta_path_; + auto &node_type_start = gpu_graph_ptr->node_type_start_[gpuid_]; + auto &finish_node_type = gpu_graph_ptr->finish_node_type_[gpuid_]; + + size_t node_type_len = first_node_type.size(); int remain_size = buf_size_ - walk_degree_ * once_sample_startid_len_ * walk_len_; while (i <= remain_size) { int cur_node_idx = cursor_ % node_type_len; - int node_type = first_node_type_[cur_node_idx]; - auto &path = meta_path_[cur_node_idx]; - size_t start = node_type_start_[node_type]; - // auto node_query_result = gpu_graph_ptr->query_node_list( - // gpuid_, node_type, start, once_sample_startid_len_); - - // int tmp_len = node_query_result.actual_sample_size; + int node_type = first_node_type[cur_node_idx]; + VLOG(2) << "cur_node_idx = " << cur_node_idx + << " meta_path.size = " << meta_path.size(); + auto &path = meta_path[cur_node_idx]; + size_t start = node_type_start[node_type]; + auto node_query_result = gpu_graph_ptr->query_node_list( + gpuid_, node_type, start, once_sample_startid_len_); + + int tmp_len = node_query_result.actual_sample_size; VLOG(2) << "choose start type: " << node_type; - int type_index = type_to_index_[node_type]; - size_t device_key_size = h_device_keys_[type_index]->size(); - VLOG(2) << "type: " << node_type << " size: " << device_key_size + // int type_index = type_to_index_[node_type]; + // size_t device_key_size = h_device_keys_[type_index]->size(); + VLOG(2) << "type: " << node_type << " size: " << tmp_len << " start: " << start; - uint64_t *d_type_keys = - reinterpret_cast(d_device_keys_[type_index]->ptr()); - int tmp_len = start + once_sample_startid_len_ > device_key_size - ? device_key_size - start - : once_sample_startid_len_; - node_type_start_[node_type] = tmp_len + start; + uint64_t *d_type_keys = node_query_result.val; + // uint64_t *d_type_keys = + // reinterpret_cast(d_device_keys_[type_index]->ptr()); + // int tmp_len = start + once_sample_startid_len_ > device_key_size + // ? device_key_size - start + // : once_sample_startid_len_; + node_type_start[node_type] = tmp_len + start; if (tmp_len == 0) { - finish_node_type_.insert(node_type); - if (finish_node_type_.size() == node_type_start_.size()) { + finish_node_type.insert(node_type); + VLOG(2) << "finish_node_type size: " << finish_node_type.size() + << " node_type_start size: " << node_type_start.size(); + if (finish_node_type.size() == node_type_start.size()) { break; } cursor_ += 1; @@ -942,20 +976,24 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { VLOG(2) << "i = " << i << " buf_size_ = " << buf_size_ << " tmp_len = " << tmp_len << " cursor = " << cursor_ << " once_max_sample_keynum = " << once_max_sample_keynum; + VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0]; uint64_t *cur_walk = walk + i; NeighborSampleQuery q; - q.initialize(gpuid_, - path[0], - (uint64_t)(d_type_keys + start), - walk_degree_, - tmp_len); + q.initialize( + gpuid_, path[0], (uint64_t)(d_type_keys), walk_degree_, tmp_len); + VLOG(2) << "tag aaa"; auto sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false); + VLOG(2) << "tag bbb"; int step = 1; VLOG(2) << "sample edge type: " << path[0] << " step: " << 1; jump_rows_ = sample_res.total_sample_size; - FillOneStep(d_type_keys + start, + if (jump_rows_ == 0) { + cursor_ += 1; + continue; + } + FillOneStep(d_type_keys, cur_walk, tmp_len, sample_res, @@ -990,7 +1028,7 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { sample_res.total_sample_size); sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false); - FillOneStep(d_type_keys + start, + FillOneStep(d_type_keys, cur_walk, sample_res.total_sample_size, sample_res, @@ -1014,7 +1052,7 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { int *d_random_row = reinterpret_cast(d_random_row_->ptr()); thrust::random::default_random_engine engine(shuffle_seed_); - const auto &exec_policy = thrust::cuda::par.on(stream_); + const auto &exec_policy = thrust::cuda::par.on(sample_stream_); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, @@ -1022,7 +1060,7 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { thrust::device_pointer_cast(d_random_row), engine); - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(sample_stream_); shuffle_seed_ = engine(); if (debug_mode_) { @@ -1044,68 +1082,82 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { return total_row != 0; } -void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, - std::vector feed_vec) { - place_ = place; - gpuid_ = place_.GetDeviceId(); - VLOG(3) << "gpuid " << gpuid_; - stream_ = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); +GraphDataGenerator::~GraphDataGenerator() { + CUDA_CHECK(cudaStreamDestroy(sample_stream_)); +} + +void GraphDataGenerator::SetFeedVec(std::vector feed_vec) { feed_vec_ = feed_vec; +} +void GraphDataGenerator::AllocResource(int thread_id, + std::vector feed_vec) { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + gpuid_ = gpu_graph_ptr->device_id_mapping[thread_id]; + place_ = platform::CUDAPlace(gpuid_); + + platform::CUDADeviceGuard guard(gpuid_); + + VLOG(3) << "AllocResource gpuid " << gpuid_; + train_stream_ = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); + CUDA_CHECK(cudaStreamCreateWithFlags(&sample_stream_, cudaStreamNonBlocking)); + // feed_vec_ = feed_vec; slot_num_ = (feed_vec_.size() - 3) / 2; - // d_device_keys_.resize(h_device_keys_.size()); VLOG(2) << "h_device_keys size: " << h_device_keys_.size(); - infer_node_type_start_ = std::vector(h_device_keys_.size(), 0); - for (size_t i = 0; i < h_device_keys_.size(); i++) { - for (size_t j = 0; j < h_device_keys_[i]->size(); j++) { - VLOG(3) << "h_device_keys_[" << i << "][" << j - << "] = " << (*(h_device_keys_[i]))[j]; - } - auto buf = memory::AllocShared( - place_, h_device_keys_[i]->size() * sizeof(uint64_t)); - d_device_keys_.push_back(buf); - CUDA_CHECK(cudaMemcpyAsync(buf->ptr(), - h_device_keys_[i]->data(), - h_device_keys_[i]->size() * sizeof(uint64_t), - cudaMemcpyHostToDevice, - stream_)); - } - // h_device_keys_ = h_device_keys; - // device_key_size_ = h_device_keys_->size(); - // d_device_keys_ = - // memory::AllocShared(place_, device_key_size_ * sizeof(int64_t)); - // CUDA_CHECK(cudaMemcpyAsync(d_device_keys_->ptr(), h_device_keys_->data(), - // device_key_size_ * sizeof(int64_t), - // cudaMemcpyHostToDevice, stream_)); + // infer_node_type_start_ = std::vector(h_device_keys_.size(), 0); + // for (size_t i = 0; i < h_device_keys_.size(); i++) { + // for (size_t j = 0; j < h_device_keys_[i]->size(); j++) { + // VLOG(3) << "h_device_keys_[" << i << "][" << j + // << "] = " << (*(h_device_keys_[i]))[j]; + // } + // auto buf = memory::AllocShared( + // place_, h_device_keys_[i]->size() * sizeof(uint64_t)); + // d_device_keys_.push_back(buf); + // CUDA_CHECK(cudaMemcpyAsync(buf->ptr(), + // h_device_keys_[i]->data(), + // h_device_keys_[i]->size() * sizeof(uint64_t), + // cudaMemcpyHostToDevice, + // stream_)); + // } size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; - d_prefix_sum_ = - memory::AllocShared(place_, (once_max_sample_keynum + 1) * sizeof(int)); + d_prefix_sum_ = memory::AllocShared( + place_, + (once_max_sample_keynum + 1) * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); int *d_prefix_sum_ptr = reinterpret_cast(d_prefix_sum_->ptr()); - cudaMemsetAsync( - d_prefix_sum_ptr, 0, (once_max_sample_keynum + 1) * sizeof(int), stream_); + cudaMemsetAsync(d_prefix_sum_ptr, + 0, + (once_max_sample_keynum + 1) * sizeof(int), + sample_stream_); cursor_ = 0; jump_rows_ = 0; - d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(uint64_t)); - cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); - if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { - d_feature_ = - memory::AllocShared(place_, buf_size_ * slot_num_ * sizeof(uint64_t)); - cudaMemsetAsync( - d_feature_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); - } - d_sample_keys_ = - memory::AllocShared(place_, once_max_sample_keynum * sizeof(uint64_t)); + d_walk_ = memory::AllocShared( + place_, + buf_size_ * sizeof(uint64_t), + phi::Stream(reinterpret_cast(sample_stream_))); + cudaMemsetAsync( + d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), sample_stream_); + d_sample_keys_ = memory::AllocShared( + place_, + once_max_sample_keynum * sizeof(uint64_t), + phi::Stream(reinterpret_cast(sample_stream_))); - d_sampleidx2rows_.push_back( - memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); - d_sampleidx2rows_.push_back( - memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); + d_sampleidx2rows_.push_back(memory::AllocShared( + place_, + once_max_sample_keynum * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_)))); + d_sampleidx2rows_.push_back(memory::AllocShared( + place_, + once_max_sample_keynum * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_)))); cur_sampleidx2row_ = 0; - d_len_per_row_ = - memory::AllocShared(place_, once_max_sample_keynum * sizeof(int)); + d_len_per_row_ = memory::AllocShared( + place_, + once_max_sample_keynum * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); for (int i = -window_; i < 0; i++) { window_step_.push_back(i); } @@ -1115,7 +1167,8 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, buf_state_.Init(batch_size_, walk_len_, &window_step_); d_random_row_ = memory::AllocShared( place_, - (once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int)); + (once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); shuffle_seed_ = 0; ins_buf_pair_len_ = 0; @@ -1133,7 +1186,7 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); } - cudaStreamSynchronize(stream_); + cudaStreamSynchronize(sample_stream_); } void GraphDataGenerator::SetConfig( @@ -1161,40 +1214,7 @@ void GraphDataGenerator::SetConfig( std::string first_node_type = graph_config.first_node_type(); std::string meta_path = graph_config.meta_path(); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - auto edge_to_id = gpu_graph_ptr->edge_to_id; - auto node_to_id = gpu_graph_ptr->feature_to_id; - // parse first_node_type - auto node_types = - paddle::string::split_string(first_node_type, ";"); - VLOG(2) << "node_types: " << first_node_type; - finish_node_type_.clear(); - node_type_start_.clear(); - for (auto &type : node_types) { - auto iter = node_to_id.find(type); - PADDLE_ENFORCE_NE( - iter, - node_to_id.end(), - platform::errors::NotFound("(%s) is not found in node_to_id.", type)); - VLOG(2) << "node_to_id[" << type << "] = " << iter->second; - first_node_type_.push_back(iter->second); - node_type_start_[iter->second] = 0; - } - meta_path_.resize(first_node_type_.size()); - auto meta_paths = paddle::string::split_string(meta_path, ";"); - - for (size_t i = 0; i < meta_paths.size(); i++) { - auto path = meta_paths[i]; - auto nodes = paddle::string::split_string(path, "-"); - for (auto &node : nodes) { - auto iter = edge_to_id.find(node); - PADDLE_ENFORCE_NE( - iter, - edge_to_id.end(), - platform::errors::NotFound("(%s) is not found in edge_to_id.", node)); - VLOG(2) << "edge_to_id[" << node << "] = " << iter->second; - meta_path_[i].push_back(iter->second); - } - } + gpu_graph_ptr->init_conf(first_node_type, meta_path); }; } // namespace framework diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index ba9d5d0546791..5874f4534e18a 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -893,10 +893,10 @@ struct BufState { class GraphDataGenerator { public: GraphDataGenerator(){}; - virtual ~GraphDataGenerator(){}; + virtual ~GraphDataGenerator(); void SetConfig(const paddle::framework::DataFeedDesc& data_feed_desc); - void AllocResource(const paddle::platform::Place& place, - std::vector feed_vec); + void AllocResource(int thread_id, std::vector feed_vec); + void SetFeedVec(std::vector feed_vec); int AcquireInstance(BufState* state); int GenerateBatch(); int FillWalkBuf(std::shared_ptr d_walk); @@ -933,7 +933,8 @@ class GraphDataGenerator { int64_t* id_tensor_ptr_; int64_t* show_tensor_ptr_; int64_t* clk_tensor_ptr_; - cudaStream_t stream_; + cudaStream_t train_stream_; + cudaStream_t sample_stream_; paddle::platform::Place place_; std::vector feed_vec_; std::vector offset_; @@ -951,10 +952,6 @@ class GraphDataGenerator { std::shared_ptr d_sample_keys_; int sample_keys_len_; - std::set finish_node_type_; - std::unordered_map node_type_start_; - std::vector infer_node_type_start_; - std::shared_ptr d_ins_buf_; std::shared_ptr d_feature_buf_; std::shared_ptr d_pair_num_; @@ -970,8 +967,6 @@ class GraphDataGenerator { int slot_num_; int shuffle_seed_; int debug_mode_; - std::vector first_node_type_; - std::vector> meta_path_; bool gpu_graph_training_; }; @@ -1037,6 +1032,7 @@ class DataFeed { virtual void SetParseLogKey(bool parse_logkey) {} virtual void SetEnablePvMerge(bool enable_pv_merge) {} virtual void SetCurrentPhase(int current_phase) {} + virtual void InitGraphResource() {} virtual void SetDeviceKeys(std::vector* device_keys, int type) { #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.SetDeviceKeys(device_keys, type); @@ -1637,6 +1633,7 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { // CustomParser* parser) {} virtual void PutToFeedVec(const std::vector& ins_vec) {} + virtual void InitGraphResource(void); virtual void LoadIntoMemoryByCommand(void); virtual void LoadIntoMemoryByLib(void); virtual void LoadIntoMemoryByLine(void); diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 1d70ef6a1c78b..6a0f4101dfbed 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -37,6 +37,7 @@ USE_INT_STAT(STAT_total_feasign_num_in_mem); DECLARE_bool(graph_get_neighbor_id); +DECLARE_int32(gpugraph_storage_mode); namespace paddle { namespace framework { @@ -473,60 +474,44 @@ void DatasetImpl::LoadIntoMemory() { auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); auto node_to_id = gpu_graph_ptr->feature_to_id; auto edge_to_id = gpu_graph_ptr->edge_to_id; - graph_all_type_total_keys_.resize(node_to_id.size()); - int cnt = 0; - // set sample start node - for (auto& iter : node_to_id) { - int node_idx = iter.second; - std::vector> gpu_graph_device_keys; - gpu_graph_ptr->get_all_id( - 1, node_idx, thread_num_, &gpu_graph_device_keys); - auto& type_total_key = graph_all_type_total_keys_[cnt]; - type_total_key.resize(thread_num_); - for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { - VLOG(2) << "node type: " << node_idx << ", gpu_graph_device_keys[" << i - << "] = " << gpu_graph_device_keys[i].size(); - for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { - type_total_key[i].push_back(gpu_graph_device_keys[i][j]); - } - } - - for (size_t i = 0; i < readers_.size(); i++) { - readers_[i]->SetDeviceKeys(&type_total_key[i], node_idx); - readers_[i]->SetGpuGraphMode(gpu_graph_mode_); - } - cnt++; - } - // add node embedding id - std::vector> gpu_graph_device_keys; - gpu_graph_ptr->get_node_embedding_ids(thread_num_, &gpu_graph_device_keys); - for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { - for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { - gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]); - } + for (size_t i = 0; i < readers_.size(); i++) { + readers_[i]->SetGpuGraphMode(gpu_graph_mode_); } - - // add feature embedding id - VLOG(2) << "begin add feature_id into gpu_graph_total_keys_ size[" - << gpu_graph_total_keys_.size() << "]"; - for (auto& iter : node_to_id) { + if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::HBM) { + // add node embedding id std::vector> gpu_graph_device_keys; - int node_idx = iter.second; - gpu_graph_ptr->get_all_feature_ids( - 1, node_idx, thread_num_, &gpu_graph_device_keys); + gpu_graph_ptr->get_node_embedding_ids(thread_num_, + &gpu_graph_device_keys); for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { - VLOG(2) << "begin node type: " << node_idx << ", gpu_graph_device_keys[" - << i << "] = " << gpu_graph_device_keys[i].size(); for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]); } - VLOG(2) << "end node type: " << node_idx << ", gpu_graph_device_keys[" - << i << "] = " << gpu_graph_device_keys[i].size(); } + + // add feature embedding id + VLOG(2) << "begin add feature_id into gpu_graph_total_keys_ size[" + << gpu_graph_total_keys_.size() << "]"; + for (auto& iter : node_to_id) { + std::vector> gpu_graph_device_keys; + int node_idx = iter.second; + gpu_graph_ptr->get_all_feature_ids( + 1, node_idx, thread_num_, &gpu_graph_device_keys); + for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { + VLOG(2) << "begin node type: " << node_idx + << ", gpu_graph_device_keys[" << i + << "] = " << gpu_graph_device_keys[i].size(); + for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { + gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]); + } + VLOG(2) << "end node type: " << node_idx << ", gpu_graph_device_keys[" + << i << "] = " << gpu_graph_device_keys[i].size(); + } + } + VLOG(2) << "end add feature_id into gpu_graph_total_keys_ size[" + << gpu_graph_total_keys_.size() << "]"; + } else if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::CPU) { } - VLOG(2) << "end add feature_id into gpu_graph_total_keys_ size[" - << gpu_graph_total_keys_.size() << "]"; #endif } else { for (int64_t i = 0; i < thread_num_; ++i) { @@ -1780,6 +1765,7 @@ void SlotRecordDataset::CreateReaders() { readers_[i]->SetParseLogKey(parse_logkey_); readers_[i]->SetEnablePvMerge(enable_pv_merge_); readers_[i]->SetCurrentPhase(current_phase_); + readers_[i]->InitGraphResource(); if (input_channel_ != nullptr) { readers_[i]->SetInputChannel(input_channel_.get()); } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 3693277a75d39..11865e83aa273 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -698,8 +698,14 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( int h_left[total_gpu]; // NOLINT int h_right[total_gpu]; // NOLINT - auto d_left = memory::Alloc(place, total_gpu * sizeof(int)); - auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); + auto d_left = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + auto d_right = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); int default_value = 0; @@ -710,15 +716,26 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream)); CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream)); // - auto d_idx = memory::Alloc(place, len * sizeof(int)); + auto d_idx = + memory::Alloc(place, + len * sizeof(int), + phi::Stream(reinterpret_cast(stream))); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - auto d_shard_keys = memory::Alloc(place, len * sizeof(uint64_t)); + auto d_shard_keys = + memory::Alloc(place, + len * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); uint64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = - memory::Alloc(place, sample_size * len * sizeof(uint64_t)); + memory::Alloc(place, + sample_size * len * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); uint64_t* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); - auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); + auto d_shard_actual_sample_size = + memory::Alloc(place, + len * sizeof(int), + phi::Stream(reinterpret_cast(stream))); int* d_shard_actual_sample_size_ptr = reinterpret_cast(d_shard_actual_sample_size->ptr()); @@ -919,8 +936,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( int total_sample_size = thrust::reduce(t_actual_sample_size.begin(), t_actual_sample_size.end()); - result.actual_val_mem = - memory::AllocShared(place, total_sample_size * sizeof(uint64_t)); + result.actual_val_mem = memory::AllocShared( + place, + total_sample_size * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); result.actual_val = (uint64_t*)(result.actual_val_mem)->ptr(); result.set_total_sample_size(total_sample_size); @@ -1011,23 +1030,40 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int total_gpu = resource_->total_device(); auto stream = resource_->local_stream(gpu_id, 0); - auto d_left = memory::Alloc(place, total_gpu * sizeof(int)); - auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); + auto d_left = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); + auto d_right = + memory::Alloc(place, + total_gpu * sizeof(int), + phi::Stream(reinterpret_cast(stream))); int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream)); CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream)); // - auto d_idx = memory::Alloc(place, node_num * sizeof(int)); + auto d_idx = + memory::Alloc(place, + node_num * sizeof(int), + phi::Stream(reinterpret_cast(stream))); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - auto d_shard_keys = memory::Alloc(place, node_num * sizeof(uint64_t)); + auto d_shard_keys = + memory::Alloc(place, + node_num * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); uint64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = - memory::Alloc(place, slot_num * node_num * sizeof(uint64_t)); + memory::Alloc(place, + slot_num * node_num * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); uint64_t* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); - auto d_shard_actual_size = memory::Alloc(place, node_num * sizeof(int)); + auto d_shard_actual_size = + memory::Alloc(place, + node_num * sizeof(int), + phi::Stream(reinterpret_cast(stream))); int* d_shard_actual_size_ptr = reinterpret_cast(d_shard_actual_size->ptr()); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index e2c6df3102aeb..8fa365e2c14a2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -28,6 +28,72 @@ void GraphGpuWrapper::set_device(std::vector ids) { } } +void GraphGpuWrapper::init_conf(const std::string &first_node_type, + const std::string &meta_path) { + static std::mutex mutex; + { + std::lock_guard lock(mutex); + if (conf_initialized_) { + return; + } + VLOG(2) << "init path config"; + conf_initialized_ = true; + auto node_types = + paddle::string::split_string(first_node_type, ";"); + VLOG(2) << "node_types: " << first_node_type; + for (auto &type : node_types) { + auto iter = feature_to_id.find(type); + PADDLE_ENFORCE_NE(iter, + feature_to_id.end(), + platform::errors::NotFound( + "(%s) is not found in feature_to_id.", type)); + VLOG(2) << "feature_to_id[" << type << "] = " << iter->second; + first_node_type_.push_back(iter->second); + } + meta_path_.resize(first_node_type_.size()); + auto meta_paths = paddle::string::split_string(meta_path, ";"); + + for (size_t i = 0; i < meta_paths.size(); i++) { + auto path = meta_paths[i]; + auto nodes = paddle::string::split_string(path, "-"); + for (auto &node : nodes) { + auto iter = edge_to_id.find(node); + PADDLE_ENFORCE_NE(iter, + edge_to_id.end(), + platform::errors::NotFound( + "(%s) is not found in edge_to_id.", node)); + VLOG(2) << "edge_to_id[" << node << "] = " << iter->second; + meta_path_[i].push_back(iter->second); + } + } + int max_dev_id = 0; + for (size_t i = 0; i < device_id_mapping.size(); i++) { + if (device_id_mapping[i] > max_dev_id) { + max_dev_id = device_id_mapping[i]; + } + } + finish_node_type_.resize(max_dev_id + 1); + node_type_start_.resize(max_dev_id + 1); + infer_node_type_start_.resize(max_dev_id + 1); + for (size_t i = 0; i < device_id_mapping.size(); i++) { + int dev_id = device_id_mapping[i]; + auto &node_type_start = node_type_start_[i]; + auto &infer_node_type_start = infer_node_type_start_[i]; + auto &finish_node_type = finish_node_type_[i]; + finish_node_type.clear(); + // for (auto& kv : feature_to_id) { + // node_type_start[kv.second] = 0; + // infer_node_type_start[kv.second] = 0; + // } + for (auto &type : node_types) { + auto iter = feature_to_id.find(type); + node_type_start[iter->second] = 0; + infer_node_type_start[iter->second] = 0; + } + } + } +} + int GraphGpuWrapper::get_all_id(int type, int slice_num, std::vector> *output) { @@ -160,9 +226,9 @@ void GraphGpuWrapper::load_node_and_edge(std::string etype2files, std::string graph_data_local_path, int part_num, bool reverse) { - ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table_->load_node_and_edge_file( - etype2files, ntype2files, graph_data_local_path, part_num, reverse); + ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->load_node_and_edge_file( + etype2files, ntype2files, graph_data_local_path, part_num, reverse); } void GraphGpuWrapper::add_table_feat_conf(std::string table_name, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index 8ca3ee5899279..5998272e72539 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -22,6 +22,9 @@ namespace paddle { namespace framework { #ifdef PADDLE_WITH_HETERPS + +enum GpuGraphStorageMode { HBM = 1, CPU, MULTINODE }; + class GraphGpuWrapper { public: static std::shared_ptr GetInstance() { @@ -31,6 +34,8 @@ class GraphGpuWrapper { return s_instance_; } static std::shared_ptr s_instance_; + void init_conf(const std::string& first_node_type, + const std::string& meta_path); void initialize(); void finalize(); void set_device(std::vector ids); @@ -116,6 +121,14 @@ class GraphGpuWrapper { int upload_num = 8; std::shared_ptr<::ThreadPool> upload_task_pool; std::string feature_separator_ = std::string(" "); + // + bool conf_initialized_ = false; + std::vector first_node_type_; + std::vector> meta_path_; + + std::vector> finish_node_type_; + std::vector> node_type_start_; + std::vector> infer_node_type_start_; }; #endif } // namespace framework diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index a26ed5dbdad8c..b271b44434dc9 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -833,6 +833,18 @@ PADDLE_DEFINE_EXPORTED_bool( false, "It controls whether exit trainer when an worker has no ins."); +/** + * Distributed related FLAG + * Name: enable_exit_when_partial_worker + * Since Version: 2.2.0 + * Value Range: bool, default=false + * Example: + * Note: represent gpugraph storage mode, 1 for full hbm, 2 for hbm + mem + ssd. + */ +PADDLE_DEFINE_EXPORTED_int32(gpugraph_storage_mode, + 1, + "gpugraph storage mode, default 1"); + /** * KP kernel related FLAG * Name: FLAGS_run_kp_kernel diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index d58770dd714ff..93acfb8042fbe 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -531,7 +531,9 @@ def fleet_desc_configs(self, configs): 'embed_sparse_beta2_decay_rate', 'embedx_sparse_optimizer', 'embedx_sparse_learning_rate', \ 'embedx_sparse_weight_bounds', 'embedx_sparse_initial_range', 'embedx_sparse_initial_g2sum', \ 'embedx_sparse_beta1_decay_rate', 'embedx_sparse_beta2_decay_rate', 'feature_learning_rate', 'nodeid_slot'] - support_sparse_table_class = ['DownpourSparseTable'] + support_sparse_table_class = [ + 'DownpourSparseTable', 'DownpourSparseSSDTable' + ] support_sparse_accessor_class = [ 'DownpourSparseValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor', @@ -623,9 +625,12 @@ def set_sparse_table_config(table_data, config): "DownpourSparseTable") if table_class not in support_sparse_table_class: raise ValueError( - "support sparse_table_class: ['DownpourSparseTable'], but actual %s" + "support sparse_table_class: ['DownpourSparseTable, DownpourSparseSSDTable'], but actual %s" % (table_class)) - table_data.table_class = 'MemorySparseTable' + if table_class == "DownpourSparseSSDTable": + table_data.table_class = 'SSDSparseTable' + else: + table_data.table_class = 'MemorySparseTable' table_data.shard_num = config.get('sparse_shard_num', 1000) accessor_class = config.get("sparse_accessor_class",