From 2d413d46d2cf5241647d693d04649b00045da4f5 Mon Sep 17 00:00:00 2001 From: yangjunchao Date: Wed, 26 Jan 2022 21:03:25 +0800 Subject: [PATCH] add join test and update test --- paddle/fluid/framework/data_feed.cc | 32 +++++++++++++--------- paddle/fluid/framework/data_set.cc | 3 +- paddle/fluid/framework/data_set.h | 2 +- paddle/fluid/framework/fleet/box_wrapper.h | 5 ++-- paddle/fluid/pybind/box_helper_py.cc | 2 ++ 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 441a84d1df975..7a9eb53262bd6 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -363,7 +363,7 @@ InMemoryDataFeed::InMemoryDataFeed() { this->parse_content_ = false; this->parse_logkey_ = false; this->enable_pv_merge_ = false; - this->current_phase_ = 1; // 1:join ;0:update + this->current_phase_ = 1; // 1:join ;0:update; 3:join_test; 2:update_test this->input_channel_ = nullptr; this->output_channel_ = nullptr; this->consume_channel_ = nullptr; @@ -1660,9 +1660,10 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { bool PaddleBoxDataFeed::Start() { #ifdef _LINUX - int phase = GetCurrentPhase(); // join: 1, update: 0 + // join: 1, update: 0, join_test: 3, update_test: 2 + int phase = GetCurrentPhase(); this->CheckSetFileList(); - if (enable_pv_merge_ && phase == 1) { + if (enable_pv_merge_ && (phase == 1 || phase == 3)) { // join phase : input_pv_channel to output_pv_channel if (output_pv_channel_->Size() == 0 && input_pv_channel_->Size() != 0) { std::vector data; @@ -1691,9 +1692,10 @@ bool PaddleBoxDataFeed::Start() { int PaddleBoxDataFeed::Next() { #ifdef _LINUX - int phase = GetCurrentPhase(); // join: 1, update: 0 + // join: 1, update: 0, join_test: 3, update_test: 2 + int phase = GetCurrentPhase(); this->CheckStart(); - if (enable_pv_merge_ && phase == 1) { + if (enable_pv_merge_ && (phase == 1 || phase == 3)) { // join phase : output_pv_channel to consume_pv_channel CHECK(output_pv_channel_ != nullptr); CHECK(consume_pv_channel_ != nullptr); @@ -1824,8 +1826,9 @@ void PaddleBoxDataFeed::GetRankOffset(const std::vector& pv_vec, void PaddleBoxDataFeed::AssignFeedVar(const Scope& scope) { MultiSlotInMemoryDataFeed::AssignFeedVar(scope); // set rank offset memory - int phase = GetCurrentPhase(); // join: 1, update: 0 - if (enable_pv_merge_ && phase == 1) { + // join: 1, update: 0, join_test: 3, update_test: 2 + int phase = GetCurrentPhase(); + if (enable_pv_merge_ && (phase == 1 || phase == 3)) { rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable(); } } @@ -2324,13 +2327,14 @@ bool SlotPaddleBoxDataFeed::Start() { return true; } int SlotPaddleBoxDataFeed::Next() { - int phase = GetCurrentPhase(); // join: 1, update: 0 + // join: 1, update: 0, join_test: 3, update_test: 2 + int phase = GetCurrentPhase(); this->CheckStart(); if (offset_index_ >= static_cast(batch_offsets_.size())) { return 0; } auto& batch = batch_offsets_[offset_index_++]; - if (enable_pv_merge_ && phase == 1) { + if (enable_pv_merge_ && (phase == 1 || phase == 3)) { // join phase : output_pv_channel to consume_pv_channel this->batch_size_ = batch.second; if (this->batch_size_ != 0) { @@ -2346,7 +2350,7 @@ int SlotPaddleBoxDataFeed::Next() { batch_timer_.Resume(); PutToFeedSlotVec(&records_[batch.first], this->batch_size_); // update set join q value - if (phase == 0 && FLAGS_padbox_slotrecord_extend_dim > 0) { + if ((phase == 0 || phase == 2) && FLAGS_padbox_slotrecord_extend_dim > 0) { // pcoc pack_->pack_qvalue(); } @@ -2355,7 +2359,8 @@ int SlotPaddleBoxDataFeed::Next() { } } bool SlotPaddleBoxDataFeed::EnablePvMerge(void) { - return (enable_pv_merge_ && GetCurrentPhase() == 1); + return (enable_pv_merge_ && + (GetCurrentPhase() == 1 || GetCurrentPhase() == 3)); } int SlotPaddleBoxDataFeed::GetPackInstance(SlotRecord** ins) { if (offset_index_ >= static_cast(batch_offsets_.size())) { @@ -2380,8 +2385,9 @@ void SlotPaddleBoxDataFeed::AssignFeedVar(const Scope& scope) { scope.FindVar(used_slots_info_[i].slot)->GetMutable(); } // set rank offset memory - int phase = GetCurrentPhase(); // join: 1, update: 0 - if (enable_pv_merge_ && phase == 1) { + // join: 1, update: 0, join_test: 3, update_test: 2 + int phase = GetCurrentPhase(); + if (enable_pv_merge_ && (phase == 1 || phase == 3)) { rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable(); } } diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 1234d941710ae..4f5065ac340f8 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -2366,7 +2366,8 @@ void PadBoxSlotDataset::PrepareTrain(void) { std::vector> offset; // join or aucrunner mode enable pv - if (enable_pv_merge_ && (box_ptr->Phase() == 1 || box_ptr->Mode() == 1)) { + if (enable_pv_merge_ && (box_ptr->Phase() == 1 || + box_ptr->Phase() == 3 || box_ptr->Mode() == 1)) { std::shuffle(input_pv_ins_.begin(), input_pv_ins_.end(), BoxWrapper::LocalRandomEngine()); // 分数据到各线程里面 diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 74a54fb9fe247..f06aa6299b2a8 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -304,7 +304,7 @@ class DatasetImpl : public Dataset { bool parse_logkey_; bool merge_by_sid_; bool enable_pv_merge_; // True means to merge pv - int current_phase_; // 1 join, 0 update + int current_phase_; // 1 join, 0 update, 3 join_test, 2 update_test size_t merge_size_; bool slots_shuffle_fea_eval_ = false; bool gen_uni_feasigns_ = false; diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index eccf16fa506b1..9977a52910013 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -522,8 +522,8 @@ class BoxWrapper { } else if (s_instance_->feature_type_ == static_cast(boxps::FEATURE_PCOC)) { s_instance_->cvm_offset_ = 8; - } else if (s_instance_->feature_type_ == - static_cast(boxps::FEATURE_CONV)) { + } else if (s_instance_->feature_type_ + == static_cast(boxps::FEATURE_CONV)) { s_instance_->cvm_offset_ = 4; } else { s_instance_->cvm_offset_ = 3; @@ -622,6 +622,7 @@ class BoxWrapper { int Phase() const { return phase_; } int PhaseNum() const { return phase_num_; } void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; } + void SetPhase(int phase) {phase_ = phase; } const std::map GetLRMap() const { return lr_map_; } std::map& GetMetricList() { return metric_lists_; } diff --git a/paddle/fluid/pybind/box_helper_py.cc b/paddle/fluid/pybind/box_helper_py.cc index 5e76e4fa6e27e..ecc4ab871fbb4 100644 --- a/paddle/fluid/pybind/box_helper_py.cc +++ b/paddle/fluid/pybind/box_helper_py.cc @@ -98,6 +98,8 @@ void BindBoxWrapper(py::module* m) { py::call_guard()) .def("flip_phase", &framework::BoxWrapper::FlipPhase, py::call_guard()) + .def("set_phase", &framework::BoxWrapper::SetPhase, + py::call_guard()) .def("init_afs_api", &framework::BoxWrapper::InitAfsAPI, py::call_guard()) .def("finalize", &framework::BoxWrapper::Finalize,