From 2f875edc562ec061c4bdc9225d66d9bc9e31f89b Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 24 Sep 2024 12:10:03 +0000 Subject: [PATCH] refine pir amp_pass --- .../fluid/inference/api/analysis_predictor.cc | 19 +++++-- .../general/auto_mixed_precision_pass.cc | 52 ++++++++++++++++--- paddle/fluid/pir/transforms/passes.h | 1 + 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 9777c6003c1eb..5048b4ec4d5a9 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -115,6 +115,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pir/serialize_deserialize/include/interface.h" +#include "paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h" #include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h" #include "paddle/fluid/pir/transforms/general/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" @@ -922,6 +923,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() { // set attr for (const auto &pass : pass_pm.passes()) { pass->SetNotOwned(pir::Pass::kParamScopeAttr, sub_scope_); + pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_); if (pass->name() == "matmul_add_act_fuse_pass" || pass->name() == "conv2d_add_act_fuse_pass" || pass->name() == "conv2d_add_fuse_pass") { @@ -960,6 +962,19 @@ void AnalysisPredictor::OptimizeInferencePirProgram() { config_.deleted_passes_.end()) { basic_pass_pm.AddPass(std::move(common_subexpression_elimination_pass)); } + if (config_.enable_gpu_mixed_) { + auto auto_mixed_precision_pass = ::pir::CreateAutoMixedPrecisionPass(); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + auto_mixed_precision_pass->name()) == + config_.deleted_passes_.end()) { + auto_mixed_precision_pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_); + auto_mixed_precision_pass->Set("__mixed_precision_mode__", + new phi::DataType(paddle::ConvertPrecision( + config_.mixed_precision_mode_))); + basic_pass_pm.AddPass(std::move(auto_mixed_precision_pass)); + } + } auto params_sync_among_devices_pass = ::pir::CreateParamsSyncAmongDevicesPass(); if (std::find(config_.deleted_passes_.begin(), @@ -2227,9 +2242,7 @@ void AnalysisPredictor::PrepareArgument() { pass_builder->AppendPass("simplify_with_basic_ops_pass"); pass_builder->AppendPass("is_test_pass"); pass_builder->AppendPass("constant_folding_pass"); - } - pass_builder->AppendPass("auto_mixed_precision_pass"); - if (!config_.new_ir_enabled()) { + pass_builder->AppendPass("auto_mixed_precision_pass"); pass_builder->AppendPass("inplace_op_var_pass"); } LOG(INFO) << "This model run in GPU mixed precision mode with no ir " diff --git a/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc index 4ce136e78ec95..2a450e4592bff 100644 --- a/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc @@ -47,6 +47,7 @@ #include "paddle/pir/include/core/parameter.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" #include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/include/pattern_rewrite/pattern_match.h" #include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" @@ -102,6 +103,7 @@ class AutoMixedPrecisionPass : public pir::Pass { ProcessBlock(&block, builder); } } + cached_cast_ops_.clear(); } bool CanApplyOn(pir::Operation* op) const override { @@ -134,6 +136,7 @@ class AutoMixedPrecisionPass : public pir::Pass { paddle::dialect::SumOp::name(), paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), paddle::dialect::CrossEntropyWithSoftmax_Op::name(), + "pd_op.array_to_tensor", }); } @@ -164,6 +167,10 @@ class AutoMixedPrecisionPass : public pir::Pass { auto backend = ConvertPlaceToBackend(place_); support_low_precision = OpSupportPrecision(op_type, backend, precision_mode_); + if (op_name == "pd_op.scale" && !OpHasFloatResult(op)) { + support_low_precision = false; + op_should_not_handle_.insert(op); + } } else { // pd op without float result support_low_precision = false; op_should_not_handle_.insert(op); @@ -436,8 +443,10 @@ class AutoMixedPrecisionPass : public pir::Pass { if (result.type().isa() && IsDenseTensorTypeFloat( result.type().dyn_cast())) { + return true; } else if (result.type().isa() && IsVectorTypeFloat(result.type().dyn_cast())) { + return true; } } return false; @@ -480,6 +489,9 @@ class AutoMixedPrecisionPass : public pir::Pass { return operand.type() && operand.type().isa(); } + bool IsOperandHasDenseTensorVectorType(pir::OpOperand operand) const { + return operand.type() && operand.type().isa(); + } void DoInsertCastOp(pir::Operation* op, pir::OpOperand operand, @@ -585,7 +597,6 @@ class AutoMixedPrecisionPass : public pir::Pass { SetResultDataType(op->result(0), precision_mode_, builder.ir_context()); return; } - // Other pd ops if (OpRunLowPrecision(op)) { auto phi_kernel = @@ -658,11 +669,38 @@ class AutoMixedPrecisionPass : public pir::Pass { auto phi_dtype = phi::DataType::FLOAT32; for (size_t i = 0; i < op->num_operands(); i++) { auto operand = op->operand(i); - if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); - if (IsPhiDataTypeFloat(operand_phi_dtype) && - operand_phi_dtype == precision_mode_) { - DoInsertCastOp(op, operand, phi_dtype, builder); + if (IsOperandHasDenseTensorType(operand)) { + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype == precision_mode_) { + DoInsertCastOp(op, operand, phi_dtype, builder); + } + } else if (IsOperandHasDenseTensorVectorType(operand)) { + LOG(INFO) << "IsOperandHasDenseTensorVectorType(operand)"; + LOG(INFO) << operand.source().defining_op()->name(); + auto defining_op_ = operand.source().defining_op(); + if (defining_op_->isa()) { + auto input_num = defining_op_->num_operands(); + for (size_t i = 0; i < input_num; ++i) { + auto operand = defining_op_->operand(i); + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype != phi::DataType::FLOAT32) { + DoInsertCastOp( + defining_op_, operand, phi::DataType::FLOAT32, builder); + LOG(INFO) << "DoInsertCastOp"; + } + } + std::vector inputs_type(input_num); + for (size_t idx = 0; idx < input_num; ++idx) { + inputs_type[idx] = defining_op_->operand(idx).type(); + } + auto new_vec_type = + pir::VectorType::get(builder.ir_context(), inputs_type); + defining_op_->result(0).set_type(new_vec_type); + } + } else { + continue; } } } @@ -677,3 +715,5 @@ std::unique_ptr CreateAutoMixedPrecisionPass() { } } // namespace pir + +REGISTER_IR_PASS(auto_mixed_precision_pass, AutoMixedPrecisionPass); diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index 3d04309b9cddf..2b3b98a663b28 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -46,6 +46,7 @@ USE_PIR_PASS(delete_weight_dequant_linear_op_pass); USE_PIR_PASS(delete_quant_dequant_linear_op_pass); USE_PIR_PASS(transfer_layout_pass); USE_PIR_PASS(fused_rotary_position_embedding_pass); +USE_PIR_PASS(auto_mixed_precision_pass); USE_PIR_PASS(horizontal_fuse_pass); USE_PIR_PASS(common_subexpression_elimination_pass);