Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support if op exe with inplace pass #57949

Merged
merged 15 commits into from
Oct 10, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ void GetInputIds(pir::Operation* op,
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);
input_ids->emplace(value, inputs_id);
input_ids->emplace(value, GetValueIds(value, value_exec_info));
}
}
}
Expand All @@ -92,9 +91,7 @@ void GetOutsideOpInputs(
"input should in name map, [%d] 'th input of [%s] op",
i,
op->name()));
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);

input_ids->emplace(value, inputs_id);
input_ids->emplace(value, GetValueIds(value, value_exec_info));
}
}
}
Expand Down Expand Up @@ -181,14 +178,22 @@ CondInstruction::CondInstruction(size_t id,
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> outputs_id = GetValueIds(value, *value_exec_info);
outputs.emplace(value, outputs_id);
outputs.emplace(value, GetValueIds(value, *value_exec_info));
}
}
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";
}

CondInstruction::~CondInstruction() {
if (true_branch_inter_ != nullptr) {
delete true_branch_inter_;
}
if (false_branch_inter_ != nullptr) {
delete false_branch_inter_;
}
}

void CondInstruction::CopyBranchOutput(
const std::vector<std::string>& var_names, const NewIRInterpreter* inter) {
for (size_t i = 0; i < var_names.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class CondInstruction : public InstructionBase {
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

~CondInstruction();

void Run() override;

const std::string& Name() const override { return cond_name_; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <vector>

#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/pir/core/builtin_attribute.h"
Expand Down Expand Up @@ -148,7 +150,8 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) {

auto& op_attributes = op->attributes();

if ((op->dialect()->name() == "pd_kernel") &&
if ((op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) ==
0) &&
(op_attributes.count("kernel_key") > 0)) {
auto kernel_key = op_attributes.at("kernel_key")
.dyn_cast<dialect::KernelAttribute>()
Expand Down Expand Up @@ -179,7 +182,7 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) {
return OpFuncType::kGpuSync;
}

if (op_name == "pd_op.shape") {
if (op_name.compare(paddle::dialect::ShapeOp::name()) == 0) {
return OpFuncType::kGpuSync;
}
}
Expand Down
30 changes: 23 additions & 7 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h"
Expand Down Expand Up @@ -121,11 +122,10 @@ static std::unordered_set<pir::Value> GetSkipDeletionValues(pir::Block* block) {
// NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator
// is supported. Therefore, this function only returns the values in the
// kernel_dialect operator that can be eager deleted.
static std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
GetEagerDeletionValues(pir::Block* block) {
std::unordered_set<pir::Value> skip_dels = GetSkipDeletionValues(block);

std::unordered_map<pir::Value, pir::Operation*> del_value_2_op;
static void GetEagerDelValueOfOp(
pir::Block* block,
const std::unordered_set<pir::Value>& skip_dels,
std::unordered_map<pir::Value, pir::Operation*>* del_value_2_op) {
for (auto& op : *block) {
std::string upper_op_name = op->name();
if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) ==
Expand All @@ -150,16 +150,32 @@ GetEagerDeletionValues(pir::Block* block) {
VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(op, input);
continue;
}
del_value_2_op[input] = op;
(*del_value_2_op)[input] = op;
}

for (size_t i = 0; i < op->num_results(); ++i) {
pir::Value output = op->result(i);
if (output && CanBeDeleted(output)) {
del_value_2_op[output] = op;
(*del_value_2_op)[output] = op;
}
}

if (op->isa<paddle::dialect::IfOp>()) {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
GetEagerDelValueOfOp(if_op.true_block(), skip_dels, del_value_2_op);
VLOG(8) << "GetEagerDelValueOfOp for IfOp true block";
GetEagerDelValueOfOp(if_op.false_block(), skip_dels, del_value_2_op);
VLOG(8) << "GetEagerDelValueOfOp for IfOp false block";
}
}
}

static std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
GetEagerDeletionValues(pir::Block* block) {
std::unordered_set<pir::Value> skip_dels = GetSkipDeletionValues(block);

std::unordered_map<pir::Value, pir::Operation*> del_value_2_op;
GetEagerDelValueOfOp(block, skip_dels, &del_value_2_op);

std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
eager_dels;
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def false_func():
np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
)

@test_and_compare_with_new_ir()
def test_pass_and_modify_var(self):
"""
pseudocode:
Expand Down