Skip to content

Commit

Permalink
[cherry-pick] Squeeze2 and transpose2 fuse using oneDNN(#47712)
Browse files Browse the repository at this point in the history
* suqeeze2 + transpose2 fuse onednn cherrypick 2.4

* format

* fix merge
  • Loading branch information
zh794390558 authored Nov 9, 2022
1 parent 34f67a8 commit ea5f44b
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 7 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,25 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
return relu_out_var;
}

PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
->assert_is_op_input("squeeze2", "X");
auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr())
->assert_is_op("squeeze2")
->assert_has_n_outputs(2);
auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr())
->AsIntermediate()
->assert_is_op_output("squeeze2", "Out")
->assert_is_op_input("transpose2", "X");
auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");

squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out});
transpose2_op->LinksFrom({squeeze2_op_out});
return transpose2_op;
}

PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias,
bool with_relu) {
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,20 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(output);
};

// Squeeze2 + Transpose2
// Forward pass
struct Squeeze2Transpose2 : public PatternBase {
Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "squeeze2_transpose2") {}

PDNode* operator()();

PATTERN_DECL_NODE(squeeze2_op_in);
PATTERN_DECL_NODE(squeeze2_op);
PATTERN_DECL_NODE(squeeze2_op_out);
PATTERN_DECL_NODE(transpose2_op);
};

// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
Expand Down Expand Up @@ -2002,6 +2016,12 @@ struct AddSupportInt8 : public PatternBase {
out_var->inputs.clear(); \
out_var->inputs.push_back(op);

// Set the in_var as the input of the op
#define IR_VAR_OP_LINK(in_var, op) \
in_var->outputs.clear(); \
in_var->outputs.push_back(op); \
op->inputs.push_back(in_var);

} // namespace ir
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

using string::PrettyLogDetail;

void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));

FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph);

GraphPatternDetector gpd;
patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern(
gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass");
squeeze2_transpose2_pattern();

int found_count = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, squeeze2_transpose2_pattern);

if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
(transpose2_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool,
transpose2_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of transpose2 can be fused after with "
"squeeze2.";
return;
}

std::vector<int> squeeze2_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes);
transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()});

IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op);
GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out});
found_count++;
};

gpd(graph, handler);
AddStatis(found_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count);
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass,
paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass);
REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("squeeze2", 0)
.GE("transpose2", 0));
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace ir {

class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase {
public:
virtual ~FuseSqueeze2Transpose2OneDNNPass() {}

protected:
void ApplyImpl(Graph *graph) const override;
};

} // namespace ir
} // namespace framework

} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ void CpuPassStrategy::EnableMKLDNN() {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");

for (auto &pass : std::vector<std::string>({
"squeeze2_transpose2_onednn_fuse_pass",
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
Expand Down Expand Up @@ -386,6 +387,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass");
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

platform::SetInMemDescWithLogicalLayoutFusesSupport(
ctx, const_cast<phi::DenseTensor*>(x), x->mem_desc());

if (ndims == 1) {
framework::TensorCopy(*x, x->place(), out);
out->set_mem_desc(x->mem_desc());
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ class TransposeOp : public framework::OperatorWithKernel {
size_t x_rank = x_dims.size();
size_t axis_size = axis.size();

PADDLE_ENFORCE_EQ(x_rank,
// Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
// axis_size
PADDLE_ENFORCE_GE(x_rank,
axis_size,
platform::errors::InvalidArgument(
"The input tensor's dimension "
"should be equal to the axis's size. "
"should be equal to or greater than the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank,
Expand Down
52 changes: 47 additions & 5 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,10 @@ static void SetOutMemDescWithUnsqueeze2FuseSupport(
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);

for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}

int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
Expand All @@ -143,20 +141,17 @@ static void SetOutMemDescWithReshape2FuseSupport(
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());

const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());

for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}

out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
Expand All @@ -169,11 +164,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport(
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_squeeze2_axes")) {
out->set_mem_desc(out_md);
out->Resize(phi::make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}

static void SetInMemDescWithSqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* in,
const dnnl::memory::desc& in_md) {
const std::vector<int> fused_squeeze2_axes =
ctx.Attr<std::vector<int>>("fused_squeeze2_axes");
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
fused_squeeze2_axes.end());
const std::vector<int64_t>& x_vec_dims = in_md.dims();
std::vector<int64_t> squeezed_op_tz(
x_vec_dims.size() - fused_squeeze2_axes.size(), 0);

int j = 0;
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
if (squeeze2_axes_set.count(i) ||
squeeze2_axes_set.count(i - x_vec_dims.size())) {
PADDLE_ENFORCE_EQ(
x_vec_dims[i],
1,
platform::errors::InvalidArgument(
"Squeeze2 input '%d' dim should be equal to one, but get '%d'.",
i,
x_vec_dims[i]));
continue;
}
squeezed_op_tz[j++] = x_vec_dims[i];
}

in->set_mem_desc(in_md.reshape(squeezed_op_tz));
in->Resize(phi::make_ddim(squeezed_op_tz));
}

static void SetInMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* in,
const dnnl::memory::desc& in_md) {
if (ctx.HasAttr("fused_squeeze2_axes")) {
SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md);
} else {
in->set_mem_desc(in_md);
in->Resize(phi::make_ddim(in_md.dims()));
}
}

template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
Expand Down

0 comments on commit ea5f44b

Please sign in to comment.