Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Heterps]Refactor heterogenous worker #37244

Merged
merged 25 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions paddle/fluid/distributed/service/heter_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,24 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {

virtual ~RequestSendAndRecvHandler() {}

// void SetMiniScopes(SharedMiniScope mini_scopes) {
// mini_scopes_ = mini_scopes;
// num_minibatch_ = mini_scopes_->size();
//}
void SetMiniScopes(SharedMiniScope mini_scopes) {
mini_scopes_ = mini_scopes;
num_minibatch_ = mini_scopes_->size();
}

void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes;
num_microbatch_ = micro_scopes_->size();
for (auto& scope_pair : (*micro_scopes_)) {
// auto mini_idx = scope_pair.first;
auto& micro_scopes = scope_pair.second;
num_microbatch_ = micro_scopes->size();
break;
}
}

int GetThreadNum() {
std::unique_lock<std::mutex> lk(scope_mutex_);
return (*task_queue_).size();
}

void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
Expand Down Expand Up @@ -235,25 +246,43 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;

// PADDLE_ENFORCE_EQ(
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
// platform::errors::InvalidArgument(
// "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
// check minibatch_index is in mini_scopes_
std::unique_lock<std::mutex> lk(scope_mutex_);
if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
lk.unlock();
// PADDLE_ENFORCE_EQ(
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
// platform::errors::InvalidArgument(
// "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));

} else {
// create mini scope & micro scopes
auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
}
(*task_queue_)[minibatch_index].reset(
new ::paddle::framework::BlockingQueue<
std::pair<std::string, int>>());
lk.unlock();
}

auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index];

distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);

// blocking queue handles multi thread
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));

auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
Expand All @@ -269,11 +298,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {

private:
// share with HeterPipelineTrainer
// SharedMiniScope mini_scopes_{nullptr};
SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr};

int num_microbatch_;
int num_minibatch_;
std::mutex scope_mutex_;

bool is_first_stage_ = false;
bool is_last_stage_ = false;
Expand Down Expand Up @@ -321,14 +351,16 @@ class HeterServer {
request_handler_ = request_handler;
}

// void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
// request_handler_->SetMiniScopes(mini_scopes);
//}
void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
request_handler_->SetMiniScopes(mini_scopes);
}

void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes);
}

int GetThreadNum() { return request_handler_->GetThreadNum(); }

void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue);
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,17 @@ class HeterSectionWorker : public DeviceWorker {
std::shared_ptr<std::vector<Scope*>> GetMicrobatchScopes() {
return microbatch_scopes_;
}
void SetMicrobatchScopes(
std::shared_ptr<std::vector<Scope*>> microbatch_scopes) {
microbatch_scopes_ = microbatch_scopes;
}
using SHARED_THREAD_QUEUE = std::shared_ptr<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>;

SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; }
void SetThreadQueue(SHARED_THREAD_QUEUE thread_queue) {
thread_queue_ = thread_queue;
}
void CopyParameters(int microbatch_id, const ProgramDesc& program,
const platform::Place& place);
void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; }
Expand Down
130 changes: 100 additions & 30 deletions paddle/fluid/framework/heter_pipeline_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
ParseDumpConfig(trainer_desc);
SetDebug(trainer_desc.debug());
// for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
// i++) {
// need_merge_var_names_.push_back(
// trainer_desc.downpour_param().stat_var_names(i));
//}
// get filelist from trainer_desc here
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
Expand All @@ -83,34 +77,51 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
trainers_.push_back(trainer_num);
}
int cpu_trainer_num = trainers_[0];
int cur_stage_trainer_num = trainers_[pipeline_stage_];
int global_thread_num = cpu_trainer_num * thread_num_;
int previous_trainers = 0;
for (int i = 0; i < pipeline_stage_; i++) previous_trainers += trainers_[i];
int stage_trainer_id =
trainer_id_ - previous_trainers; // trainer id in current stage
int cnt = -1;
for (int i = stage_trainer_id; i < global_thread_num;
i += cur_stage_trainer_num) {
cnt++;
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
// int cur_stage_trainer_num = trainers_[pipeline_stage_];
// int global_thread_num = cpu_trainer_num * thread_num_;
// int previous_trainers = 0;
// for (int i = 0; i < pipeline_stage_; i++) previous_trainers +=
// trainers_[i];
// int stage_trainer_id =
// trainer_id_ - previous_trainers; // trainer id in current stage

if (pipeline_stage_ == 0) { // for cpu trainer
int cnt = -1;
int real_thread_id = trainer_id_;
for (int i = 0; i < thread_num_; i++) {
cnt++;
workers_[real_thread_id] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[real_thread_id]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(real_thread_id);
real_thread_id += cpu_trainer_num;
// if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
//}
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
}
} else { // for heter_trainer
// heter trainer with thread_id == -1 is not for
// real training
workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[i]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(i);
if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
}
workers_[-1]);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetDeviceIndex(-1);
}
}

Expand Down Expand Up @@ -183,7 +194,7 @@ void HeterPipelineTrainer::Run() {
}
auto heter_server = paddle::distributed::HeterServer::GetInstance();
heter_server->WaitServerReady();
// heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMicroBatchScopes(micro_scopes_);
heter_server->SetTaskQueue(task_queue_);
// main training logic
Expand All @@ -199,6 +210,7 @@ void HeterPipelineTrainer::Run() {
}
}
} else { // for heter worker
// start thread_worker with thread_id = -1
for (auto& worker_pair : workers_) {
auto device_worker = worker_pair.second;
if (!debug_) {
Expand All @@ -209,6 +221,60 @@ void HeterPipelineTrainer::Run() {
device_worker.get()));
}
}
bool epoch_finish = false;
auto heter_server = paddle::distributed::HeterServer::GetInstance();
while (!epoch_finish) {
if (heter_server->IsStop()) {
epoch_finish = true;
continue;
}
// create new thread_worker
// size_t thread_num = (*micro_scopes_).size();
// size_t thread_num = (*task_queue_).size();
size_t thread_num = heter_server->GetThreadNum();
while (thread_num > threads_.size()) {
for (auto& worker_pair : (*micro_scopes_)) {
auto worker_index = worker_pair.first;
if (workers_.find(worker_index) != workers_.end()) continue;
workers_[worker_index] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc_.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[worker_index]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc_);
this_worker->SetDeviceIndex(worker_index);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc_);
this_worker->SetRootScope(root_scope_);

// generate mini_batch scope for every worker
// auto* minibatch_scope = &root_scope_->NewScope();
auto* minibatch_scope = (*mini_scopes_)[worker_index];
// (*mini_scopes_)[worker_index] = minibatch_scope;
this_worker->SetMinibatchScope(minibatch_scope);
// after set micro num & mini batch scope
this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
this_worker->CreateMicrobatchScopes();
// this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
this_worker->SetThreadQueue((*task_queue_)[worker_index]);
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, this_worker.get()));
} else {
threads_.push_back(std::thread(
&DeviceWorker::TrainFilesWithProfiler, this_worker.get()));
}
}
}
}
}
for (auto& th : threads_) {
th.join();
Expand All @@ -234,7 +300,11 @@ void HeterPipelineTrainer::Finalize() {
}

Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope();
if (workers_.find(thread_id) != workers_.end()) {
return workers_[thread_id]->GetThreadScope();
} else {
return nullptr;
}
}

} // end namespace framework
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/framework/heter_pipeline_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined PADDLE_WITH_PSCORE)
#if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE)
#include "gtest/gtest.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
Expand Down
Loading