From f43b31e8f32ea1a5e91dd209d481ae3226b000bb Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 28 Nov 2022 04:29:29 +0000 Subject: [PATCH] update --- paddle/fluid/framework/naive_executor.cc | 5 +- paddle/fluid/framework/naive_executor.h | 2 - .../fluid/inference/api/analysis_predictor.cc | 58 +++++++++++++++++++ 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index e310f30c8cd7e..57e9a175b16f2 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -157,8 +157,8 @@ NaiveExecutor::~NaiveExecutor() { #endif } -#ifdef PADDLE_WITH_TENSORRT void NaiveExecutor::ResetTrtOps(int num) { +#ifdef PADDLE_WITH_TENSORRT for (auto &op : ops_) { if (op->Type() == "tensorrt_engine") { operators::TensorRTEngineOp *trtop = @@ -193,7 +193,8 @@ void NaiveExecutor::ResetTrtOps(int num) { } } } -} #endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index bcc51897e4e48..882f50b451a29 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -67,9 +67,7 @@ class NaiveExecutor { Scope* GetScope() { return scope_; } -#ifdef PADDLE_WITH_TENSORRT void ResetTrtOps(int num); -#endif void RegisterOutputHook(const HookFunc& hookfunc); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index fb5b407a094a5..7c7df22ff64d0 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1921,6 +1921,64 @@ bool AnalysisPredictor::LoadProgramDesc() { return true; } +bool AnalysisPredictor::LoadParameters() { + PADDLE_ENFORCE_NOT_NULL(inference_program_.get(), + platform::errors::PreconditionNotMet( + "The inference program should be loaded first.")); + + const auto &global_block = inference_program_->MutableBlock(0); + + // create a temporary program to load parameters. + + std::unique_ptr load_program( + new framework::ProgramDesc()); + framework::BlockDesc *load_block = load_program->MutableBlock(0); + std::vector params; + + for (auto *var : global_block->AllVars()) { + if (IsPersistable(var)) { + VLOG(3) << "persistable variable's name: " << var->Name(); + + framework::VarDesc *new_var = load_block->Var(var->Name()); + new_var->SetShape(var->GetShape()); + new_var->SetDataType(var->GetDataType()); + new_var->SetType(var->GetType()); + new_var->SetLoDLevel(var->GetLoDLevel()); + new_var->SetPersistable(true); + + if (!config_.params_file().empty()) { + params.push_back(new_var->Name()); + } else { + // append_op + framework::OpDesc *op = load_block->AppendOp(); + op->SetType("load"); + op->SetOutput("Out", {new_var->Name()}); + op->SetAttr("file_path", {config_.model_dir() + "/" + new_var->Name()}); + op->CheckAttrs(); + } + } + } + + if (!config_.params_file().empty()) { + // sort paramlist to have consistent ordering + std::sort(params.begin(), params.end()); + // append just the load_combine op + framework::OpDesc *op = load_block->AppendOp(); + op->SetType("load_combine"); + op->SetOutput("Out", params); + op->SetAttr("file_path", {config_.params_file()}); + op->CheckAttrs(); + } + + // Use NaiveExecutor to Load parameters. + framework::NaiveExecutor e(place_); + e.Prepare(scope_.get(), *load_program, 0, false); + e.Run(); + VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load"; + + return true; +} + uint64_t AnalysisPredictor::TryShrinkMemory() { ClearIntermediateTensor(); return paddle::memory::Release(place_);