Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Support constant_folding_pass on PIR #58753

Merged
merged 9 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions paddle/fluid/inference/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,18 @@ if(WIN32)
target_link_libraries(paddle_inference_api phi)
endif()

set(inference_deps ${analysis_deps} paddle_inference_api analysis
analysis_config naive_executor ${GLOB_PASS_LIB})
set(PIR_PASS_DEPS
pd_constant_folding_pass dead_code_elimination_pass pd_op_to_kernel_pass
pd_inplace_pass replace_fetch_with_shadow_output_pass)

set(inference_deps
${analysis_deps}
paddle_inference_api
analysis
analysis_config
naive_executor
${GLOB_PASS_LIB}
${PIR_PASS_DEPS})

if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
#endif

#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/pir/transforms/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
Expand Down Expand Up @@ -731,10 +732,11 @@ bool AnalysisPredictor::PrepareExecutor() {
paddle::TranslateLegacyProgramToProgram(*inference_program_));

::pir::PassManager pm(::pir::IrContext::Instance(), 2);
pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_));
pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
pm.AddPass(::pir::CreateDeadCodeEliminationPass());

pm.EnableIRPrinting();
// pm.EnableIRPrinting();
pm.Run(pir_program_.get());

pir_program_ = std::move(
Expand Down
174 changes: 85 additions & 89 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@
#include <memory>
#include <string>
#include <unordered_map>

// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include <vector>

#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"

#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/op_result.h"
Expand All @@ -46,21 +47,32 @@ namespace {

class ConstantFoldingPattern : public pir::RewritePattern {
public:
ConstantFoldingPattern(pir::IrContext* context,
paddle::framework::Scope* scope,
pir::PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names),
scope_(scope) {}
ConstantFoldingPattern(
pir::IrContext* context,
size_t* suffix,
const phi::Place& place,
paddle::framework::Scope* scope,
paddle::framework::interpreter::ExecutionConfig* exe_config,
std::vector<std::string>* deleted_vars)
: RewritePattern(MatchAnyOpTypeTag(),
1 /*benefit*/,
context,
{} /*generated_names*/),
counter_(suffix),
place_(place),
scope_(scope),
exe_config_(exe_config),
deleted_vars_(deleted_vars) {
exe_config_->create_local_scope = false;
}

bool Match(pir::Operation* op) const override {
// TODO(liuyuanle): Use trait to improve robustness.
if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() ||
op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::ShadowOutputOp>())
op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::FeedOp>())
return false;

// Inputs must come from get parameter op.
// inputs must come from get parameter op
for (uint32_t i = 0; i < op->num_operands(); ++i)
if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>())
return false;
Expand All @@ -69,73 +81,34 @@ class ConstantFoldingPattern : public pir::RewritePattern {

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
pir::Program* program = op->GetParentProgram();
auto temp_program = BuildProgramFromOperation(op);

std::vector<std::string> fetch_var_names;
auto block = temp_program->block();
for (auto it = block->begin(); it != block->end(); ++it) {
if ((*it)->isa<paddle::dialect::FetchOp>()) {
size_t index = (*it)
->attributes()
.at("col")
.dyn_cast<pir::Int32Attribute>()
.data();

if (fetch_var_names.size() < index + 1) {
fetch_var_names.resize(index + 1);
}

fetch_var_names[index] = (*it)
->attributes()
.at("name")
.dyn_cast<pir::StrAttribute>()
.AsString() +
"@fetch";
}
}
VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op";
pir::Program new_program(ir_context());
auto output_var_name = BuildProgramFromOperation(op, &new_program);

// Execute program
exe_config_.create_local_scope = false;
// execute program
exe_config_->skip_gc_vars.insert(output_var_name);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(temp_program.get());
paddle::framework::InterpreterCore core(phi::CPUPlace{},
fetch_var_names,
kernel_program->block(),
scope_,
exe_config_);

paddle::framework::FetchList fetch_list = core.Run({});

// TODO(liuyuanle): Support multiple output.
auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
std::unique_ptr<pir::Parameter> parameter =
std::make_unique<pir::Parameter>(
reinterpret_cast<void*>(out_tensor.data()),
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
op->result(0).type());

std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++);
exe_config_.skip_gc_vars.insert(param_name);

auto* param_var = scope_->Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
*param_tensor = out_tensor;
program->SetParameter(param_name, std::move(parameter));
// rewriter.SetInsertionPoint(op);
auto get_parameter_op =
rewriter.Build<pir::GetParameterOp>(param_name, op->result(0).type());
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});

// TODO(liuyuanle): support multiple output
auto get_parameter_op = rewriter.Build<pir::GetParameterOp>(
output_var_name, op->result(0).type());
get_parameter_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));

VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op";
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op);
}

private:
std::unique_ptr<pir::Program> BuildProgramFromOperation(
pir::Operation* op) const {
auto program = std::make_unique<pir::Program>(ir_context());
pir::Builder builder = pir::Builder(ir_context(), program->block());
std::string BuildProgramFromOperation(pir::Operation* op,
pir::Program* new_program) const {
pir::Builder builder = pir::Builder(ir_context(), new_program->block());

// prepare op inputs
std::vector<pir::Value> op_inputs;
Expand All @@ -146,15 +119,13 @@ class ConstantFoldingPattern : public pir::RewritePattern {
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));

auto [param_name, param] =
pir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name,
std::make_unique<pir::Parameter>(*param));

const auto& param_name =
pir::GetParameterNameFromValue(op->operand_source(i));
auto* param_var = scope_->FindVar(param_name);
PADDLE_ENFORCE_NOT_NULL(
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
deleted_vars_->push_back(param_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需不需要判断下param有没有被其他op使用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!


auto get_parameter_op = builder.Build<pir::GetParameterOp>(
param_name, op->operand_source(i).type());
Expand All @@ -170,60 +141,85 @@ class ConstantFoldingPattern : public pir::RewritePattern {
auto* temp_op =
builder.Build(op_inputs, op->attributes(), output_types, op->info());

// TODO(liuyuanle): Support multiple output.
// TODO(liuyuanle): support multiple output
// for (uint32_t i = 0; i < op->num_results(); i++) {
PADDLE_ENFORCE_EQ(
temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's output must be a dense tensor type."));

builder.Build<paddle::dialect::FetchOp>(
temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);
std::stringstream ss;
ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();
std::string output_var_name =
"constant_folding@_" + ss.str() + std::to_string((*counter_)++);

builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name);
// }

return program;
return output_var_name;
}

private:
size_t* counter_{nullptr};
phi::Place place_;
paddle::framework::Scope* scope_{nullptr};
inline static size_t suffix_{0};
inline static paddle::framework::interpreter::ExecutionConfig exe_config_{};
paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr};
std::vector<std::string>* deleted_vars_{nullptr};
};

class ConstantFoldingPass : public pir::Pass {
public:
ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {}
ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope)
: pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) {
PADDLE_ENFORCE_NOT_NULL(
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));
}

private:
bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(context, &scope_);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation* op) override {
size_t op_nums = op->GetParentProgram()->block()->size();
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);

// delete old parameter var
scope_->EraseVars(deleted_vars_);
LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/"
<< op_nums << "]";
}

bool CanApplyOn(pir::Operation* op) const override {
// TODO(liuyuanle): remove op->isa<::pir::ModuleOp>()
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
size_t counter_{0};
phi::Place place_;
paddle::framework::Scope* scope_{nullptr};
paddle::framework::interpreter::ExecutionConfig exe_config_{};
std::vector<std::string> deleted_vars_;

pir::FrozenRewritePatternSet patterns_;
paddle::framework::Scope scope_;
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateConstantFoldingPass() {
return std::make_unique<ConstantFoldingPass>();
std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope) {
return std::make_unique<ConstantFoldingPass>(place, scope);
}

} // namespace pir
10 changes: 9 additions & 1 deletion paddle/fluid/pir/transforms/constant_folding_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
#pragma once

#include <memory>
#include "paddle/phi/common/place.h"
#include "paddle/pir/core/dll_decl.h"

namespace paddle {
namespace framework {
class Scope;
}
} // namespace paddle

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();
IR_API std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope);

} // namespace pir
7 changes: 3 additions & 4 deletions paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
}

bool Match(pir::Operation* op) const override {
if (op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::ShadowOutputOp>())
if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>()) {
return false;

}
return op->use_empty();
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
if (op->dyn_cast<pir::GetParameterOp>()) {
if (op->isa<pir::GetParameterOp>()) {
// Delete parameter from program.
pir::GetParameterOp get_parameter_op =
op->dyn_cast<pir::GetParameterOp>();
Expand Down
8 changes: 2 additions & 6 deletions paddle/fluid/pir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

namespace pir {

std::pair<std::string, pir::Parameter*> GetParameterFromValue(
pir::Value value) {
std::string GetParameterNameFromValue(pir::Value value) {
pir::GetParameterOp op =
value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
Expand All @@ -37,10 +36,7 @@ std::pair<std::string, pir::Parameter*> GetParameterFromValue(
.at(op.attributes_name[0])
.dyn_cast<pir::StrAttribute>()
.AsString();
pir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null."));
return {name, param};
return name;
}

const phi::DDim& GetShapeFromValue(pir::Value value) {
Expand Down
Loading