Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Add conv_fuse_pass, support conv2d+bn -> conv2d #58724

Merged
merged 38 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
00dd869
add conv2d_bn_fuse_pass
Xinyu302 Oct 11, 2023
d0c410e
resolve conflict
Xinyu302 Oct 11, 2023
9d0d589
change LOG(INFO) to LOG(4) & change header
Xinyu302 Oct 14, 2023
786df53
delete useless headers
Xinyu302 Oct 14, 2023
b268223
modify header
Xinyu302 Oct 18, 2023
9fbbc47
rename conv2d_bn_fuse to conv2d_fuse
Xinyu302 Oct 19, 2023
331e010
modify pybind CMakeLists.txt
Xinyu302 Oct 20, 2023
fa331f8
handle conflict
Xinyu302 Oct 20, 2023
eb27a09
move to transforms dir
Xinyu302 Oct 20, 2023
7e18a38
modify pass name
Xinyu302 Oct 20, 2023
f93e991
cancel adding to pir.cc
Xinyu302 Oct 20, 2023
d75fc92
modify pass test
Xinyu302 Oct 20, 2023
419b0bc
change conv2d_fuse pass opt_level to 2
Xinyu302 Oct 20, 2023
32c1d38
add paddle/pir/pass/pass_registry.h in header
Xinyu302 Oct 20, 2023
c320035
Merge commit 'refs/pull/58252/head' of https://github.com/PaddlePaddl…
bukejiyu Nov 6, 2023
9a8c5c0
add conv2d_bn_pass
bukejiyu Nov 6, 2023
58964a5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bukejiyu Nov 7, 2023
981dfd9
add pass_test for pir pass test
bukejiyu Nov 10, 2023
4395cd8
code style
bukejiyu Nov 13, 2023
bd56c4f
code style
bukejiyu Nov 13, 2023
3cc9f33
bug fix
bukejiyu Nov 13, 2023
ec7da07
code style
bukejiyu Nov 13, 2023
cf36d32
bug fix
bukejiyu Nov 13, 2023
ae985a2
code style
bukejiyu Nov 13, 2023
dbabc5c
fix windows ci
bukejiyu Nov 14, 2023
63fba8c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bukejiyu Nov 14, 2023
943b9c8
fix bug
bukejiyu Nov 14, 2023
134209e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bukejiyu Nov 16, 2023
47a5a9f
code style and fix
bukejiyu Nov 16, 2023
5a83fe6
bug fix and code style
bukejiyu Nov 16, 2023
4d00ccf
code style
bukejiyu Nov 16, 2023
69e7622
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bukejiyu Nov 16, 2023
3a468e2
code style
bukejiyu Nov 16, 2023
328f084
code style
bukejiyu Nov 17, 2023
5e13f0c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
bukejiyu Nov 17, 2023
e8a37de
code style
bukejiyu Nov 17, 2023
9f5947c
add cpu test
bukejiyu Nov 17, 2023
96be8c4
add for pass ci
bukejiyu Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/pir/transforms/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/params_sync_among_devices_pass.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
Expand Down Expand Up @@ -745,10 +746,10 @@ bool AnalysisPredictor::PrepareExecutor() {
::pir::PassManager pm(::pir::IrContext::Instance(), 2);
// TODO(liuyuanle): Uncomment constant_folding_pass after fix it
// pm.AddPass(::pir::CreateConstantFoldingPass(sub_scope_));
pm.AddPass(::pir::CreateConv2dFusePass());
pm.AddPass(::pir::CreateDeadCodeEliminationPass());
pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());

// pm.EnableIRPrinting();
pm.EnableIRPrinting();
pm.Run(pir_program_.get());

pir_program_ = std::move(
Expand Down
207 changes: 207 additions & 0 deletions paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/core/op_info.h"
#include "paddle/pir/core/parameter.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/value.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_manager.h"
bukejiyu marked this conversation as resolved.
Show resolved Hide resolved
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/phi/core/ddim.h"

namespace {

class Conv2dBnFusePattern
: public pir::OpRewritePattern<paddle::dialect::BatchNormOp> {
public:
using pir::OpRewritePattern<paddle::dialect::BatchNormOp>::OpRewritePattern;
bool MatchAndRewrite(
paddle::dialect::BatchNormOp op,
pir::PatternRewriter &rewriter) const override { // NOLINT
// The prev op should be conv2d op.
paddle::dialect::Conv2dOp conv2d_op =
pir::GetDefiningOpForInput(op, 0)
->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false;

pir::OpResult conv2d_out = conv2d_op.out();
if (!conv2d_out.HasOneUse()) return false;

pir::Value conv2d_filter = conv2d_op.filter();

pir::OpResult conv2d_filter_result =
conv2d_filter.dyn_cast<pir::OpResult>();
IR_ENFORCE(conv2d_filter_result);

pir::Value bn_input = op.x();
IR_ENFORCE(bn_input == conv2d_out);

pir::Value bn_mean = op.mean();
pir::Value bn_variance = op.variance();
pir::Value bn_scale = op.scale();
pir::Value bn_bias = op.bias();

// --- deal with filter ---
rewriter.set_insertion_point(op);
phi::DDim bn_variance_shape =
bn_variance.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
float epsilon = op.attribute<pir::FloatAttribute>("epsilon").data();
paddle::dialect::FullOp full_op = rewriter.Build<paddle::dialect::FullOp>(
phi::vectorize(bn_variance_shape), epsilon);
paddle::dialect::AddOp add_op = rewriter.Build<paddle::dialect::AddOp>(
bn_variance.dyn_cast<pir::OpResult>(), full_op.out());
paddle::dialect::SqrtOp sqrt_op =
rewriter.Build<paddle::dialect::SqrtOp>(add_op.out());
paddle::dialect::DivideOp div_op =
rewriter.Build<paddle::dialect::DivideOp>(
bn_scale.dyn_cast<pir::OpResult>(), sqrt_op.out());
// reshape scale
phi::DDim conv2d_filter_shape = pir::GetShapeFromValue(conv2d_filter);
phi::DDim bn_scale_shape =
bn_scale.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
std::vector<int64_t> bn_scale_new_shape(conv2d_filter_shape.size(), 1);
bn_scale_new_shape[0] = bn_scale_shape[0];
paddle::dialect::ReshapeOp reshape_scale_op =
rewriter.Build<paddle::dialect::ReshapeOp>(div_op.out(),
bn_scale_new_shape);
// new filter --> mul_op.out()
paddle::dialect::MultiplyOp mul_op =
rewriter.Build<paddle::dialect::MultiplyOp>(conv2d_filter_result,
reshape_scale_op.out());

auto conv2d_attributes = conv2d_op->attributes();
auto new_conv2d_op = rewriter.Build<paddle::dialect::Conv2dOp>(
conv2d_op.input().dyn_cast<pir::OpResult>(),
mul_op.out(),
conv2d_attributes);

// --- deal with bias ---
paddle::dialect::MultiplyOp mul_bias_op =
rewriter.Build<paddle::dialect::MultiplyOp>(
bn_mean.dyn_cast<pir::OpResult>(), div_op.out());
// new bias --> sub_op.out()
paddle::dialect::SubtractOp sub_op =
rewriter.Build<paddle::dialect::SubtractOp>(
bn_bias.dyn_cast<pir::OpResult>(), mul_bias_op.out());
// reshape new bias
phi::DDim new_conv2d_out_shape =
pir::GetShapeFromValue(new_conv2d_op.out());
std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1);
std::string data_format =
new_conv2d_op.attribute<pir::StrAttribute>("data_format").AsString();
if (data_format != "NCHW") {
return false;
}
new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
new_bias_new_shape);
paddle::dialect::AddOp add_bias_op = rewriter.Build<paddle::dialect::AddOp>(
new_conv2d_op.out(), reshape_bias_op.out());

rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out());

rewriter.EraseOp(op);
rewriter.EraseOp(conv2d_op);
return true;
}
};
class BatchNormReplacePattern
: public pir::OpRewritePattern<paddle::dialect::BatchNorm_Op> {
public:
using pir::OpRewritePattern<paddle::dialect::BatchNorm_Op>::OpRewritePattern;
bool MatchAndRewrite(
paddle::dialect::BatchNorm_Op op,
pir::PatternRewriter &rewriter) const override { // NOLINT
auto bn_op = rewriter.Build<paddle::dialect::BatchNormOp>(
op.x().dyn_cast<pir::OpResult>(),
op.mean().dyn_cast<pir::OpResult>(),
op.variance().dyn_cast<pir::OpResult>(),
op.scale().dyn_cast<pir::OpResult>(),
op.bias().dyn_cast<pir::OpResult>(),
op->attributes());
rewriter.ReplaceAllUsesWith(op.out(), bn_op.out());
rewriter.EraseOp(op);
return true;
}
};
class Conv2dFusePass : public pir::Pass {
public:
Conv2dFusePass() : pir::Pass("conv2d_fuse_pass", 2) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
auto conv_bn_pattern = std::make_unique<Conv2dBnFusePattern>(
context,
1,
std::vector<std::string>{paddle::dialect::FullOp::name(),
paddle::dialect::AddOp::name(),
paddle::dialect::SqrtOp::name(),
paddle::dialect::DivideOp::name(),
paddle::dialect::ReshapeOp::name(),
paddle::dialect::MultiplyOp::name(),
paddle::dialect::SubtractOp::name(),
paddle::dialect::Conv2dOp::name()});
VLOG(4) << "Conv2dBnFusePattern will generate the following operations: ";
for (auto op_info : conv_bn_pattern->generated_ops()) {
VLOG(4) << "--- " << op_info.name();
}
auto bn_replace_pattern = std::make_unique<BatchNormReplacePattern>(
context,
1,
std::vector<std::string>{paddle::dialect::BatchNormOp::name()});
ps.Add(std::move(bn_replace_pattern));
ps.Add(std::move(conv_bn_pattern));
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation *op) override {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}

bool CanApplyOn(pir::Operation *op) const override {
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateConv2dFusePass() {
return std::make_unique<Conv2dFusePass>();
}

} // namespace pir

REGISTER_IR_PASS(conv2d_fuse_pass, Conv2dFusePass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateConv2dFusePass();

} // namespace pir
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ USE_PIR_PASS(fused_dropout_add_pass);
USE_PIR_PASS(fused_linear_param_grad_add_pass);
USE_PIR_PASS(inplace_pass);
USE_PIR_PASS(replace_fetch_with_shadow_output_pass);
USE_PIR_PASS(conv2d_fuse_pass);

PHI_DECLARE_bool(print_ir);

Expand Down
Loading