From 2ae4c26f5fa680f60f9c0216247adfcecd31623c Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 26 Sep 2022 20:07:48 +0800 Subject: [PATCH] add graphsage slot feature (#126) --- paddle/fluid/framework/data_feed.cu | 113 ++++++++++++++++------------ 1 file changed, 64 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index b43118e358008..336a0f412bb88 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -737,6 +737,8 @@ int GraphDataGenerator::GenerateBatch() { int total_instance = 0; platform::CUDADeviceGuard guard(gpuid_); int res = 0; + + std::shared_ptr final_sage_nodes; if (!gpu_graph_training_) { while (cursor_ < h_device_keys_.size()) { size_t device_key_size = h_device_keys_[cursor_]->size(); @@ -776,7 +778,6 @@ int GraphDataGenerator::GenerateBatch() { 0, stream_>>>(clk_tensor_ptr_, total_instance); } else { - auto node_buf = memory::AllocShared( place_, total_instance * sizeof(uint64_t)); int64_t* node_buf_ptr = reinterpret_cast(node_buf->ptr()); @@ -789,7 +790,7 @@ int GraphDataGenerator::GenerateBatch() { phi::DenseTensor inverse_; VLOG(1) << "generate sample graph"; uint64_t* node_buf_ptr_ = reinterpret_cast(node_buf->ptr()); - std::shared_ptr final_infer_nodes = + final_sage_nodes = GenerateSampleGraph(node_buf_ptr_, total_instance, &uniq_instance_, &inverse_); id_tensor_ptr_ = @@ -803,7 +804,7 @@ int GraphDataGenerator::GenerateBatch() { feed_vec_[index_offset]->mutable_data({total_instance}, this->place_); VLOG(1) << "copy id and index"; - cudaMemcpy(id_tensor_ptr_, final_infer_nodes->ptr(), + cudaMemcpy(id_tensor_ptr_, final_sage_nodes->ptr(), sizeof(int64_t) * uniq_instance_, cudaMemcpyDeviceToDevice); cudaMemcpy(index_tensor_ptr_, inverse_.data(), sizeof(int) * total_instance, @@ -840,31 +841,7 @@ int GraphDataGenerator::GenerateBatch() { total_instance *= 2; } - int64_t *slot_tensor_ptr_[slot_num_]; - int64_t *slot_lod_tensor_ptr_[slot_num_]; - if (slot_num_ > 0) { - for (int i = 0; i < slot_num_; ++i) { - slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data( - {total_instance * h_slot_feature_num_map_[i], 1}, this->place_); - slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( - {total_instance + 1}, this->place_); - } - if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { - cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), - slot_tensor_ptr_, - sizeof(uint64_t *) * slot_num_, - cudaMemcpyHostToDevice, - stream_); - cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), - slot_lod_tensor_ptr_, - sizeof(uint64_t *) * slot_num_, - cudaMemcpyHostToDevice, - stream_); - } - } - uint64_t *ins_cursor, *ins_buf; - std::shared_ptr final_nodes; phi::DenseTensor inverse; if (gpu_graph_training_) { VLOG(2) << "total_instance: " << total_instance @@ -893,7 +870,7 @@ int GraphDataGenerator::GenerateBatch() { stream_>>>(clk_tensor_ptr_, total_instance); } else { VLOG(2) << gpuid_ << " " << "Ready to enter GenerateSampleGraph"; - final_nodes = GenerateSampleGraph(ins_cursor, total_instance, &uniq_instance_, + final_sage_nodes = GenerateSampleGraph(ins_cursor, total_instance, &uniq_instance_, &inverse); VLOG(2) << "Copy Final Results"; id_tensor_ptr_ = @@ -907,7 +884,7 @@ int GraphDataGenerator::GenerateBatch() { feed_vec_[index_offset]->mutable_data({total_instance}, this->place_); cudaMemcpyAsync(id_tensor_ptr_, - final_nodes->ptr(), + final_sage_nodes->ptr(), sizeof(int64_t) * uniq_instance_, cudaMemcpyDeviceToDevice, stream_); @@ -930,23 +907,60 @@ int GraphDataGenerator::GenerateBatch() { ins_cursor = (uint64_t *)id_tensor_ptr_; // NOLINT } + int64_t *slot_tensor_ptr_[slot_num_]; + int64_t *slot_lod_tensor_ptr_[slot_num_]; if (slot_num_ > 0) { + int slot_instance = sage_mode_ == true ? uniq_instance_ : total_instance; + for (int i = 0; i < slot_num_; ++i) { + slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data( + {slot_instance * h_slot_feature_num_map_[i], 1}, this->place_); + slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( + {slot_instance + 1}, this->place_); + } + if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { + cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), + slot_tensor_ptr_, + sizeof(uint64_t *) * slot_num_, + cudaMemcpyHostToDevice, + stream_); + cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), + slot_lod_tensor_ptr_, + sizeof(uint64_t *) * slot_num_, + cudaMemcpyHostToDevice, + stream_); + } + if (sage_mode_) { + d_feature_buf_ = + memory::AllocShared(place_, slot_instance * slot_num_ * sizeof(uint64_t)); + } uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { - FillFeatureBuf(ins_cursor, feature_buf, total_instance); - // FillFeatureBuf(id_tensor_ptr_, feature_buf, total_instance); + if (!sage_mode_) { + FillFeatureBuf(ins_cursor, feature_buf, slot_instance); + } else { + uint64_t* sage_nodes_ptr = reinterpret_cast(final_sage_nodes->ptr()); + FillFeatureBuf(sage_nodes_ptr, feature_buf, slot_instance); + } if (debug_mode_) { - uint64_t h_walk[total_instance]; // NOLINT - cudaMemcpy(h_walk, - ins_cursor, - total_instance * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - uint64_t h_feature[total_instance * fea_num_per_node_]; + uint64_t h_walk[slot_instance]; + if (!sage_mode_) { + cudaMemcpy(h_walk, + ins_cursor, + slot_instance * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + } else { + uint64_t* sage_nodes_ptr = reinterpret_cast(final_sage_nodes->ptr()); + cudaMemcpy(h_walk, + sage_nodes_ptr, + slot_instance * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + } + uint64_t h_feature[slot_instance * fea_num_per_node_]; cudaMemcpy(h_feature, feature_buf, - total_instance * fea_num_per_node_ * sizeof(uint64_t), + slot_instance * fea_num_per_node_ * sizeof(uint64_t), cudaMemcpyDeviceToHost); - for (int i = 0; i < total_instance; ++i) { + for (int i = 0; i < slot_instance; ++i) { std::stringstream ss; for (int j = 0; j < fea_num_per_node_; ++j) { ss << h_feature[i * fea_num_per_node_ + j] << " "; @@ -957,26 +971,25 @@ int GraphDataGenerator::GenerateBatch() { << "] = " << ss.str(); } } - - GraphFillSlotKernel<<>>((uint64_t *)d_slot_tensor_ptr_->ptr(), feature_buf, - total_instance * fea_num_per_node_, - total_instance, + slot_instance * fea_num_per_node_, + slot_instance, slot_num_, (int*)d_slot_feature_num_map_->ptr(), fea_num_per_node_, (int*)d_actual_slot_id_map_->ptr(), (int*)d_fea_offset_map_->ptr()); - GraphFillSlotLodKernelOpt<<>>( - (uint64_t *)d_slot_lod_tensor_ptr_->ptr(), // NOLINT - (total_instance + 1) * slot_num_, - total_instance + 1, + (uint64_t *)d_slot_lod_tensor_ptr_->ptr(), + (slot_instance + 1) * slot_num_, + slot_instance + 1, (int*)d_slot_feature_num_map_->ptr()); } else { for (int i = 0; i < slot_num_; ++i) { @@ -1519,8 +1532,10 @@ void GraphDataGenerator::AllocResource( d_ins_buf_ = memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t)); if (slot_num_ > 0) { - d_feature_buf_ = memory::AllocShared( - place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t)); + if (!sage_mode_) { + d_feature_buf_ = memory::AllocShared( + place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t)); + } } d_pair_num_ = memory::AllocShared(place_, sizeof(int)); if (FLAGS_enable_opt_get_features && slot_num_ > 0) {