From 633c2d00808a14207e5d2db838fa6f180ce6ad46 Mon Sep 17 00:00:00 2001 From: YaoCheng8667 <951766143@qq.com> Date: Thu, 21 Mar 2024 02:37:56 +0000 Subject: [PATCH] modify interface for compress push --- paddle/fluid/framework/data_set.cc | 18 ++++++++++++++---- paddle/fluid/framework/fleet/box_wrapper.cc | 2 +- paddle/fluid/framework/fleet/box_wrapper.h | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 66e4a6f665ed4..9fc170c13ed87 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -2808,18 +2808,23 @@ void PadBoxSlotDataset::PrepareTrain(void) { ->AddBatchOffset(offset[i]); } #ifdef PADDLE_WITH_XPU_KP - using BatchData = std::vector>; + using BatchData = std::vector>>; // devid -> dev_batch_data + VLOG(0) << "PadBoxSlotDataset::PrepareTrain with pv_merge offset size:" << offset.size() << ", thread_num:" << thread_num_; auto data_func = [this, offset] (int batch_idx, BatchData * out_data) { BatchData & batch_data = *out_data; batch_data.clear(); + batch_data.resize(thread_num_); + int offset_idx = batch_idx * thread_num_; CHECK(offset_idx + thread_num_ <= (int)offset.size()) << "offset_idx:" << offset_idx << ", thread_num_:" << thread_num_ << "offset.size:" << offset.size(); for (int j = 0; j < thread_num_; j++) { + int dev_id = j; + auto & dev_batch_data = batch_data[dev_id]; auto & offset_pair = offset[offset_idx + j]; for (int k = 0; k < offset_pair.second; k++) { auto & pv_ins = input_pv_ins_[offset_pair.first + k]->ads; @@ -2827,7 +2832,7 @@ void PadBoxSlotDataset::PrepareTrain(void) { for (auto & rec : pv_ins) { for (auto& idx : used_fea_index_) { uint64_t* feas = rec->slot_uint64_feasigns_.get_values(idx, &num); - batch_data.push_back(std::make_pair(feas, num)); + dev_batch_data.push_back(std::make_pair(feas, num)); } } } @@ -2857,25 +2862,30 @@ void PadBoxSlotDataset::PrepareTrain(void) { ->AddBatchOffset(offset[i]); } #ifdef PADDLE_WITH_XPU_KP - using BatchData = std::vector>; + using BatchData = std::vector>>; // devid -> dev_batch_data VLOG(0) << "PadBoxSlotDataset::PrepareTrain offset size:" << offset.size() << ", thread_num:" << thread_num_; auto data_func = [this, offset] (int batch_idx, BatchData * out_data) { BatchData & batch_data = *out_data; batch_data.clear(); + batch_data.resize(thread_num_); + int offset_idx = batch_idx * thread_num_; CHECK(offset_idx + thread_num_ <= (int)offset.size()) << "offset_idx:" << offset_idx << ", thread_num_:" << thread_num_ << "offset.size:" << offset.size(); for (int j = 0; j < thread_num_; j++) { + int dev_id = j; + auto & dev_batch_data = batch_data[dev_id]; + auto & offset_pair = offset[offset_idx + j]; for (int k = 0; k < offset_pair.second; k++) { auto & rec = input_records_[offset_pair.first + k]; size_t num = 0; for (auto& idx : used_fea_index_) { uint64_t* feas = rec->slot_uint64_feasigns_.get_values(idx, &num); - batch_data.push_back(std::make_pair(feas, num)); + dev_batch_data.push_back(std::make_pair(feas, num)); } } } diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index c5a740c118b6e..7386030ae217f 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -1237,7 +1237,7 @@ void BoxWrapper::GetFeatureOffsetInfo(void) { #ifdef PADDLE_WITH_XPU_KP void BoxWrapper::SetDataFuncForCacheManager(int batch_num, - std::function>*)> data_func) { + std::function>>*)> data_func) { boxps_ptr_->SetDataFuncForCacheManager(batch_num, data_func, &fid2sign_map_); } diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 1f6be13e904e4..4593b2b1eea24 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -682,7 +682,7 @@ class BoxWrapper { #ifdef PADDLE_WITH_XPU_KP void SetDataFuncForCacheManager(int batch_num, - std::function>*)> data_func); + std::function>>*)> data_func); int PrepareNextBatch(int dev_id); std::vector * GetFid2SginMap() { return fid2sign_map_; } #endif