diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index c68ba524cebef..91a03663a7f35 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2691,19 +2691,18 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByLine(void) { std::vector record_vec; platform::Timer timeline; timeline.Start(); - const int max_fetch_num = 10000; int offset = 0; int old_offset = 0; - SlotRecordPool().get(&record_vec, max_fetch_num); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); // get slotrecord object function - auto record_func = [this, &offset, &record_vec, &max_fetch_num, - &old_offset](std::vector& vec, int num) { + auto record_func = [this, &offset, &record_vec, &old_offset]( + std::vector& vec, int num) { vec.resize(num); - if (offset + num > max_fetch_num) { + if (offset + num > OBJPOOL_BLOCK_SIZE) { input_channel_->WriteMove(offset, &record_vec[0]); SlotRecordPool().get(&record_vec[0], offset); - record_vec.resize(max_fetch_num); + record_vec.resize(OBJPOOL_BLOCK_SIZE); offset = 0; old_offset = 0; } @@ -2715,8 +2714,8 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByLine(void) { offset = offset + num; }; - line_func = [this, &parser, &record_vec, &offset, &max_fetch_num, &filename, - &record_func, &old_offset](const std::string& line) { + line_func = [this, &parser, &record_vec, &offset, &filename, &record_func, + &old_offset](const std::string& line) { old_offset = offset; if (!parser->ParseOneInstance(line, record_func)) { offset = old_offset; @@ -2724,10 +2723,10 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByLine(void) { << line << "]"; return false; } - if (offset >= max_fetch_num) { + if (offset >= OBJPOOL_BLOCK_SIZE) { input_channel_->Write(std::move(record_vec)); record_vec.clear(); - SlotRecordPool().get(&record_vec, max_fetch_num); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); offset = 0; } return true; @@ -2756,8 +2755,9 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByLine(void) { } while (line_reader.is_error()); if (offset > 0) { input_channel_->WriteMove(offset, &record_vec[0]); - if (offset < max_fetch_num) { - SlotRecordPool().put(&record_vec[offset], (max_fetch_num - offset)); + if (offset < OBJPOOL_BLOCK_SIZE) { + SlotRecordPool().put(&record_vec[offset], + (OBJPOOL_BLOCK_SIZE - offset)); } } else { SlotRecordPool().put(&record_vec); @@ -2881,8 +2881,7 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByCommand(void) { std::vector record_vec; platform::Timer timeline; timeline.Start(); - int max_fetch_num = 10000; - SlotRecordPool().get(&record_vec, max_fetch_num); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); int offset = 0; do { @@ -2898,8 +2897,7 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByCommand(void) { lines = line_reader.read_file( this->fp_.get(), - [this, &record_vec, &offset, &max_fetch_num, - &filename](const std::string& line) { + [this, &record_vec, &offset, &filename](const std::string& line) { if (ParseOneInstance(line, &record_vec[offset])) { ++offset; } else { @@ -2907,10 +2905,10 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByCommand(void) { << "] item error, line:[" << line << "]"; return false; } - if (offset >= max_fetch_num) { + if (offset >= OBJPOOL_BLOCK_SIZE) { input_channel_->Write(std::move(record_vec)); record_vec.clear(); - SlotRecordPool().get(&record_vec, max_fetch_num); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); offset = 0; } return true; @@ -2919,8 +2917,9 @@ void SlotPaddleBoxDataFeed::LoadIntoMemoryByCommand(void) { } while (line_reader.is_error()); if (offset > 0) { input_channel_->WriteMove(offset, &record_vec[0]); - if (offset < max_fetch_num) { - SlotRecordPool().put(&record_vec[offset], (max_fetch_num - offset)); + if (offset < OBJPOOL_BLOCK_SIZE) { + SlotRecordPool().put(&record_vec[offset], + (OBJPOOL_BLOCK_SIZE - offset)); } } else { SlotRecordPool().put(&record_vec); @@ -3096,7 +3095,7 @@ void SlotPaddleBoxDataFeedWithGpuReplicaCache::LoadIntoMemoryByLib(void) { std::vector record_vec; platform::Timer timeline; timeline.Start(); - const int max_fetch_num = 10000; + const int max_fetch_num = OBJPOOL_BLOCK_SIZE; int offset = 0; SlotRecordPool().get(&record_vec, max_fetch_num); @@ -3206,7 +3205,7 @@ void SlotPaddleBoxDataFeedWithGpuReplicaCache::LoadIntoMemoryByCommand(void) { timeline.Start(); int offset = 0; int gpu_cache_offset; - int max_fetch_num = 10000; + int max_fetch_num = OBJPOOL_BLOCK_SIZE; SlotRecordPool().get(&record_vec, max_fetch_num); do { if (box_ptr->UseAfsApi()) { @@ -3401,7 +3400,7 @@ void InputTableDataFeed::LoadIntoMemoryByLib() { std::vector record_vec; platform::Timer timeline; timeline.Start(); - const int max_fetch_num = 10000; + const int max_fetch_num = OBJPOOL_BLOCK_SIZE; int offset = 0; SlotRecordPool().get(&record_vec, max_fetch_num); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 777b2974269f9..0894778e4ddfc 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -929,11 +929,12 @@ class SlotObjAllocator { Node* free_nodes_; // a list size_t capacity_; }; - +static const int OBJPOOL_BLOCK_SIZE = 10000; class SlotObjPool { public: SlotObjPool() : max_capacity_(FLAGS_padbox_record_pool_max_size) { ins_chan_ = MakeChannel(); + ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE); for (int i = 0; i < FLAGS_padbox_slotpool_thread_num; ++i) { threads_.push_back(std::thread([this]() { run(); })); } @@ -976,10 +977,9 @@ class SlotObjPool { } put(&(*input)[0], size); input->clear(); - input->shrink_to_fit(); } void put(SlotRecord* input, size_t size) { - ins_chan_->WriteMove(size, input); + CHECK(ins_chan_->WriteMove(size, input) == size); } void run(void) { std::vector input; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index fc1160e6457b1..691cf001f43f6 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -1437,11 +1437,11 @@ PadBoxSlotDataset::~PadBoxSlotDataset() {} void PadBoxSlotDataset::CreateChannel() { if (input_channel_ == nullptr) { input_channel_ = MakeChannel(); - input_channel_->SetBlockSize(10240); + input_channel_->SetBlockSize(OBJPOOL_BLOCK_SIZE); } if (shuffle_channel_ == nullptr) { shuffle_channel_ = MakeChannel(); - shuffle_channel_->SetBlockSize(10240); + shuffle_channel_->SetBlockSize(OBJPOOL_BLOCK_SIZE); } } // set filelist, file_idx_ will reset to zero. @@ -1479,26 +1479,6 @@ inline paddle::framework::ThreadPool* GetShufflePool(int thread_num) { } return thread_pool.get(); } -int PadBoxSlotDataset::GetMaxShuffleThreadId(void) { - double rate = static_cast(shuffle_thread_num_) / - static_cast(thread_num_); - int thread_num = static_cast(rate * read_ins_ref_); - int half_num = static_cast(shuffle_thread_num_ >> 1); - if (thread_num < half_num) { - return half_num; - } - return thread_num; -} -int PadBoxSlotDataset::GetMaxMergeThreadId(void) { - double rate = - static_cast(merge_thread_num_) / static_cast(thread_num_); - int half_num = static_cast(merge_thread_num_ >> 1); - int thread_num = static_cast(rate * read_ins_ref_); - if (thread_num < half_num) { - return half_num; - } - return thread_num; -} void PadBoxSlotDataset::CheckThreadPool(void) { wait_futures_.clear(); if (thread_pool_ != nullptr && merge_pool_ != nullptr) { @@ -1511,10 +1491,10 @@ void PadBoxSlotDataset::CheckThreadPool(void) { // read ins thread thread_pool_ = GetThreadPool(thread_num_); // merge thread - merge_pool_ = GetMergePool(merge_thread_num_); + merge_pool_ = GetMergePool(merge_thread_num_ * 2); // shuffle thread if (!FLAGS_padbox_dataset_disable_shuffle && mpi_size_ > 1) { - shuffle_pool_ = GetShufflePool(shuffle_thread_num_); + shuffle_pool_ = GetShufflePool(shuffle_thread_num_ * 2); } std::vector& cores = boxps::get_readins_cores(); @@ -1639,17 +1619,18 @@ void PadBoxSlotDataset::LoadIntoMemory() { void PadBoxSlotDataset::MergeInsKeys(const Channel& in) { merge_ins_ref_ = merge_thread_num_; input_records_.clear(); + min_merge_ins_span_ = 1000; CHECK(p_agent_ != nullptr); for (int tid = 0; tid < merge_thread_num_; ++tid) { wait_futures_.emplace_back(merge_pool_->Run([this, &in, tid]() { // VLOG(0) << "merge thread id: " << tid << "start"; platform::Timer timer; - timer.Start(); auto feed_obj = reinterpret_cast(readers_[0].get()); size_t num = 0; std::vector datas; - while (in->ReadOnce(datas, 10240)) { + while (in->ReadOnce(datas, OBJPOOL_BLOCK_SIZE)) { + timer.Resume(); for (auto& rec : datas) { for (auto& idx : used_fea_index_) { uint64_t* feas = rec->slot_uint64_feasigns_.get_values(idx, &num); @@ -1666,18 +1647,15 @@ void PadBoxSlotDataset::MergeInsKeys(const Channel& in) { } merge_mutex_.unlock(); datas.clear(); - if (tid > GetMaxMergeThreadId()) { - break; - } + timer.Pause(); } datas.shrink_to_fit(); - timer.Pause(); double span = timer.ElapsedSec(); if (max_merge_ins_span_ < span) { max_merge_ins_span_ = span; } - if (min_merge_ins_span_ == 0 || min_merge_ins_span_ > span) { + if (min_merge_ins_span_ > span) { min_merge_ins_span_ = span; } // end merge thread @@ -1771,8 +1749,10 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) { CHECK_GT(thread_num, 0); VLOG(3) << "start global shuffle threads, num = " << thread_num; shuffle_counter_ = thread_num; + min_shuffle_span_ = 1000; for (int tid = 0; tid < thread_num; ++tid) { wait_futures_.emplace_back(shuffle_pool_->Run([this, tid]() { + platform::Timer timer; std::vector data; std::vector loc_datas; std::vector releases; @@ -1780,7 +1760,8 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) { PadBoxSlotDataConsumer* handler = reinterpret_cast(data_consumer_); ShuffleResultWaitGroup wg; - while (input_channel_->ReadOnce(data, 10240)) { + while (input_channel_->Read(data)) { + timer.Resume(); for (auto& t : data) { int client_id = 0; if (enable_pv_merge_) { // shuffle by pv @@ -1800,8 +1781,8 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) { } SlotRecordPool().put(&releases); releases.clear(); - - shuffle_channel_->Write(std::move(loc_datas)); + size_t loc_len = loc_datas.size(); + CHECK(shuffle_channel_->Write(std::move(loc_datas)) == loc_len); wg.wait(); wg.add(mpi_size_); @@ -1821,13 +1802,21 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) { data.clear(); loc_datas.clear(); - if (tid > GetMaxShuffleThreadId()) { - break; - } + timer.Pause(); } + timer.Resume(); wg.wait(); + timer.Pause(); - VLOG(3) << "end shuffle thread id = " << tid; + double span = timer.ElapsedSec(); + if (span > max_shuffle_span_) { + max_shuffle_span_ = span; + } + if (span < min_shuffle_span_) { + min_shuffle_span_ = span; + } + VLOG(3) << "passid = " << pass_id_ << ", end shuffle thread id=" << tid + << ", span: " << span; // only one thread send finish notify if (--shuffle_counter_ == 0) { // send closed @@ -1840,6 +1829,10 @@ void PadBoxSlotDataset::ShuffleData(int thread_num) { handler->send_message_callback(i, NULL, 0, &wg); } wg.wait(); + // end shuffle thread + LOG(WARNING) << "passid = " << pass_id_ + << ", end shuffle span max:" << max_shuffle_span_ + << ", min:" << min_shuffle_span_; // local closed channel if (--finished_counter_ == 0) { while (receiver_cnt_ > 0) { @@ -1867,6 +1860,7 @@ void PadBoxSlotDataset::ReceiveSuffleData(int client_id, const char* buf, --receiver_cnt_; if (finished_counter_ == 0) { + usleep(10000); while (receiver_cnt_ > 0) { usleep(100); } @@ -1881,8 +1875,8 @@ void PadBoxSlotDataset::ReceiveSuffleData(int client_id, const char* buf, paddle::framework::BinaryArchive ar; ar.SetReadBuffer(const_cast(buf), len, nullptr); + static const int max_fetch_num = OBJPOOL_BLOCK_SIZE / mpi_size_; int offset = 0; - const int max_fetch_num = 1000; std::vector data; SlotRecordPool().get(&data, max_fetch_num); while (ar.Cursor() < ar.Finish()) { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 082099c2f1b2e..a9fbb1832bf53 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -401,8 +401,6 @@ class PadBoxSlotDataset : public DatasetImpl { protected: void MergeInsKeys(const Channel& in); void CheckThreadPool(void); - int GetMaxShuffleThreadId(void); - int GetMaxMergeThreadId(void); protected: Channel shuffle_channel_ = nullptr; @@ -431,6 +429,8 @@ class PadBoxSlotDataset : public DatasetImpl { paddle::framework::ThreadPool* merge_pool_ = nullptr; paddle::framework::ThreadPool* shuffle_pool_ = nullptr; uint16_t pass_id_ = 0; + double max_shuffle_span_ = 0; + double min_shuffle_span_ = 0; }; class InputTableDataset : public PadBoxSlotDataset { diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index bb62400b9fce3..ddc39bf04b173 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -52,6 +52,7 @@ limitations under the License. */ DECLARE_int32(fix_dayid); DECLARE_bool(padbox_auc_runner_mode); DECLARE_bool(enable_dense_nccl_barrier); +DECLARE_int32(padbox_dataset_shuffle_thread_num); namespace paddle { namespace framework { @@ -581,7 +582,7 @@ class BoxWrapper { if (boxps::MPICluster::Ins().size() > 1) { data_shuffle_.reset(boxps::PaddleShuffler::New()); - data_shuffle_->init(10); + data_shuffle_->init(FLAGS_padbox_dataset_shuffle_thread_num); } } else { if (nullptr == s_instance_->boxps_ptr_) { diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 34ab1e0766bd5..94a379d93808f 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -481,7 +481,7 @@ DEFINE_int32(padbox_dataset_shuffle_thread_num, 20, "PadBoxSlotDataset shuffle thread num"); DEFINE_int32(padbox_dataset_merge_thread_num, 20, "PadBoxSlotDataset shuffle thread num"); -DEFINE_int32(padbox_slotpool_thread_num, 20, +DEFINE_int32(padbox_slotpool_thread_num, 1, "PadBoxSlotDataset slot pool thread num"); DEFINE_bool(use_gpu_replica_cache, false, "if true ,will open use_gpu_replica_cache"); diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 1a17663828bf8..ccfbccbfdf3dd 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -299,7 +299,9 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count, void GpuMemcpySync(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind) { - CHECK(cudaMemcpy(dst, src, count, kind) == cudaSuccess); + CHECK(cudaMemcpy(dst, src, count, kind) == cudaSuccess) + << "dst:" << dst << ", src:" << src << ", count:" << count + << ", kind:" << kind; } void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,