Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PSLIB] Add Metrics Module, Support User-defined Add Metric #38789

Merged
merged 18 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
Expand Down Expand Up @@ -498,6 +499,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}

template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}

template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
Expand Down Expand Up @@ -1047,6 +1053,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape);
}
}
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size());
const int kEstimatedFeasignNumPerSlot = 5; // Magic Number
for (size_t i = 0; i < all_slot_num; i++) {
Expand Down Expand Up @@ -1160,6 +1167,19 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
#ifdef PADDLE_WITH_PSLIB
if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
platform::errors::PreconditionNotMet(
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str));

char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
#endif
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ struct Record {
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
std::string uid_;
};

inline SlotRecord make_slotrecord() {
Expand Down Expand Up @@ -559,6 +560,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
Expand Down Expand Up @@ -642,6 +644,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
std::string uid_slot_;

// The input type of pipe reader, 0 for one sample, 1 for one batch
int input_type_;
Expand Down Expand Up @@ -706,6 +709,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
Expand Down Expand Up @@ -734,6 +738,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
bool parse_uid_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/data_feed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor
}

message MultiSlotDesc { repeated Slot slots = 1; }
message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}

message DataFeedDesc {
optional string name = 1;
Expand Down
19 changes: 16 additions & 3 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false;
preload_thread_num_ = 0;
global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
}

// set filelist, file_idx_ will reset to zero.
Expand Down Expand Up @@ -150,6 +152,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}

template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}

template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
Expand Down Expand Up @@ -664,11 +672,14 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
<< input_channel_->Size();

auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
if (this->merge_by_insid_) {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 此处 globalshuffle 在二次开发时需要新增 if/else,建议按照 neg sampling #36564 方法重构;
  2. 类、函数、变量命名按统一ps规范来 https://ku.baidu-int.com/knowledge/HFVrC7hq1Q/pKzJfZczuc/UQuex_mf2m/sTOYzTS6so3HUG
  3. 文件是否可以放入统一ps的目录中,在统一ps未开发完之前也能正常调用?
  4. 此处的 BasicAucCalculator 和 paddlebox 中的 BasicAucCalculator 是什么关系?
  5. 需要新增 metrics 基类,方便后续开发

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先进行功能合入,后续再考虑放入统一ps

this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
}
};

Expand Down Expand Up @@ -902,6 +913,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFeaNum(&total_fea_num_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
Expand Down Expand Up @@ -972,6 +984,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
preload_readers_[i]->SetFeaNum(&total_fea_num_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Dataset {
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
Expand Down Expand Up @@ -189,6 +190,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);

virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
Expand Down Expand Up @@ -307,6 +309,8 @@ class DatasetImpl : public Dataset {
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
Expand Down
26 changes: 25 additions & 1 deletion paddle/fluid/framework/downpour_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h"

namespace paddle {
Expand All @@ -29,7 +30,6 @@ class Variable;

namespace paddle {
namespace framework {

void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) {
Expand Down Expand Up @@ -737,6 +737,23 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}

#ifdef PADDLE_WITH_PSLIB
/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}
#endif

void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
Expand Down Expand Up @@ -834,6 +851,13 @@ void DownpourWorker::TrainFiles() {
}
}

#ifdef PADDLE_WITH_PSLIB
// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}
#endif

// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ endif(WITH_BOX_PS)

if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO)

if(WITH_PSLIB)
Expand Down
Loading