Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Refine and fix pir exe #60443

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ IfInstruction::IfInstruction(size_t id,
GetInputIds(op, *value_exec_info, &inputs);
auto true_outside_inputs =
GetExternalInputs(&true_branch_block, *value_exec_info, &inputs);
std::vector<pir::Value> false_outside_inputs;
auto& false_branch_block = if_op.false_block();
false_outside_inputs =
auto false_outside_inputs =
GetExternalInputs(&false_branch_block, *value_exec_info, &inputs);
// NOTE(chenxi67): the variable corresponding to container value if a
// <VariableRefArray> Type. It will recursively get the ID of internal
Expand Down Expand Up @@ -107,9 +106,14 @@ IfInstruction::IfInstruction(size_t id,
}
}
InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs);

InsertTuplePushContinerToOuts(
&if_op.false_block(), *value_exec_info, &outputs);

InsertInplacedExternalInputsToOuts(
&true_branch_block, true_outside_inputs, *value_exec_info, &outputs);
InsertInplacedExternalInputsToOuts(
&false_branch_block, false_outside_inputs, *value_exec_info, &outputs);

for (auto& item : outputs) {
auto& var_vec = item.second;
for (auto it = var_vec.begin(); it != var_vec.end();) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,25 @@ WhileInstruction::WhileInstruction(
ValueExecutionInfo* parent_exe_info,
interpreter::ExecutionConfig execution_config)
: InstructionBase(id, place) {
op_ = op;
VLOG(6) << "finish process dist attributes";

SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";

VLOG(6) << "finish process inputs outputs index";

PADDLE_ENFORCE(op->isa<paddle::dialect::WhileOp>(),
phi::errors::PreconditionNotMet(
"While instruction only support While op"));

op_ = op;
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
body_block_ = &while_op.body();

cond_var_ = parent_exe_info->GetVarByValue(while_op.operand_source(0));
SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";

cond_var_ = parent_exe_info->GetVarByValue(while_op.operand_source(0));
for (size_t i = 1; i < while_op.num_operands(); ++i) {
inputs_.push_back(
parent_exe_info->GetVarByValue(while_op.operand_source(i)));
}

for (size_t i = 0; i < while_op.num_results(); ++i) {
outputs_.push_back(parent_exe_info->GetVarByValue(while_op.result(i)));
}

body_block_ = &while_op.body();

std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *parent_exe_info, &inputs);
auto body_outside_inputs =
Expand All @@ -94,8 +86,10 @@ WhileInstruction::WhileInstruction(
std::vector<int> outputs_id = GetValueIds(value, *parent_exe_info);
outputs.emplace(value, outputs_id);
}
InsertTuplePushContinerToOuts(body_block_, *parent_exe_info, &outputs);
}
InsertTuplePushContinerToOuts(body_block_, *parent_exe_info, &outputs);
InsertInplacedExternalInputsToOuts(
body_block_, body_outside_inputs, *parent_exe_info, &outputs);
SetOutputs(outputs);

Scope* body_scope = &(parent_exe_info->GetScope()->NewScope());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,55 @@ void InsertTuplePushContinerToOuts(
}
}

void InsertInplacedExternalInputsToOuts(
pir::Block* block,
const std::vector<pir::Value>& external_inputs,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs) {
for (auto& op : *block) {
if (op.attributes().count("is_inplace") != 0 &&
op.attributes()
.at("is_inplace")
.dyn_cast<pir::BoolAttribute>()
.data()) {
std::string op_name = op.name();
if (op.attributes().count("op_name")) {
op_name = op.attributes()
.at("op_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
}
pir::OpInfo op_info =
pir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_(op_name),
paddle::dialect::IsLegacyOp(op_name));

for (size_t i = 0; i < op.num_results(); ++i) {
pir::Value value = op.result(i);
if (!IsInvalid(value)) {
VLOG(8) << "Number " << i << " result of " << op_name
<< " is not invalid, so skip build a variable.";
continue;
}
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
pir::Value inplace_value =
op.operand_source(yaml_parser.InputName2Id().at(inplace_name));
if (std::find(external_inputs.begin(),
external_inputs.end(),
inplace_value) != external_inputs.end()) {
outputs->emplace(value,
GetValueIds(inplace_value, value_exec_info));
}
}
}
}
}
}

bool GetCondData(const phi::DenseTensor& cond) {
if (paddle::platform::is_cpu_place(cond.place())) {
return cond.data<bool>()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ void InsertTuplePushContinerToOuts(
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs);

void InsertInplacedExternalInputsToOuts(
pir::Block* block,
const std::vector<pir::Value>& external_inputs,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs);

bool GetCondData(const phi::DenseTensor& cond);
} // namespace framework
} // namespace paddle
30 changes: 1 addition & 29 deletions test/legacy_test/test_while_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy
from utils import compare_legacy_with_pt

import paddle
from paddle import base, set_flags
from paddle import base
from paddle.base import core
from paddle.base.backward import append_backward
from paddle.base.executor import Executor
Expand Down Expand Up @@ -82,7 +81,6 @@ def simple_net(self):
loss = paddle.mean(sum_result)
return loss, sum_result

# TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False).
@test_with_pir_api
def test_simple_net(self):
main_program = base.Program()
Expand All @@ -92,14 +90,6 @@ def test_simple_net(self):

append_backward(loss)

if in_pir_mode():
flag_1 = "FLAGS_enable_pir_in_executor_trace_run"
flag_2 = "FLAGS_new_executor_serial_run"
os.environ[flag_1] = 'True'
os.environ[flag_2] = 'True'
set_flags({flag_1: True})
set_flags({flag_2: True})

cpu = core.CPUPlace()
exe = Executor(cpu)
d = []
Expand All @@ -111,14 +101,8 @@ def test_simple_net(self):
feed={'d0': d[0], 'd1': d[1], 'd2': d[2]},
fetch_list=[sum_result],
)
if in_pir_mode():
del os.environ[flag_1]
del os.environ[flag_2]
set_flags({flag_1: False})
set_flags({flag_2: False})
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)

# TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False).
@test_with_pir_api
def test_simple_net_forward(self):
main_program = base.Program()
Expand All @@ -136,20 +120,8 @@ def test_simple_net_forward(self):
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))

if in_pir_mode():
flag_1 = "FLAGS_enable_pir_in_executor_trace_run"
flag_2 = "FLAGS_new_executor_serial_run"
os.environ[flag_1] = 'True'
os.environ[flag_2] = 'True'
set_flags({flag_1: True})
set_flags({flag_2: True})
for _ in range(2):
exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]})
if in_pir_mode():
del os.environ[flag_1]
del os.environ[flag_2]
set_flags({flag_1: False})
set_flags({flag_2: False})

@compare_legacy_with_pt
@test_with_pir_api
Expand Down