Skip to content

Commit

Permalink
[PIR][Inference] open pir amp pass and fix some bugs (#67822)
Browse files Browse the repository at this point in the history
* register for auto_mixed_precision_pass.cc

* modified:   paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc

* fix some bugs for amp pass

* fix conflict for cse and amp pass

* fix amp pass for en_table_structure,mask_rcnn_r50_1x_coco,rec_mtb_nrtr in auto_mixed_precision_pass.cc

* refine pir amp_pass

* update

* fix amp bugs about SetOptimizationLevel(3)

* add some config settings for AutoMixedPrecisionPass

* fix fused_bias_residual_layernorm op bugs for amp pass.

* delete repeat func.

* code stytle

* fix code style

---------

Co-authored-by: yuanlehome <yuanlehome@163.com>
  • Loading branch information
aooxin and yuanlehome authored Oct 17, 2024
1 parent 47b9213 commit b01f759
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 12 deletions.
28 changes: 25 additions & 3 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/serialize_deserialize/include/interface.h"
#include "paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h"
#include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h"
#include "paddle/fluid/pir/transforms/general/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
Expand Down Expand Up @@ -940,6 +941,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
// set attr
for (const auto &pass : pass_pm.passes()) {
pass->SetNotOwned(pir::Pass::kParamScopeAttr, sub_scope_);
pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
if (pass->name() == "matmul_add_act_fuse_pass" ||
pass->name() == "conv2d_add_act_fuse_pass" ||
pass->name() == "conv2d_add_fuse_pass") {
Expand Down Expand Up @@ -978,6 +980,28 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
config_.deleted_passes_.end()) {
basic_pass_pm.AddPass(std::move(common_subexpression_elimination_pass));
}
if (config_.enable_gpu_mixed_) {
auto auto_mixed_precision_pass = ::pir::CreateAutoMixedPrecisionPass();
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
auto_mixed_precision_pass->name()) ==
config_.deleted_passes_.end()) {
auto_mixed_precision_pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
auto_mixed_precision_pass->Set("__mixed_precision_mode__",
new phi::DataType(paddle::ConvertPrecision(
config_.mixed_precision_mode_)));
auto_mixed_precision_pass->Set(
"__enable_low_precision_io__",
new bool(config_.enable_low_precision_io_));
auto_mixed_precision_pass->Set(
"__mixed_black_list__",
new std::unordered_set<std::string>(config_.mixed_black_list_));
auto_mixed_precision_pass->Set(
"__mixed_white_list__",
new std::unordered_set<std::string>(config_.mixed_white_list_));
basic_pass_pm.AddPass(std::move(auto_mixed_precision_pass));
}
}
auto params_sync_among_devices_pass =
::pir::CreateParamsSyncAmongDevicesPass();
if (std::find(config_.deleted_passes_.begin(),
Expand Down Expand Up @@ -2245,9 +2269,7 @@ void AnalysisPredictor::PrepareArgument() {
pass_builder->AppendPass("simplify_with_basic_ops_pass");
pass_builder->AppendPass("is_test_pass");
pass_builder->AppendPass("constant_folding_pass");
}
pass_builder->AppendPass("auto_mixed_precision_pass");
if (!config_.new_ir_enabled()) {
pass_builder->AppendPass("auto_mixed_precision_pass");
pass_builder->AppendPass("inplace_op_var_pass");
}
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
Expand Down
83 changes: 74 additions & 9 deletions paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
Expand All @@ -47,6 +48,7 @@
#include "paddle/pir/include/core/parameter.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"
Expand All @@ -61,8 +63,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
precision_mode_(phi::DataType::FLOAT16),
enable_low_precision_io_(false),
context_(nullptr),
black_list_(),
white_list_(),
op_run_low_precision_(),
op_should_not_handle_(),
cached_cast_ops_() {}
Expand All @@ -84,10 +84,39 @@ class AutoMixedPrecisionPass : public pir::Pass {
"required!"
"Use Set method to set the scope attribute."));

PADDLE_ENFORCE_EQ(Has("__enable_low_precision_io__"),
true,
common::errors::InvalidArgument(
"Pass initialize failed."
"When using AutoMixedPrecisionPass, "
"enable_low_precision_io attribute is "
"required!"
"Use Set method to set the scope attribute."));

PADDLE_ENFORCE_EQ(
Has("__mixed_black_list__"),
true,
common::errors::InvalidArgument(
"Pass initialize failed."
"When using AutoMixedPrecisionPass, mixed_black_list attribute is "
"required!"
"Use Set method to set the scope attribute."));

PADDLE_ENFORCE_EQ(
Has("__mixed_white_list__"),
true,
common::errors::InvalidArgument(
"Pass initialize failed."
"When using AutoMixedPrecisionPass, mixed_white_list attribute is "
"required!"
"Use Set method to set the scope attribute."));

place_ = Get<phi::Place>(pir::Pass::kPlaceAttr);
precision_mode_ = Get<phi::DataType>("__mixed_precision_mode__");
context_ = context;
enable_low_precision_io_ = false;
enable_low_precision_io_ = Get<bool>("__enable_low_precision_io__");
black_list_ = Get<std::unordered_set<std::string>>("__mixed_black_list__");
white_list_ = Get<std::unordered_set<std::string>>("__mixed_white_list__");
SetDefaultBlacklist();
return true;
}
Expand All @@ -102,6 +131,7 @@ class AutoMixedPrecisionPass : public pir::Pass {
ProcessBlock(&block, builder);
}
}
cached_cast_ops_.clear();
}

bool CanApplyOn(pir::Operation* op) const override {
Expand Down Expand Up @@ -134,6 +164,7 @@ class AutoMixedPrecisionPass : public pir::Pass {
paddle::dialect::SumOp::name(),
paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(),
paddle::dialect::CrossEntropyWithSoftmax_Op::name(),
paddle::dialect::ArrayToTensorOp::name(),
});
}

Expand Down Expand Up @@ -164,6 +195,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
auto backend = ConvertPlaceToBackend(place_);
support_low_precision =
OpSupportPrecision(op_type, backend, precision_mode_);
if (op->isa<paddle::dialect::ScaleOp>() && !OpHasFloatResult(op)) {
support_low_precision = false;
op_should_not_handle_.insert(op);
}
} else { // pd op without float result
support_low_precision = false;
op_should_not_handle_.insert(op);
Expand Down Expand Up @@ -436,8 +471,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
if (result.type().isa<paddle::dialect::DenseTensorType>() &&
IsDenseTensorTypeFloat(
result.type().dyn_cast<paddle::dialect::DenseTensorType>())) {
return true;
} else if (result.type().isa<pir::VectorType>() &&
IsVectorTypeFloat(result.type().dyn_cast<pir::VectorType>())) {
return true;
}
}
return false;
Expand Down Expand Up @@ -480,6 +517,9 @@ class AutoMixedPrecisionPass : public pir::Pass {
return operand.type() &&
operand.type().isa<paddle::dialect::DenseTensorType>();
}
bool IsOperandHasDenseTensorVectorType(pir::OpOperand operand) const {
return operand.type() && operand.type().isa<pir::VectorType>();
}

void DoInsertCastOp(pir::Operation* op,
pir::OpOperand operand,
Expand Down Expand Up @@ -585,7 +625,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
SetResultDataType(op->result(0), precision_mode_, builder.ir_context());
return;
}

// Other pd ops
if (OpRunLowPrecision(op)) {
auto phi_kernel =
Expand Down Expand Up @@ -658,11 +697,35 @@ class AutoMixedPrecisionPass : public pir::Pass {
auto phi_dtype = phi::DataType::FLOAT32;
for (size_t i = 0; i < op->num_operands(); i++) {
auto operand = op->operand(i);
if (!IsOperandHasDenseTensorType(operand)) continue;
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype == precision_mode_) {
DoInsertCastOp(op, operand, phi_dtype, builder);
if (IsOperandHasDenseTensorType(operand)) {
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype == precision_mode_) {
DoInsertCastOp(op, operand, phi_dtype, builder);
}
} else if (IsOperandHasDenseTensorVectorType(operand)) {
auto defining_op_ = operand.source().defining_op();
if (defining_op_->isa<pir::CombineOp>()) {
auto input_num = defining_op_->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto operand = defining_op_->operand(i);
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype != phi::DataType::FLOAT32) {
DoInsertCastOp(
defining_op_, operand, phi::DataType::FLOAT32, builder);
}
}
std::vector<pir::Type> inputs_type(input_num);
for (size_t idx = 0; idx < input_num; ++idx) {
inputs_type[idx] = defining_op_->operand(idx).type();
}
auto new_vec_type =
pir::VectorType::get(builder.ir_context(), inputs_type);
defining_op_->result(0).set_type(new_vec_type);
}
} else {
continue;
}
}
}
Expand All @@ -677,3 +740,5 @@ std::unique_ptr<Pass> CreateAutoMixedPrecisionPass() {
}

} // namespace pir

REGISTER_IR_PASS(auto_mixed_precision_pass, AutoMixedPrecisionPass);
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ USE_PIR_PASS(delete_weight_dequant_linear_op_pass);
USE_PIR_PASS(delete_quant_dequant_linear_op_pass);
USE_PIR_PASS(transfer_layout_pass);
USE_PIR_PASS(fused_rotary_position_embedding_pass);
USE_PIR_PASS(auto_mixed_precision_pass);
USE_PIR_PASS(horizontal_fuse_pass);
USE_PIR_PASS(auto_layout_simplify_pass);
USE_PIR_PASS(auto_layout_pass);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,8 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
Expand All @@ -1237,6 +1239,8 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16) {
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
Expand All @@ -1249,6 +1253,8 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16) {
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
Expand Down

0 comments on commit b01f759

Please sign in to comment.