Skip to content

Commit

Permalink
Merge pull request #28 from chao9527/paddlebox
Browse files Browse the repository at this point in the history
add join test and update test
  • Loading branch information
qingshui authored Jan 27, 2022
2 parents 6ba0254 + 2d413d4 commit 349bc2d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 17 deletions.
32 changes: 19 additions & 13 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
this->current_phase_ = 1; // 1:join ;0:update
this->current_phase_ = 1; // 1:join ;0:update; 3:join_test; 2:update_test
this->input_channel_ = nullptr;
this->output_channel_ = nullptr;
this->consume_channel_ = nullptr;
Expand Down Expand Up @@ -1660,9 +1660,10 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {

bool PaddleBoxDataFeed::Start() {
#ifdef _LINUX
int phase = GetCurrentPhase(); // join: 1, update: 0
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
this->CheckSetFileList();
if (enable_pv_merge_ && phase == 1) {
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
// join phase : input_pv_channel to output_pv_channel
if (output_pv_channel_->Size() == 0 && input_pv_channel_->Size() != 0) {
std::vector<PvInstance> data;
Expand Down Expand Up @@ -1691,9 +1692,10 @@ bool PaddleBoxDataFeed::Start() {

int PaddleBoxDataFeed::Next() {
#ifdef _LINUX
int phase = GetCurrentPhase(); // join: 1, update: 0
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
this->CheckStart();
if (enable_pv_merge_ && phase == 1) {
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
// join phase : output_pv_channel to consume_pv_channel
CHECK(output_pv_channel_ != nullptr);
CHECK(consume_pv_channel_ != nullptr);
Expand Down Expand Up @@ -1824,8 +1826,9 @@ void PaddleBoxDataFeed::GetRankOffset(const std::vector<PvInstance>& pv_vec,
void PaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
MultiSlotInMemoryDataFeed::AssignFeedVar(scope);
// set rank offset memory
int phase = GetCurrentPhase(); // join: 1, update: 0
if (enable_pv_merge_ && phase == 1) {
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable<LoDTensor>();
}
}
Expand Down Expand Up @@ -2324,13 +2327,14 @@ bool SlotPaddleBoxDataFeed::Start() {
return true;
}
int SlotPaddleBoxDataFeed::Next() {
int phase = GetCurrentPhase(); // join: 1, update: 0
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
this->CheckStart();
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
return 0;
}
auto& batch = batch_offsets_[offset_index_++];
if (enable_pv_merge_ && phase == 1) {
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
// join phase : output_pv_channel to consume_pv_channel
this->batch_size_ = batch.second;
if (this->batch_size_ != 0) {
Expand All @@ -2346,7 +2350,7 @@ int SlotPaddleBoxDataFeed::Next() {
batch_timer_.Resume();
PutToFeedSlotVec(&records_[batch.first], this->batch_size_);
// update set join q value
if (phase == 0 && FLAGS_padbox_slotrecord_extend_dim > 0) {
if ((phase == 0 || phase == 2) && FLAGS_padbox_slotrecord_extend_dim > 0) {
// pcoc
pack_->pack_qvalue();
}
Expand All @@ -2355,7 +2359,8 @@ int SlotPaddleBoxDataFeed::Next() {
}
}
bool SlotPaddleBoxDataFeed::EnablePvMerge(void) {
return (enable_pv_merge_ && GetCurrentPhase() == 1);
return (enable_pv_merge_ &&
(GetCurrentPhase() == 1 || GetCurrentPhase() == 3));
}
int SlotPaddleBoxDataFeed::GetPackInstance(SlotRecord** ins) {
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
Expand All @@ -2380,8 +2385,9 @@ void SlotPaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
scope.FindVar(used_slots_info_[i].slot)->GetMutable<LoDTensor>();
}
// set rank offset memory
int phase = GetCurrentPhase(); // join: 1, update: 0
if (enable_pv_merge_ && phase == 1) {
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable<LoDTensor>();
}
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2366,7 +2366,8 @@ void PadBoxSlotDataset::PrepareTrain(void) {

std::vector<std::pair<int, int>> offset;
// join or aucrunner mode enable pv
if (enable_pv_merge_ && (box_ptr->Phase() == 1 || box_ptr->Mode() == 1)) {
if (enable_pv_merge_ && (box_ptr->Phase() == 1 ||
box_ptr->Phase() == 3 || box_ptr->Mode() == 1)) {
std::shuffle(input_pv_ins_.begin(), input_pv_ins_.end(),
BoxWrapper::LocalRandomEngine());
// 分数据到各线程里面
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class DatasetImpl : public Dataset {
bool parse_logkey_;
bool merge_by_sid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
int current_phase_; // 1 join, 0 update, 3 join_test, 2 update_test
size_t merge_size_;
bool slots_shuffle_fea_eval_ = false;
bool gen_uni_feasigns_ = false;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ class BoxWrapper {
} else if (s_instance_->feature_type_ ==
static_cast<int>(boxps::FEATURE_PCOC)) {
s_instance_->cvm_offset_ = 8;
} else if (s_instance_->feature_type_ ==
static_cast<int>(boxps::FEATURE_CONV)) {
} else if (s_instance_->feature_type_
== static_cast<int>(boxps::FEATURE_CONV)) {
s_instance_->cvm_offset_ = 4;
} else {
s_instance_->cvm_offset_ = 3;
Expand Down Expand Up @@ -622,6 +622,7 @@ class BoxWrapper {
int Phase() const { return phase_; }
int PhaseNum() const { return phase_num_; }
void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
void SetPhase(int phase) {phase_ = phase; }
const std::map<std::string, float> GetLRMap() const { return lr_map_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/box_helper_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ void BindBoxWrapper(py::module* m) {
py::call_guard<py::gil_scoped_release>())
.def("flip_phase", &framework::BoxWrapper::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("set_phase", &framework::BoxWrapper::SetPhase,
py::call_guard<py::gil_scoped_release>())
.def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
py::call_guard<py::gil_scoped_release>())
.def("finalize", &framework::BoxWrapper::Finalize,
Expand Down

0 comments on commit 349bc2d

Please sign in to comment.