Skip to content

Commit

Permalink
[hybrid performance] pipeline cache trainer (#33998)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jul 9, 2021
1 parent dfff52e commit 98c7191
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ class SectionWorker : public DeviceWorker {
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void PrepareUnusedVar();

protected:
int section_id_;
Expand All @@ -595,6 +596,8 @@ class SectionWorker : public DeviceWorker {

std::vector<std::unique_ptr<OperatorBase>> ops_;
std::shared_ptr<framework::ProgramDesc> program_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
static uint64_t batch_id_;

platform::DeviceContext* dev_ctx_ = nullptr;
Expand Down
19 changes: 14 additions & 5 deletions paddle/fluid/framework/pipeline_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,28 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
this_worker->SetRootScope(root_scope_);
this_worker->SetMinibatchScope(minibatch_scope_);
this_worker->SetMicrobatchScopes(microbatch_scopes_);
this_worker->PrepareUnusedVar();
}

void PipelineTrainer::Run() {
VLOG(5) << "Going to run PipelineTrainer::Run()";
section_thread_ = std::async(&DeviceWorker::TrainFiles, worker_.get());
}

void PipelineTrainer::Finalize() {
try {
section_thread_.get();
worker_->TrainFiles();
} catch (platform::EOFException& e) {
std::rethrow_exception(std::current_exception());
}
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scop->DropKids();
}
}

void PipelineTrainer::Finalize() {
if (need_dump_field_) {
FinalizeDumpEnv();
}
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/section_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ void SectionWorker::RunUpdate(
}
}

void SectionWorker::PrepareUnusedVar() {
VLOG(5) << "begin prepare the unsed vars";
unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
}

void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";

int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc;
auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
if (max_memory_size >= 0) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place_)) {
Expand Down
17 changes: 9 additions & 8 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,8 +1638,12 @@ def _get_real_program_fetch_list():
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)

trainer_desc = trainer._desc() # slow, cache
ctx = [trainer_desc, dataset, scope, real_fetch_list]
trainer_instance = self._default_executor.init_for_dataset(
program.desc, trainer_desc, scope, dataset.dataset)

ctx = [scope, real_fetch_list, trainer_instance]
if use_program_cache: self._add_ctx_cache(cache_key, ctx)

return ctx

def _run_pipeline(self,
Expand All @@ -1654,20 +1658,17 @@ def _run_pipeline(self,
print_period=100,
fetch_handler=None,
use_program_cache=False):
trainer_desc, dataset, scope, real_fetch_list = \
scope, real_fetch_list, trainer_instance = \
self._prepare_pipeline_ctx(program, dataset, scope, thread,
is_infer, debug, fetch_list, fetch_info,
print_period, fetch_handler,
use_program_cache)

trainer_instance = self._default_executor.init_for_dataset(
program.desc, trainer_desc, scope, dataset.dataset)

self._default_executor.run_from_dataset(trainer_instance)
self._default_executor.release_trainer(trainer_instance)

dataset._dynamic_adjust_after_train()
dataset._finish_to_run()
if not use_program_cache:
self._default_executor.release_trainer(trainer_instance)

if real_fetch_list:
arr = scope.find_var('fetch').get_fetch_list()
tensors = arr._move_to_list()
Expand Down

0 comments on commit 98c7191

Please sign in to comment.