Skip to content

Commit

Permalink
refine pir amp_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Sep 24, 2024
1 parent 5c9a673 commit 2f875ed
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
19 changes: 16 additions & 3 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/serialize_deserialize/include/interface.h"
#include "paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h"
#include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h"
#include "paddle/fluid/pir/transforms/general/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
Expand Down Expand Up @@ -922,6 +923,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
// set attr
for (const auto &pass : pass_pm.passes()) {
pass->SetNotOwned(pir::Pass::kParamScopeAttr, sub_scope_);
pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
if (pass->name() == "matmul_add_act_fuse_pass" ||
pass->name() == "conv2d_add_act_fuse_pass" ||
pass->name() == "conv2d_add_fuse_pass") {
Expand Down Expand Up @@ -960,6 +962,19 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
config_.deleted_passes_.end()) {
basic_pass_pm.AddPass(std::move(common_subexpression_elimination_pass));
}
if (config_.enable_gpu_mixed_) {
auto auto_mixed_precision_pass = ::pir::CreateAutoMixedPrecisionPass();
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
auto_mixed_precision_pass->name()) ==
config_.deleted_passes_.end()) {
auto_mixed_precision_pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
auto_mixed_precision_pass->Set("__mixed_precision_mode__",
new phi::DataType(paddle::ConvertPrecision(
config_.mixed_precision_mode_)));
basic_pass_pm.AddPass(std::move(auto_mixed_precision_pass));
}
}
auto params_sync_among_devices_pass =
::pir::CreateParamsSyncAmongDevicesPass();
if (std::find(config_.deleted_passes_.begin(),
Expand Down Expand Up @@ -2227,9 +2242,7 @@ void AnalysisPredictor::PrepareArgument() {
pass_builder->AppendPass("simplify_with_basic_ops_pass");
pass_builder->AppendPass("is_test_pass");
pass_builder->AppendPass("constant_folding_pass");
}
pass_builder->AppendPass("auto_mixed_precision_pass");
if (!config_.new_ir_enabled()) {
pass_builder->AppendPass("auto_mixed_precision_pass");
pass_builder->AppendPass("inplace_op_var_pass");
}
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
Expand Down
52 changes: 46 additions & 6 deletions paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "paddle/pir/include/core/parameter.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"
Expand Down Expand Up @@ -102,6 +103,7 @@ class AutoMixedPrecisionPass : public pir::Pass {
ProcessBlock(&block, builder);
}
}
cached_cast_ops_.clear();
}

bool CanApplyOn(pir::Operation* op) const override {
Expand Down Expand Up @@ -134,6 +136,7 @@ class AutoMixedPrecisionPass : public pir::Pass {
paddle::dialect::SumOp::name(),
paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(),
paddle::dialect::CrossEntropyWithSoftmax_Op::name(),
"pd_op.array_to_tensor",
});
}

Expand Down Expand Up @@ -164,6 +167,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
auto backend = ConvertPlaceToBackend(place_);
support_low_precision =
OpSupportPrecision(op_type, backend, precision_mode_);
if (op_name == "pd_op.scale" && !OpHasFloatResult(op)) {
support_low_precision = false;
op_should_not_handle_.insert(op);
}
} else { // pd op without float result
support_low_precision = false;
op_should_not_handle_.insert(op);
Expand Down Expand Up @@ -436,8 +443,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
if (result.type().isa<paddle::dialect::DenseTensorType>() &&
IsDenseTensorTypeFloat(
result.type().dyn_cast<paddle::dialect::DenseTensorType>())) {
return true;
} else if (result.type().isa<pir::VectorType>() &&
IsVectorTypeFloat(result.type().dyn_cast<pir::VectorType>())) {
return true;
}
}
return false;
Expand Down Expand Up @@ -480,6 +489,9 @@ class AutoMixedPrecisionPass : public pir::Pass {
return operand.type() &&
operand.type().isa<paddle::dialect::DenseTensorType>();
}
bool IsOperandHasDenseTensorVectorType(pir::OpOperand operand) const {
return operand.type() && operand.type().isa<pir::VectorType>();
}

void DoInsertCastOp(pir::Operation* op,
pir::OpOperand operand,
Expand Down Expand Up @@ -585,7 +597,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
SetResultDataType(op->result(0), precision_mode_, builder.ir_context());
return;
}

// Other pd ops
if (OpRunLowPrecision(op)) {
auto phi_kernel =
Expand Down Expand Up @@ -658,11 +669,38 @@ class AutoMixedPrecisionPass : public pir::Pass {
auto phi_dtype = phi::DataType::FLOAT32;
for (size_t i = 0; i < op->num_operands(); i++) {
auto operand = op->operand(i);
if (!IsOperandHasDenseTensorType(operand)) continue;
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype == precision_mode_) {
DoInsertCastOp(op, operand, phi_dtype, builder);
if (IsOperandHasDenseTensorType(operand)) {
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype == precision_mode_) {
DoInsertCastOp(op, operand, phi_dtype, builder);
}
} else if (IsOperandHasDenseTensorVectorType(operand)) {
LOG(INFO) << "IsOperandHasDenseTensorVectorType(operand)";
LOG(INFO) << operand.source().defining_op()->name();
auto defining_op_ = operand.source().defining_op();
if (defining_op_->isa<pir::CombineOp>()) {
auto input_num = defining_op_->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto operand = defining_op_->operand(i);
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand);
if (IsPhiDataTypeFloat(operand_phi_dtype) &&
operand_phi_dtype != phi::DataType::FLOAT32) {
DoInsertCastOp(
defining_op_, operand, phi::DataType::FLOAT32, builder);
LOG(INFO) << "DoInsertCastOp";
}
}
std::vector<pir::Type> inputs_type(input_num);
for (size_t idx = 0; idx < input_num; ++idx) {
inputs_type[idx] = defining_op_->operand(idx).type();
}
auto new_vec_type =
pir::VectorType::get(builder.ir_context(), inputs_type);
defining_op_->result(0).set_type(new_vec_type);
}
} else {
continue;
}
}
}
Expand All @@ -677,3 +715,5 @@ std::unique_ptr<Pass> CreateAutoMixedPrecisionPass() {
}

} // namespace pir

REGISTER_IR_PASS(auto_mixed_precision_pass, AutoMixedPrecisionPass);
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ USE_PIR_PASS(delete_weight_dequant_linear_op_pass);
USE_PIR_PASS(delete_quant_dequant_linear_op_pass);
USE_PIR_PASS(transfer_layout_pass);
USE_PIR_PASS(fused_rotary_position_embedding_pass);
USE_PIR_PASS(auto_mixed_precision_pass);
USE_PIR_PASS(horizontal_fuse_pass);
USE_PIR_PASS(common_subexpression_elimination_pass);

Expand Down

0 comments on commit 2f875ed

Please sign in to comment.