Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR][Inference] open pir amp pass and fix some bugs #67822

Merged
merged 25 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9a62e91
register for auto_mixed_precision_pass.cc
aooxin Jul 18, 2024
06f8cbe
modified: paddle/fluid/pir/transforms/general/auto_mixed_precision_…
aooxin Aug 28, 2024
1af5328
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Aug 29, 2024
950f526
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Sep 9, 2024
7b3d121
fix some bugs for amp pass
aooxin Sep 9, 2024
1af63c4
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Sep 9, 2024
7c24382
fix conflict for cse and amp pass
aooxin Sep 9, 2024
faa032a
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Sep 10, 2024
bc8263f
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Sep 11, 2024
e458f8c
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Sep 23, 2024
efe2c0e
fix amp pass for en_table_structure,mask_rcnn_r50_1x_coco,rec_mtb_nrt…
aooxin Sep 23, 2024
2f875ed
refine pir amp_pass
yuanlehome Sep 24, 2024
24bf9be
update
yuanlehome Sep 24, 2024
b955fd4
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
aooxin Sep 26, 2024
5fbd930
Merge branch 'develop' of github.com:aooxin/Paddle into develop
aooxin Sep 26, 2024
380142e
fix amp bugs about SetOptimizationLevel(3)
aooxin Sep 26, 2024
01bc0b2
add some config settings for AutoMixedPrecisionPass
aooxin Sep 27, 2024
cfdcc2a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aooxin Sep 27, 2024
7193616
fix fused_bias_residual_layernorm op bugs for amp pass.
aooxin Sep 29, 2024
a6fddaf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aooxin Sep 29, 2024
7281951
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aooxin Oct 8, 2024
430a3d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aooxin Oct 9, 2024
0eea6d0
delete repeat func.
aooxin Oct 9, 2024
656bb2f
code stytle
aooxin Oct 9, 2024
ad916bf
fix code style
yuanlehome Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/serialize_deserialize/include/interface.h"
#include "paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h"
#include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h"
#include "paddle/fluid/pir/transforms/general/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
Expand Down Expand Up @@ -922,6 +923,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
// set attr
for (const auto &pass : pass_pm.passes()) {
pass->SetNotOwned(pir::Pass::kParamScopeAttr, sub_scope_);
pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
if (pass->name() == "matmul_add_act_fuse_pass" ||
pass->name() == "conv2d_add_act_fuse_pass" ||
pass->name() == "conv2d_add_fuse_pass") {
Expand Down Expand Up @@ -960,6 +962,28 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
config_.deleted_passes_.end()) {
basic_pass_pm.AddPass(std::move(common_subexpression_elimination_pass));
}
if (config_.enable_gpu_mixed_) {
auto auto_mixed_precision_pass = ::pir::CreateAutoMixedPrecisionPass();
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
auto_mixed_precision_pass->name()) ==
config_.deleted_passes_.end()) {
auto_mixed_precision_pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
auto_mixed_precision_pass->Set("__mixed_precision_mode__",
new phi::DataType(paddle::ConvertPrecision(
config_.mixed_precision_mode_)));
auto_mixed_precision_pass->Set(
"__enable_low_precision_io__",
new bool(config_.enable_low_precision_io_));
auto_mixed_precision_pass->Set(
"__mixed_black_list__",
new std::unordered_set<std::string>(config_.mixed_black_list_));
auto_mixed_precision_pass->Set(
"__mixed_white_list__",
new std::unordered_set<std::string>(config_.mixed_white_list_));
basic_pass_pm.AddPass(std::move(auto_mixed_precision_pass));
}
}
auto params_sync_among_devices_pass =
::pir::CreateParamsSyncAmongDevicesPass();
if (std::find(config_.deleted_passes_.begin(),
Expand Down Expand Up @@ -2227,9 +2251,7 @@ void AnalysisPredictor::PrepareArgument() {
pass_builder->AppendPass("simplify_with_basic_ops_pass");
pass_builder->AppendPass("is_test_pass");
pass_builder->AppendPass("constant_folding_pass");
}
pass_builder->AppendPass("auto_mixed_precision_pass");
if (!config_.new_ir_enabled()) {
pass_builder->AppendPass("auto_mixed_precision_pass");
pass_builder->AppendPass("inplace_op_var_pass");
}
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
Expand Down
83 changes: 74 additions & 9 deletions paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
Expand All @@ -47,6 +48,7 @@
#include "paddle/pir/include/core/parameter.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"
Expand All @@ -61,8 +63,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
precision_mode_(phi::DataType::FLOAT16),
enable_low_precision_io_(false),
context_(nullptr),
black_list_(),
white_list_(),
op_run_low_precision_(),
op_should_not_handle_(),
cached_cast_ops_() {}
Expand All @@ -84,10 +84,39 @@ class AutoMixedPrecisionPass : public pir::Pass {
"required!"
"Use Set method to set the scope attribute."));

PADDLE_ENFORCE_EQ(Has("__enable_low_precision_io__"),
true,
common::errors::InvalidArgument(
"Pass initialize failed."
"When using AutoMixedPrecisionPass, "
"enable_low_precision_io attribute is "
"required!"
"Use Set method to set the scope attribute."));

PADDLE_ENFORCE_EQ(
Has("__mixed_black_list__"),
true,
common::errors::InvalidArgument(
"Pass initialize failed."
"When using AutoMixedPrecisionPass, mixed_black_list attribute is "
"required!"
"Use Set method to set the scope attribute."));

PADDLE_ENFORCE_EQ(
Has("__mixed_white_list__"),
true,
common::errors::InvalidArgument(
"Pass initialize failed."
"When using AutoMixedPrecisionPass, mixed_white_list attribute is "
"required!"
"Use Set method to set the scope attribute."));

place_ = Get<phi::Place>(pir::Pass::kPlaceAttr);
precision_mode_ = Get<phi::DataType>("__mixed_precision_mode__");
context_ = context;
enable_low_precision_io_ = false;
enable_low_precision_io_ = Get<bool>("__enable_low_precision_io__");
black_list_ = Get<std::unordered_set<std::string>>("__mixed_black_list__");
white_list_ = Get<std::unordered_set<std::string>>("__mixed_white_list__");
SetDefaultBlacklist();
return true;
}
Expand All @@ -102,6 +131,7 @@ class AutoMixedPrecisionPass : public pir::Pass {
ProcessBlock(&block, builder);
}
}
cached_cast_ops_.clear();
}

bool CanApplyOn(pir::Operation* op) const override {
Expand Down Expand Up @@ -134,6 +164,7 @@ class AutoMixedPrecisionPass : public pir::Pass {
paddle::dialect::SumOp::name(),
paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(),
paddle::dialect::CrossEntropyWithSoftmax_Op::name(),
paddle::dialect::ArrayToTensorOp::name(),
});
}

Expand Down Expand Up @@ -164,6 +195,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
auto backend = ConvertPlaceToBackend(place_);
support_low_precision =
OpSupportPrecision(op_type, backend, precision_mode_);
if (op->isa<paddle::dialect::ScaleOp>() && !OpHasFloatResult(op)) {
support_low_precision = false;
op_should_not_handle_.insert(op);
}
} else { // pd op without float result
support_low_precision = false;
op_should_not_handle_.insert(op);
Expand Down Expand Up @@ -436,8 +471,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
if (result.type().isa<paddle::dialect::DenseTensorType>() &&
IsDenseTensorTypeFloat(
result.type().dyn_cast<paddle::dialect::DenseTensorType>())) {
return true;
} else if (result.type().isa<pir::VectorType>() &&
IsVectorTypeFloat(result.type().dyn_cast<pir::VectorType>())) {
return true;
}
}
return false;
Expand Down Expand Up @@ -480,6 +517,9 @@ class AutoMixedPrecisionPass : public pir::Pass {
return operand.type() &&
operand.type().isa<paddle::dialect::DenseTensorType>();
}
bool IsOperandHasDenseTensorVectorType(pir::OpOperand operand) const {
return operand.type() && operand.type().isa<pir::VectorType>();
}

void DoInsertCastOp(pir::Operation* op,
pir::OpOperand operand,
Expand Down Expand Up @@ -585,7 +625,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
SetResultDataType(op->result(0), precision_mode_, builder.ir_context());
return;
}

// Other pd ops
if (OpRunLowPrecision(op)) {
auto phi_kernel =
Expand Down Expand Up @@ -658,11 +697,35 @@ class AutoMixedPrecisionPass : public pir::Pass {
auto phi_dtype = phi::DataType::FLOAT32;
for (size_t i = 0; i < op->num_operands(); i++) {
auto operand = op->operand(i);
if (!IsOperandHasDenseTensorType(operand)) continue;
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype == precision_mode_) {
DoInsertCastOp(op, operand, phi_dtype, builder);
if (IsOperandHasDenseTensorType(operand)) {
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype == precision_mode_) {
DoInsertCastOp(op, operand, phi_dtype, builder);
}
} else if (IsOperandHasDenseTensorVectorType(operand)) {
auto defining_op_ = operand.source().defining_op();
if (defining_op_->isa<pir::CombineOp>()) {
auto input_num = defining_op_->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto operand = defining_op_->operand(i);
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype != phi::DataType::FLOAT32) {
DoInsertCastOp(
defining_op_, operand, phi::DataType::FLOAT32, builder);
}
}
std::vector<pir::Type> inputs_type(input_num);
for (size_t idx = 0; idx < input_num; ++idx) {
inputs_type[idx] = defining_op_->operand(idx).type();
}
auto new_vec_type =
pir::VectorType::get(builder.ir_context(), inputs_type);
defining_op_->result(0).set_type(new_vec_type);
}
} else {
continue;
}
}
}
Expand All @@ -677,3 +740,5 @@ std::unique_ptr<Pass> CreateAutoMixedPrecisionPass() {
}

} // namespace pir

REGISTER_IR_PASS(auto_mixed_precision_pass, AutoMixedPrecisionPass);
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ USE_PIR_PASS(delete_weight_dequant_linear_op_pass);
USE_PIR_PASS(delete_quant_dequant_linear_op_pass);
USE_PIR_PASS(transfer_layout_pass);
USE_PIR_PASS(fused_rotary_position_embedding_pass);
USE_PIR_PASS(auto_mixed_precision_pass);
USE_PIR_PASS(horizontal_fuse_pass);
USE_PIR_PASS(common_subexpression_elimination_pass);
USE_PIR_PASS(add_shadow_output_after_dead_parameter_pass);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,8 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
Expand All @@ -1237,6 +1239,8 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16) {
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
Expand All @@ -1249,6 +1253,8 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16) {
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
Expand Down