Skip to content

Commit

Permalink
Support control flow for static build [Step 3: support while] (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#57616)

* add conditional_block to OperatorBasesHandledInStaticBuild

* run op in FakeInitializeOutputsForOperatorBase

* add init_success judge

* fix build error

* fix

* add SetSubBlockCore func

* add PreStaticRun func

* add PreStaticRun to interpreter_base and new_ir_inter

* recover codes

* add PreStaticBuild and BlockCanBeStaticBuilt

* fix logic about RunPreStaticBuild

* change CreateOpFromOpDesc type

* fix build error

* fix build error

* remove IsOperatorBasesHandledInStaticBuild

* recover BlockCanBeStaticBuilt

* add logic about conditional_block run static build

* recover codes

* recover BlockCanBeStaticBuilt

* support static build condational block op when condational block is the last op in the block

* fix error

* fix logic about last op

* fit for sub block can't open static build

* add IsStaticBuild

* fix build error

* fit logic when sub block can't open static build

* close static build when sub_block don't support static_build

* recover third party

* add is_skil_fake_init logic

* set the backend of the lamb

* change start index

* add if conditional for cal is_skip_fake_init

* change name

* close static_build for test_conditional_block

* add static buiild support for conditional block in case of the output's dtype/place is changed but the following op is not use this output

* fix logic error

* fix timeout error

* fix

* remove useless codes

* fix

* fix

* fix build error

* move GetVarsInfo and RunPreStaticBuild from opeartor to static_build

* fix lamb backend registe

* fix build error

* fix build error

* remove lamp op test from new_ir_op_test_white_list

* fix

* move generating following_input_vars logic to static_build.cc

* remove HasInfo

* fix build error

* recover codes and turn off the flag

* add support for while

* fix
  • Loading branch information
AndSonder authored Sep 23, 2023
1 parent 21a9d41 commit a68d9a9
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 43 deletions.
216 changes: 207 additions & 9 deletions paddle/fluid/framework/new_executor/interpreter/static_build.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/controlflow/control_flow_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/platform/flags.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

PHI_DECLARE_bool(cache_inference_while_scope);

// These Ops is OperatorBase, but we have been handle them in static build
std::set<std::string> OperatorBasesHandledInStaticBuild = {"read",
"conditional_block"};
std::set<std::string> OperatorBasesHandledInStaticBuild = {
"read", "conditional_block", "while"};

std::set<std::string> OperatorBasesMustRunInStaticBuild = {
"create_double_buffer_reader", "create_py_reader"};
Expand Down Expand Up @@ -386,9 +391,9 @@ void FakeInitializeTensorBase(const platform::DeviceContext& dev_ctx,
}
}

void RunPreStaticBuild(const framework::Scope& scope,
const platform::Place& dev_place,
const OperatorBase& op) {
void RunConditionalBlockPreStaticBuild(const framework::Scope& scope,
const platform::Place& dev_place,
const OperatorBase& op) {
auto* scope_var = scope.FindVar(op.Output("Scope"));
PADDLE_ENFORCE_NOT_NULL(
scope_var,
Expand Down Expand Up @@ -434,6 +439,193 @@ void RunPreStaticBuild(const framework::Scope& scope,
core->Build({}, &op_func_nodes);
}

void RunWhileBlockPreStaticBuild(const framework::Scope& scope,
const platform::Place& dev_place,
const OperatorBase& op) {
PADDLE_ENFORCE_NOT_NULL(
scope.FindVar(op.Input("Condition")),
platform::errors::NotFound("Input(Condition) of WhileOp is not found."));

#ifdef PADDLE_WITH_DNNL
// Executor on being destroyed clears oneDNN cache and resets
// registered model data layout. This is unwanted for nested
// Executors (executors declared inside control ops)
platform::DontClearMKLDNNCache(dev_place);
#endif
auto* block = op.Attr<framework::BlockDesc*>("sub_block");

// get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place);

bool is_test = op.Attr<bool>("is_test");

std::set<std::string> no_copy_var_names;
if (!is_test) {
// set all persistable parameters into no_copy_var_names.
auto* global_block = block;

while (global_block->ID() != 0) global_block = global_block->ParentBlock();
auto all_vars = global_block->AllVars();
std::for_each(all_vars.begin(),
all_vars.end(),
[&no_copy_var_names](framework::VarDesc* var) {
if (var->IsParameter())
no_copy_var_names.insert(var->Name());
});

const std::vector<framework::OpDesc*>& all_ops = block->AllOps();
for (const framework::OpDesc* item : all_ops) {
const framework::VariableNameMap& input_var_names = item->Inputs();
const framework::VariableNameMap& output_var_names = item->Outputs();
for (auto& ipt : input_var_names) {
for (const std::string& var_name : ipt.second) {
if (operators::StrInVaraiableNameMap(var_name, output_var_names)) {
no_copy_var_names.insert(var_name);
}
}
}
}
}

auto step_scopes = scope.FindVar(op.Output("StepScopes"))
->GetMutable<std::vector<framework::Scope*>>();

if (!step_scopes->empty()) {
platform::DeviceContextPool::Instance().Get(dev_place)->Wait();
for (auto& s : *step_scopes) {
if (scope.HasKid(s)) {
scope.DeleteScope(s);
}
}
step_scopes->clear();
}

PADDLE_ENFORCE_EQ(step_scopes->size(),
0,
platform::errors::PreconditionNotMet(
"The Output(StepScope) of WhileOp should be empty."));

auto& skip_vars =
op.Attr<std::vector<std::string>>("skip_eager_deletion_vars");

// note(lvyongkang): The assign op in while loop may change the place of
// variable. However, InterpreterCore fix the kernel of every ops during its
// first run. A cpu tensor may become gpu tensor after first run. This will
// lead to segmetation fault when it's used in a cpu kernel. Here we record
// the place of every inputs and restore their place after
// InterpreterCore.run().
std::map<std::string, phi::Place> input_var_original_places;
for (const auto& in_name : op.Inputs("X")) {
framework::Variable* var = scope.FindVar(in_name);
if (var == nullptr) {
VLOG(4) << "[while op]"
<< "input not found:" << in_name;
}

if (var->Type() == framework::proto::VarType::LOD_TENSOR) {
input_var_original_places[in_name] =
(var->Get<phi::DenseTensor>()).place();
} else {
VLOG(10) << "[while op]"
<< "skip backup input " << in_name << " type:"
<< framework::TransToPhiDataType(
framework::ToVarType(var->Type()));
}
}

LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running.";
std::unique_ptr<InterpreterCore> core;

framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(skip_vars.begin(), skip_vars.end());

core.reset(new framework::InterpreterCore(
dev_place, *block, &placeholder, execution_config));

if (!is_test) {
auto& current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);

std::vector<std::string> rename_vars;
for (const std::string& input_var_name : op.Inputs("X")) {
if (no_copy_var_names.find(input_var_name) == no_copy_var_names.end()) {
std::string input_var_rename = input_var_name + "@TMP_COPY";
framework::Variable* input_var = scope.FindVar(input_var_name);
if (input_var->IsType<phi::DenseTensor>()) {
rename_vars.push_back(input_var_rename);
auto input_var_tensor = input_var->Get<phi::DenseTensor>();
auto* rename_input_var_tensor = current_scope.Var(input_var_rename)
->GetMutable<phi::DenseTensor>();
framework::TensorCopy(
input_var_tensor, dev_place, rename_input_var_tensor);
rename_input_var_tensor->set_lod(input_var_tensor.lod());
}
}
}

operators::BuildScopeForControlFlowOp(*core, *block, &current_scope);
core->reset_scope(&current_scope);

std::vector<paddle::framework::OpFuncNode> op_func_nodes;
core->Build({}, &op_func_nodes);

// restore inputs place
for (const auto& n : input_var_original_places) {
const std::string& in_name = n.first;
const phi::Place& original_place = n.second;
// input vars exist in `scope` not `current_scope`
operators::TransferVariablePlace(
&scope, in_name, original_place, dev_ctx);
}

for (auto& var_rename : rename_vars) {
std::string input_var_name =
var_rename.substr(0, var_rename.size() - strlen("@TMP_COPY"));
current_scope.Rename(var_rename, input_var_name);
}
} else {
framework::Scope* current_scope = nullptr;
if (!FLAGS_cache_inference_while_scope) {
current_scope = &(scope.NewScope());
operators::BuildScopeForControlFlowOp(*core, *block, current_scope);
core->reset_scope(current_scope);
} else {
auto cached_inference_scope = &(scope.NewScope());
operators::BuildScopeForControlFlowOp(
*core, *block, cached_inference_scope);
core->reset_scope(cached_inference_scope);
current_scope = cached_inference_scope;
}

for (auto& name : current_scope->LocalVarNames()) {
auto* var = current_scope->Var(name);
if (var->IsType<phi::DenseTensor>()) {
// Clear all lod information for all lod_tensors.
auto* t = var->GetMutable<phi::DenseTensor>();
framework::LoD empty_lod;
t->set_lod(empty_lod);
} else if (var->IsType<framework::LoDTensorArray>()) {
// Clear elements of all tensor arrays.
auto* t = var->GetMutable<framework::LoDTensorArray>();
t->clear();
}
}

std::vector<paddle::framework::OpFuncNode> op_func_nodes;
core->Build({}, &op_func_nodes);

if (!FLAGS_cache_inference_while_scope) {
scope.DeleteScope(current_scope);
}
}
}

void FakeInitializeOutputsForOperatorBase(
const OperatorBase& op,
const phi::Place& place,
Expand All @@ -447,7 +639,7 @@ void FakeInitializeOutputsForOperatorBase(
phi::DeviceContext* dev_ctx =
platform::DeviceContextPool::Instance().Get(place);

if (op_type == "conditional_block") {
if (op_type == "conditional_block" || op_type == "while") {
// Note(sonder): skip fake init for conditional_block when there is no
// op with kernel after it.
bool skip_fake_init = true;
Expand All @@ -456,7 +648,7 @@ void FakeInitializeOutputsForOperatorBase(
for (size_t i = 0; i < following_ops.size(); ++i) {
if (dynamic_cast<framework::OperatorWithKernel*>(
following_ops[i].get()) != nullptr) {
VLOG(4) << "Find op with kernel after conditional_block : "
VLOG(4) << "Find op with kernel after " << op_type << ": "
<< following_ops[i]->Type();
skip_fake_init = false;
auto input_vars_info = GetVarsInfo(
Expand All @@ -474,7 +666,12 @@ void FakeInitializeOutputsForOperatorBase(
const std::vector<VarMetaInfo> out_var_info_before_build =
GetVarsInfo(scope, op.Outputs(), op);

RunPreStaticBuild(*scope, place, op);
if (op_type == "conditional_block") {
RunConditionalBlockPreStaticBuild(*scope, place, op);
} else {
RunWhileBlockPreStaticBuild(*scope, place, op);
}

const std::vector<VarMetaInfo> out_var_info_after_build =
GetVarsInfo(scope, op.Outputs(), op);

Expand All @@ -487,10 +684,11 @@ void FakeInitializeOutputsForOperatorBase(
auto var_name = out_var_info_before_build[i].name_;
if (following_input_vars.count(var_name)) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"The output %s s' dtype/place of conditional_block is "
"The output %s s' dtype/place of %s is "
"changed after static build. Befer static build, the "
"dtype is %s, place is %s. After static "
"build, the dtype is %s, place is %s.",
op_type,
var_name,
out_var_info_before_build[i].dtype_,
out_var_info_before_build[i].place_,
Expand Down
34 changes: 0 additions & 34 deletions paddle/fluid/operators/controlflow/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,40 +51,6 @@ static std::string GetSkipEagerDeletionVarsDebugString(
return str;
}

static void TransferVariablePlace(const framework::Scope *scope,
const std::string &var_name,
const phi::Place &dst_place,
const platform::DeviceContext &dev_ctx) {
framework::Variable *var = scope->FindVar(var_name);
if (var == nullptr) {
VLOG(4) << "[TransferVariablePlace]"
<< "lost in_var: " << var_name;
return;
}
if (var->Type() != framework::proto::VarType::LOD_TENSOR) {
VLOG(10) << "[TransferVariablePlace]" << var_name << " type changed:"
<< framework::TransToPhiDataType(
framework::ToVarType(var->Type()));
return;
}
phi::DenseTensor *t = var->GetMutable<phi::DenseTensor>();
if (t->place() == dst_place) {
VLOG(10) << "[TransferVariablePlace]"
<< "no need transfer: " << var_name;
return;
}

phi::DenseTensor *new_t = new phi::DenseTensor;
framework::TensorCopy(*t, dst_place, new_t);
dev_ctx.Wait();

t->set_meta(new_t->meta());
t->ResetHolder(new_t->Holder());

VLOG(4) << "[TransferVariablePlace]" << var_name
<< " place: " << new_t->place();
}

} // namespace

class WhileOp : public framework::OperatorBase {
Expand Down
34 changes: 34 additions & 0 deletions paddle/fluid/operators/controlflow/while_op_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,39 @@ bool StrInVaraiableNameMap(const std::string &name,
return false;
}

void TransferVariablePlace(const framework::Scope *scope,
const std::string &var_name,
const phi::Place &dst_place,
const platform::DeviceContext &dev_ctx) {
framework::Variable *var = scope->FindVar(var_name);
if (var == nullptr) {
VLOG(4) << "[TransferVariablePlace]"
<< "lost in_var: " << var_name;
return;
}
if (var->Type() != framework::proto::VarType::LOD_TENSOR) {
VLOG(10) << "[TransferVariablePlace]" << var_name << " type changed:"
<< framework::TransToPhiDataType(
framework::ToVarType(var->Type()));
return;
}
phi::DenseTensor *t = var->GetMutable<phi::DenseTensor>();
if (t->place() == dst_place) {
VLOG(10) << "[TransferVariablePlace]"
<< "no need transfer: " << var_name;
return;
}

phi::DenseTensor *new_t = new phi::DenseTensor;
framework::TensorCopy(*t, dst_place, new_t);
dev_ctx.Wait();

t->set_meta(new_t->meta());
t->ResetHolder(new_t->Holder());

VLOG(4) << "[TransferVariablePlace]" << var_name
<< " place: " << new_t->place();
}

} // namespace operators
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/fluid/operators/controlflow/while_op_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/op_variant.h"

namespace phi {
Expand Down Expand Up @@ -58,5 +59,10 @@ bool GetCondData(const phi::DenseTensor &cond);
bool StrInVaraiableNameMap(const std::string &,
const framework::VariableNameMap &);

void TransferVariablePlace(const framework::Scope *scope,
const std::string &var_name,
const phi::Place &dst_place,
const platform::DeviceContext &dev_ctx);

} // namespace operators
} // namespace paddle

0 comments on commit a68d9a9

Please sign in to comment.