-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] Support constant_folding_pass on PIR #58753
Merged
Merged
Changes from 3 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
b6d78ef
fix shadow_output_op
yuanlehome e7cd7be
update
yuanlehome a30a22a
rewrite constant_folding_pass
yuanlehome 1185d56
update
yuanlehome 700ebc7
update
yuanlehome 44c1af0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome a6900b3
fix compile
yuanlehome 400762d
fix dce
yuanlehome b4e4a1c
enhance judgement
yuanlehome File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,20 +17,21 @@ | |
#include <memory> | ||
#include <string> | ||
#include <unordered_map> | ||
|
||
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in | ||
// paddle/fluid/pir/dialect/CMakeLists.txt. | ||
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/new_executor/interpretercore.h" | ||
#include "paddle/fluid/framework/scope.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" | ||
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" | ||
#include "paddle/fluid/pir/transforms/transform_general_functions.h" | ||
|
||
#include "paddle/phi/common/place.h" | ||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/enforce.h" | ||
|
||
#include "paddle/pir/core/builtin_attribute.h" | ||
#include "paddle/pir/core/builtin_op.h" | ||
#include "paddle/pir/core/ir_context.h" | ||
#include "paddle/pir/core/op_result.h" | ||
|
@@ -46,21 +47,32 @@ namespace { | |
|
||
class ConstantFoldingPattern : public pir::RewritePattern { | ||
public: | ||
ConstantFoldingPattern(pir::IrContext* context, | ||
paddle::framework::Scope* scope, | ||
pir::PatternBenefit benefit = 1, | ||
const std::vector<std::string>& generated_names = {}) | ||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names), | ||
scope_(scope) {} | ||
ConstantFoldingPattern( | ||
pir::IrContext* context, | ||
size_t* suffix, | ||
const phi::Place& place, | ||
paddle::framework::Scope* scope, | ||
paddle::framework::interpreter::ExecutionConfig* exe_config, | ||
std::vector<std::string>* deleted_vars) | ||
: RewritePattern(MatchAnyOpTypeTag(), | ||
1 /*benefit*/, | ||
context, | ||
{} /*generated_names*/), | ||
counter_(suffix), | ||
place_(place), | ||
scope_(scope), | ||
exe_config_(exe_config), | ||
deleted_vars_(deleted_vars) { | ||
exe_config_->create_local_scope = false; | ||
} | ||
|
||
bool Match(pir::Operation* op) const override { | ||
// TODO(liuyuanle): Use trait to improve robustness. | ||
if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() || | ||
op->isa<paddle::dialect::FetchOp>() || | ||
op->isa<paddle::dialect::ShadowOutputOp>()) | ||
op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() || | ||
op->isa<paddle::dialect::FeedOp>()) | ||
return false; | ||
|
||
// Inputs must come from get parameter op. | ||
// inputs must come from get parameter op | ||
for (uint32_t i = 0; i < op->num_operands(); ++i) | ||
if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>()) | ||
return false; | ||
|
@@ -69,73 +81,36 @@ class ConstantFoldingPattern : public pir::RewritePattern { | |
|
||
void Rewrite(pir::Operation* op, | ||
pir::PatternRewriter& rewriter) const override { // NOLINT | ||
pir::Program* program = op->GetParentProgram(); | ||
auto temp_program = BuildProgramFromOperation(op); | ||
|
||
std::vector<std::string> fetch_var_names; | ||
auto block = temp_program->block(); | ||
for (auto it = block->begin(); it != block->end(); ++it) { | ||
if ((*it)->isa<paddle::dialect::FetchOp>()) { | ||
size_t index = (*it) | ||
->attributes() | ||
.at("col") | ||
.dyn_cast<pir::Int32Attribute>() | ||
.data(); | ||
|
||
if (fetch_var_names.size() < index + 1) { | ||
fetch_var_names.resize(index + 1); | ||
} | ||
|
||
fetch_var_names[index] = (*it) | ||
->attributes() | ||
.at("name") | ||
.dyn_cast<pir::StrAttribute>() | ||
.AsString() + | ||
"@fetch"; | ||
} | ||
} | ||
VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op"; | ||
pir::Program new_program(ir_context()); | ||
auto output_var_name = BuildProgramFromOperation(op, &new_program); | ||
|
||
// Execute program | ||
exe_config_.create_local_scope = false; | ||
// execute program | ||
exe_config_->skip_gc_vars.insert(output_var_name); | ||
auto kernel_program = | ||
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()); | ||
paddle::framework::InterpreterCore core(phi::CPUPlace{}, | ||
fetch_var_names, | ||
kernel_program->block(), | ||
scope_, | ||
exe_config_); | ||
|
||
paddle::framework::FetchList fetch_list = core.Run({}); | ||
|
||
// TODO(liuyuanle): Support multiple output. | ||
auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]); | ||
std::unique_ptr<pir::Parameter> parameter = | ||
std::make_unique<pir::Parameter>( | ||
reinterpret_cast<void*>(out_tensor.data()), | ||
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), | ||
op->result(0).type()); | ||
|
||
std::string param_name = | ||
"@constant_folding_pass@_" + std::to_string(suffix_++); | ||
exe_config_.skip_gc_vars.insert(param_name); | ||
|
||
auto* param_var = scope_->Var(param_name); | ||
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>(); | ||
*param_tensor = out_tensor; | ||
program->SetParameter(param_name, std::move(parameter)); | ||
// rewriter.SetInsertionPoint(op); | ||
auto get_parameter_op = | ||
rewriter.Build<pir::GetParameterOp>(param_name, op->result(0).type()); | ||
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); | ||
paddle::framework::InterpreterCore core( | ||
place_, {}, kernel_program->block(), scope_, *exe_config_); | ||
|
||
core.Run({}); | ||
|
||
// TODO(liuyuanle): support multiple output | ||
auto get_parameter_op = rewriter.Build<pir::GetParameterOp>( | ||
output_var_name, op->result(0).type()); | ||
get_parameter_op->set_attribute( | ||
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); | ||
|
||
VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op"; | ||
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0)); | ||
rewriter.EraseOp(op); | ||
} | ||
|
||
private: | ||
std::unique_ptr<pir::Program> BuildProgramFromOperation( | ||
pir::Operation* op) const { | ||
auto program = std::make_unique<pir::Program>(ir_context()); | ||
pir::Builder builder = pir::Builder(ir_context(), program->block()); | ||
std::string BuildProgramFromOperation(pir::Operation* op, | ||
pir::Program* new_program) const { | ||
pir::Builder builder = pir::Builder(ir_context(), new_program->block()); | ||
std::string output_var_name = | ||
"constant_folding@" + std::to_string((*counter_)++); | ||
|
||
// prepare op inputs | ||
std::vector<pir::Value> op_inputs; | ||
|
@@ -146,15 +121,14 @@ class ConstantFoldingPattern : public pir::RewritePattern { | |
phi::errors::InvalidArgument( | ||
"Op's input must be a dense tensor type.")); | ||
|
||
auto [param_name, param] = | ||
pir::GetParameterFromValue(op->operand_source(i)); | ||
program->SetParameter(param_name, | ||
std::make_unique<pir::Parameter>(*param)); | ||
|
||
const auto& param_name = | ||
pir::GetParameterNameFromValue(op->operand_source(i)); | ||
auto* param_var = scope_->FindVar(param_name); | ||
PADDLE_ENFORCE_NOT_NULL( | ||
param_var, | ||
phi::errors::InvalidArgument("Parameter var not in scope.")); | ||
output_var_name = output_var_name + "_" + param_name; | ||
deleted_vars_->push_back(param_name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里需不需要判断下param有没有被其他op使用? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thanks! |
||
|
||
auto get_parameter_op = builder.Build<pir::GetParameterOp>( | ||
param_name, op->operand_source(i).type()); | ||
|
@@ -170,60 +144,80 @@ class ConstantFoldingPattern : public pir::RewritePattern { | |
auto* temp_op = | ||
builder.Build(op_inputs, op->attributes(), output_types, op->info()); | ||
|
||
// TODO(liuyuanle): Support multiple output. | ||
// TODO(liuyuanle): support multiple output | ||
// for (uint32_t i = 0; i < op->num_results(); i++) { | ||
PADDLE_ENFORCE_EQ( | ||
temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(), | ||
true, | ||
phi::errors::InvalidArgument( | ||
"Op's output must be a dense tensor type.")); | ||
|
||
builder.Build<paddle::dialect::FetchOp>( | ||
temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0); | ||
builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name); | ||
// } | ||
|
||
return program; | ||
return output_var_name; | ||
} | ||
|
||
private: | ||
size_t* counter_{nullptr}; | ||
phi::Place place_; | ||
paddle::framework::Scope* scope_{nullptr}; | ||
inline static size_t suffix_{0}; | ||
inline static paddle::framework::interpreter::ExecutionConfig exe_config_{}; | ||
paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr}; | ||
std::vector<std::string>* deleted_vars_{nullptr}; | ||
}; | ||
|
||
class ConstantFoldingPass : public pir::Pass { | ||
public: | ||
ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {} | ||
ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope) | ||
: pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
scope_, phi::errors::InvalidArgument("scope can not be nullptr")); | ||
} | ||
|
||
private: | ||
bool Initialize(pir::IrContext* context) override { | ||
pir::RewritePatternSet ps(context); | ||
ps.Add<ConstantFoldingPattern>(context, &scope_); | ||
ps.Add<ConstantFoldingPattern>( | ||
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); | ||
patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); | ||
return true; | ||
} | ||
|
||
void Run(pir::Operation* op) override { | ||
size_t op_nums = op->GetParentProgram()->block()->size(); | ||
pir::GreedyRewriteConfig cfg; | ||
cfg.use_top_down_traversal = true; | ||
cfg.max_iterations = 10; | ||
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); | ||
|
||
// delete old parameter var | ||
scope_->EraseVars(deleted_vars_); | ||
LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/" | ||
<< op_nums << "]"; | ||
} | ||
|
||
bool CanApplyOn(pir::Operation* op) const override { | ||
// TODO(liuyuanle): remove op->isa<::pir::ModuleOp>() | ||
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; | ||
} | ||
|
||
private: | ||
size_t counter_{0}; | ||
phi::Place place_; | ||
paddle::framework::Scope* scope_{nullptr}; | ||
paddle::framework::interpreter::ExecutionConfig exe_config_{}; | ||
std::vector<std::string> deleted_vars_; | ||
|
||
pir::FrozenRewritePatternSet patterns_; | ||
paddle::framework::Scope scope_; | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace pir { | ||
|
||
std::unique_ptr<Pass> CreateConstantFoldingPass() { | ||
return std::make_unique<ConstantFoldingPass>(); | ||
std::unique_ptr<Pass> CreateConstantFoldingPass( | ||
const phi::Place& place, paddle::framework::Scope* scope) { | ||
return std::make_unique<ConstantFoldingPass>(place, scope); | ||
} | ||
|
||
} // namespace pir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用counter做唯一标识输出var name感觉不是很安全
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面还有操作,拼接了被折叠掉的name
![image](https://private-user-images.githubusercontent.com/23653004/281304297-46f57e5b-458c-41b7-877f-4459580462dd.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwMzI4MTIsIm5iZiI6MTczOTAzMjUxMiwicGF0aCI6Ii8yMzY1MzAwNC8yODEzMDQyOTctNDZmNTdlNWItNDU4Yy00MWI3LTg3N2YtNDQ1OTU4MDQ2MmRkLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA4VDE2MzUxMlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTFiYjBmMTk5N2E0Y2UwY2E1ZjJjNWZiMGU3MTFlMzQ5ZmZiNWIwMWM2YTFmMWUxOTc2NzdiNDJhZmI0ZWViNzQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.l6V1_ZoQdzUNP4_Ge-tTG0pk8qN8hbF1bbdsIE3b86E)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
提醒了我,这里有个潜在问题,随着折叠的深度,name拼接后会越来越长,我修一下