Skip to content

Commit

Permalink
Fix add key filter only used slots data (PaddlePaddle#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingshui authored Jun 13, 2020
1 parent 4fb0239 commit dcbb4b9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
9 changes: 4 additions & 5 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,22 +1794,21 @@ void SlotPaddleBoxDataFeed::Init(const DataFeedDesc& data_feed_desc) {
pv_batch_size_ = data_feed_desc.pv_batch_size();
// fprintf(stdout, "rank_offset_name: [%s]\n", rank_offset_name_.c_str());
}
void SlotPaddleBoxDataFeed::GetUsedSlot(
std::vector<bool>* used_slot_index_ptr) {
void SlotPaddleBoxDataFeed::GetUsedSlotIndex(
std::vector<int>* used_slot_index) {
auto boxps_ptr = BoxWrapper::GetInstance();
// get feasigns that FeedPass doesn't need
const std::unordered_set<std::string>& slot_name_omited_in_feedpass_ =
boxps_ptr->GetOmitedSlot();
std::vector<bool>& used_slot_index = (*used_slot_index_ptr);
used_slot_index.assign(use_slot_size_, false);
used_slot_index->clear();
for (int i = 0; i < use_slot_size_; ++i) {
auto& info = used_slots_info_[i];
if (info.type[0] != 'u') {
continue;
}
if (slot_name_omited_in_feedpass_.find(info.slot) ==
slot_name_omited_in_feedpass_.end()) {
used_slot_index[info.slot_value_idx] = true;
used_slot_index->push_back(info.slot_value_idx);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ class SlotPaddleBoxDataFeed : public DataFeed {
void AddBatchOffset(const std::pair<int, int>& off) {
batch_offsets_.push_back(off);
}
void GetUsedSlot(std::vector<bool>* used_slot_index);
void GetUsedSlotIndex(std::vector<int>* used_slot_index);
// expand values
void ExpandSlotRecord(SlotRecord* ins);

Expand Down
25 changes: 11 additions & 14 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1492,8 +1492,10 @@ void PadBoxSlotDataset::MergeInsKeys(const Channel<SlotRecord>& in) {

std::vector<std::thread> feed_threads;
auto boxps_ptr = BoxWrapper::GetInstance();
// std::vector<bool> used_fea_index;
// ((SlotPaddleBoxDataFeed*)readers_[0].get())->GetUsedSlot(used_fea_index);

std::vector<int> used_fea_index;
(reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[0].get()))
->GetUsedSlotIndex(&used_fea_index);

input_records_.clear();
boxps::PSAgentBase* agent = boxps_ptr->GetAgent();
Expand All @@ -1504,24 +1506,19 @@ void PadBoxSlotDataset::MergeInsKeys(const Channel<SlotRecord>& in) {
}
std::mutex mutex;
for (int tid = 0; tid < thread_num; ++tid) {
feed_threads.push_back(std::thread([this, &in, agent, tid,
&mutex /**, &used_fea_index*/]() {
feed_threads.push_back(std::thread([this, &in, agent, tid, &mutex,
&used_fea_index]() {
SetCPUAffinity(tid, false);
size_t num = 0;
std::vector<SlotRecord> datas;
auto feed_obj =
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[tid].get());
while (in->Read(datas)) {
for (auto& rec : datas) {
agent->AddKeys(rec->slot_uint64_feasigns_.slot_values.data(),
rec->slot_uint64_feasigns_.slot_values.size(), tid);
// for (size_t i = 0; i < rec->slot_uint64_feasigns_.size();
// ++i) {
// if (!used_fea_index[i]) {
// continue;
// }
// auto &feas = rec->slot_uint64_feasigns_[i];
// agent->AddKeys(feas.data(), feas.size(), tid);
// }
for (auto& idx : used_fea_index) {
uint64_t* feas = rec->slot_uint64_feasigns_.get_values(idx, &num);
agent->AddKeys(feas, num, tid);
}
feed_obj->ExpandSlotRecord(&rec);
}

Expand Down

0 comments on commit dcbb4b9

Please sign in to comment.