Skip to content

Commit

Permalink
support multi-node (#35396)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoxuefeng6 authored Sep 7, 2021
1 parent 8307b0c commit c6e0ced
Show file tree
Hide file tree
Showing 16 changed files with 680 additions and 76 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,10 @@ cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor ${RPC_DEPS})
conditional_block_op executor gloo_wrapper ${RPC_DEPS})
else()
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor)
conditional_block_op executor gloo_wrapper)
endif()
cc_library(prune SRCS prune.cc DEPS framework_proto boost)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
Expand Down
182 changes: 157 additions & 25 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ bool InMemoryDataFeed<T>::Start() {
output_channel_->Write(std::move(data));
}
#endif
if (batch_offsets_.size() > 0) {
VLOG(3) << "batch_size offsets: " << batch_offsets_.size();
enable_heterps_ = true;
this->offset_index_ = 0;
}
this->finish_start_ = true;
return true;
}
Expand All @@ -265,34 +270,64 @@ template <typename T>
int InMemoryDataFeed<T>::Next() {
#ifdef _LINUX
this->CheckStart();
CHECK(output_channel_ != nullptr);
CHECK(consume_channel_ != nullptr);
VLOG(3) << "output_channel_ size=" << output_channel_->Size()
<< ", consume_channel_ size=" << consume_channel_->Size()
<< ", thread_id=" << thread_id_;
int index = 0;
T instance;
std::vector<T> ins_vec;
ins_vec.reserve(this->default_batch_size_);
while (index < this->default_batch_size_) {
if (output_channel_->Size() == 0) {
break;
if (!enable_heterps_) {
CHECK(output_channel_ != nullptr);
CHECK(consume_channel_ != nullptr);
VLOG(3) << "output_channel_ size=" << output_channel_->Size()
<< ", consume_channel_ size=" << consume_channel_->Size()
<< ", thread_id=" << thread_id_;
int index = 0;
T instance;
std::vector<T> ins_vec;
ins_vec.reserve(this->default_batch_size_);
while (index < this->default_batch_size_) {
if (output_channel_->Size() == 0) {
break;
}
output_channel_->Get(instance);
ins_vec.push_back(instance);
++index;
consume_channel_->Put(std::move(instance));
}
this->batch_size_ = index;
VLOG(3) << "batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
VLOG(3) << "finish reading, output_channel_ size="
<< output_channel_->Size()
<< ", consume_channel_ size=" << consume_channel_->Size()
<< ", thread_id=" << thread_id_;
}
output_channel_->Get(instance);
ins_vec.push_back(instance);
++index;
consume_channel_->Put(std::move(instance));
}
this->batch_size_ = index;
VLOG(3) << "batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
VLOG(3) << "finish reading, output_channel_ size="
<< output_channel_->Size()
<< ", consume_channel_ size=" << consume_channel_->Size()
VLOG(3) << "enable heter NEXT: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
if (offset_index_ >= batch_offsets_.size()) {
VLOG(3) << "offset_index: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
return 0;
}
auto& batch = batch_offsets_[offset_index_++];
this->batch_size_ = batch.second;
VLOG(3) << "batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(&records_[batch.first], this->batch_size_);
} else {
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
<< thread_id_;
}
/*
if (offset_index_ == batch_offsets_.size() - 1) {
std::vector<Record> data;
output_channel_->ReadAll(data);
consume_channel_->Write(std::move(data));
}
*/
VLOG(3) << "#15 enable heter NEXT: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size()
<< " baych_size: " << this->batch_size_;
}
return this->batch_size_;
#else
Expand Down Expand Up @@ -1141,6 +1176,103 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
return false;
}

void MultiSlotInMemoryDataFeed::PutToFeedVec(const Record* ins_vec, int num) {
#ifdef _LINUX
for (size_t i = 0; i < batch_float_feasigns_.size(); ++i) {
batch_float_feasigns_[i].clear();
batch_uint64_feasigns_[i].clear();
offset_[i].clear();
offset_[i].push_back(0);
}
ins_content_vec_.clear();
ins_content_vec_.reserve(num);
ins_id_vec_.clear();
ins_id_vec_.reserve(num);
for (int i = 0; i < num; ++i) {
auto& r = ins_vec[i];
ins_id_vec_.push_back(r.ins_id_);
ins_content_vec_.push_back(r.content_);
for (auto& item : r.float_feasigns_) {
batch_float_feasigns_[item.slot()].push_back(item.sign().float_feasign_);
visit_[item.slot()] = true;
}
for (auto& item : r.uint64_feasigns_) {
batch_uint64_feasigns_[item.slot()].push_back(
item.sign().uint64_feasign_);
visit_[item.slot()] = true;
}
for (size_t j = 0; j < use_slots_.size(); ++j) {
const auto& type = all_slots_type_[j];
if (visit_[j]) {
visit_[j] = false;
} else {
// fill slot value with default value 0
if (type[0] == 'f') { // float
batch_float_feasigns_[j].push_back(0.0);
} else if (type[0] == 'u') { // uint64
batch_uint64_feasigns_[j].push_back(0);
}
}
// get offset of this ins in this slot
if (type[0] == 'f') { // float
offset_[j].push_back(batch_float_feasigns_[j].size());
} else if (type[0] == 'u') { // uint64
offset_[j].push_back(batch_uint64_feasigns_[j].size());
}
}
}

for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset_[i].back();
const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float
float* feasign = batch_float_feasigns_[i].data();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
uint64_t* feasign = batch_uint64_feasigns_[i].data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
auto& slot_offset = offset_[i];
if (this->input_type_ == 0) {
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
} else if (this->input_type_ == 1) {
if (!use_slots_is_dense_[i]) {
std::vector<size_t> tmp_offset;
PADDLE_ENFORCE_EQ(slot_offset.size(), 2,
platform::errors::InvalidArgument(
"In batch reader, the sparse tensor lod size "
"must be 2, but received %d.",
slot_offset.size()));
const auto& max_size = slot_offset[1];
tmp_offset.reserve(max_size + 1);
for (unsigned int k = 0; k <= max_size; k++) {
tmp_offset.emplace_back(k);
}
slot_offset = tmp_offset;
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
}
}
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
}

void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<Record>& ins_vec) {
#ifdef _LINUX
Expand Down
18 changes: 15 additions & 3 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class DLManager {
}

paddle::framework::CustomParser* Load(const std::string& name,
std::vector<SlotConf>& conf) {
const std::vector<SlotConf>& conf) {
#ifdef _LINUX
std::lock_guard<std::mutex> lock(mutex_);
DLHandle handle;
Expand Down Expand Up @@ -195,7 +195,7 @@ class DLManager {
}

paddle::framework::CustomParser* ReLoad(const std::string& name,
std::vector<SlotConf>& conf) {
const std::vector<SlotConf>& conf) {
Close(name);
return Load(name, conf);
}
Expand Down Expand Up @@ -422,6 +422,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void ParseOneInstanceFromSo(const char* str, T* instance,
CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
virtual void PutToFeedVec(const T* ins_vec, int num) = 0;

int thread_id_;
int thread_num_;
Expand All @@ -439,6 +440,11 @@ class InMemoryDataFeed : public DataFeed {
paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;

std::vector<std::pair<int, int>> batch_offsets_;
uint64_t offset_index_ = 0;
bool enable_heterps_ = false;
T* records_ = nullptr;
};

// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
Expand Down Expand Up @@ -601,7 +607,7 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
for (size_t& x : offset) {
uint64_t t;
ar >> t;
x = (size_t)t;
x = static_cast<size_t>(t);
}
#endif
ar >> ins.MutableFloatData();
Expand Down Expand Up @@ -777,6 +783,11 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
void SetRecord(Record* records) { records_ = records; }
int GetDefaultBatchSize() { return default_batch_size_; }
void AddBatchOffset(const std::pair<int, int>& offset) {
batch_offsets_.push_back(offset);
}

protected:
virtual bool ParseOneInstance(Record* instance);
Expand All @@ -786,6 +797,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
virtual void PutToFeedVec(const Record* ins_vec, int num);
std::vector<std::vector<float>> batch_float_feasigns_;
std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
std::vector<std::vector<size_t>> offset_;
Expand Down
Loading

0 comments on commit c6e0ced

Please sign in to comment.