Skip to content

Commit

Permalink
[PSLIB] Add Metrics Module, Support User-defined Add Metric (#38789)
Browse files Browse the repository at this point in the history
* [PSLIB] Add Metrics Module, Support User-defined Add Metric

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* modify role_maker

* update CMakeLists.txt
  • Loading branch information
WorgenZhang authored Jan 28, 2022
1 parent 3ef2922 commit 2e6be88
Show file tree
Hide file tree
Showing 21 changed files with 1,440 additions and 10 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ if(WITH_DISTRIBUTE)
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer
lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto fleet_executor ${BRPC_DEP})
Expand All @@ -315,7 +315,7 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
index_sampler index_wrapper sampler index_dataset_proto
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
Expand All @@ -336,7 +336,7 @@ if(WITH_DISTRIBUTE)
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor fleet_executor)
endif()
elseif(WITH_PSLIB)
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
Expand Down Expand Up @@ -498,6 +499,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}

template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}

template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
Expand Down Expand Up @@ -1047,6 +1053,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape);
}
}
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size());
const int kEstimatedFeasignNumPerSlot = 5; // Magic Number
for (size_t i = 0; i < all_slot_num; i++) {
Expand Down Expand Up @@ -1160,6 +1167,19 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
#ifdef PADDLE_WITH_PSLIB
if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
platform::errors::PreconditionNotMet(
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str));

char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
#endif
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ struct Record {
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
std::string uid_;
};

inline SlotRecord make_slotrecord() {
Expand Down Expand Up @@ -562,6 +563,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
Expand Down Expand Up @@ -645,6 +647,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
std::string uid_slot_;

// The input type of pipe reader, 0 for one sample, 1 for one batch
int input_type_;
Expand Down Expand Up @@ -709,6 +712,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
Expand Down Expand Up @@ -737,6 +741,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
bool parse_uid_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/data_feed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor
}

message MultiSlotDesc { repeated Slot slots = 1; }
message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}

message DataFeedDesc {
optional string name = 1;
Expand Down
19 changes: 16 additions & 3 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false;
preload_thread_num_ = 0;
global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
}

// set filelist, file_idx_ will reset to zero.
Expand Down Expand Up @@ -150,6 +152,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}

template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}

template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
Expand Down Expand Up @@ -664,11 +672,14 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
<< input_channel_->Size();

auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
if (this->merge_by_insid_) {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
}
};

Expand Down Expand Up @@ -902,6 +913,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFeaNum(&total_fea_num_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
Expand Down Expand Up @@ -972,6 +984,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
preload_readers_[i]->SetFeaNum(&total_fea_num_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Dataset {
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
Expand Down Expand Up @@ -189,6 +190,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);

virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
Expand Down Expand Up @@ -307,6 +309,8 @@ class DatasetImpl : public Dataset {
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
Expand Down
26 changes: 25 additions & 1 deletion paddle/fluid/framework/downpour_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h"

namespace pten {
Expand All @@ -32,7 +33,6 @@ class Variable;

namespace paddle {
namespace framework {

void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) {
Expand Down Expand Up @@ -740,6 +740,23 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}

#ifdef PADDLE_WITH_PSLIB
/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}
#endif

void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
Expand Down Expand Up @@ -837,6 +854,13 @@ void DownpourWorker::TrainFiles() {
}
}

#ifdef PADDLE_WITH_PSLIB
// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}
#endif

// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ endif(WITH_BOX_PS)

if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO)

if(WITH_PSLIB)
Expand Down
Loading

0 comments on commit 2e6be88

Please sign in to comment.