Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Nov 28, 2022
1 parent bc9c8b7 commit f43b31e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ NaiveExecutor::~NaiveExecutor() {
#endif
}

#ifdef PADDLE_WITH_TENSORRT
void NaiveExecutor::ResetTrtOps(int num) {
#ifdef PADDLE_WITH_TENSORRT
for (auto &op : ops_) {
if (op->Type() == "tensorrt_engine") {
operators::TensorRTEngineOp *trtop =
Expand Down Expand Up @@ -193,7 +193,8 @@ void NaiveExecutor::ResetTrtOps(int num) {
}
}
}
}
#endif
}

} // namespace framework
} // namespace paddle
2 changes: 0 additions & 2 deletions paddle/fluid/framework/naive_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class NaiveExecutor {

Scope* GetScope() { return scope_; }

#ifdef PADDLE_WITH_TENSORRT
void ResetTrtOps(int num);
#endif

void RegisterOutputHook(const HookFunc& hookfunc);

Expand Down
58 changes: 58 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,64 @@ bool AnalysisPredictor::LoadProgramDesc() {
return true;
}

bool AnalysisPredictor::LoadParameters() {
PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
platform::errors::PreconditionNotMet(
"The inference program should be loaded first."));

const auto &global_block = inference_program_->MutableBlock(0);

// create a temporary program to load parameters.

std::unique_ptr<framework::ProgramDesc> load_program(
new framework::ProgramDesc());
framework::BlockDesc *load_block = load_program->MutableBlock(0);
std::vector<std::string> params;

for (auto *var : global_block->AllVars()) {
if (IsPersistable(var)) {
VLOG(3) << "persistable variable's name: " << var->Name();

framework::VarDesc *new_var = load_block->Var(var->Name());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true);

if (!config_.params_file().empty()) {
params.push_back(new_var->Name());
} else {
// append_op
framework::OpDesc *op = load_block->AppendOp();
op->SetType("load");
op->SetOutput("Out", {new_var->Name()});
op->SetAttr("file_path", {config_.model_dir() + "/" + new_var->Name()});
op->CheckAttrs();
}
}
}

if (!config_.params_file().empty()) {
// sort paramlist to have consistent ordering
std::sort(params.begin(), params.end());
// append just the load_combine op
framework::OpDesc *op = load_block->AppendOp();
op->SetType("load_combine");
op->SetOutput("Out", params);
op->SetAttr("file_path", {config_.params_file()});
op->CheckAttrs();
}

// Use NaiveExecutor to Load parameters.
framework::NaiveExecutor e(place_);
e.Prepare(scope_.get(), *load_program, 0, false);
e.Run();
VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";

return true;
}

uint64_t AnalysisPredictor::TryShrinkMemory() {
ClearIntermediateTensor();
return paddle::memory::Release(place_);
Expand Down

0 comments on commit f43b31e

Please sign in to comment.