Skip to content

Commit

Permalink
cherry-pick 36424 inference support bert when exists matmul_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo committed Oct 18, 2021
1 parent cc44965 commit dd76d8d
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 43 deletions.
2 changes: 1 addition & 1 deletion cmake/external/lite.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
GIT_TAG ${LITE_GIT_TAG}
PREFIX ${LITE_SOURCES_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND sed -i "s?NNadapter_bridges_path = os.path.abspath('..')+\"\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?NNadapter_bridges_path = os.path.abspath(\'..\')+\"\/extern_lite\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?" ${LITE_SOURCES_DIR}/src/extern_lite//lite/tools/cmake_tools/record_supported_kernel_op.py && sed -i "/general::ssa::ConvertToSSA(cpp_prog)$<SEMICOLON>/d" ${LITE_SOURCES_DIR}/src/extern_lite/lite/model_parser/model_parser.cc
PATCH_COMMAND sed -i "s?NNadapter_bridges_path = os.path.abspath('..')+\"\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?NNadapter_bridges_path = os.path.abspath(\'..\')+\"\/extern_lite\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?" ${LITE_SOURCES_DIR}/src/extern_lite//lite/tools/cmake_tools/record_supported_kernel_op.py
BUILD_COMMAND ${LITE_BUILD_COMMAND}
INSTALL_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,25 @@ PDNode *patterns::Matmul::operator()() {
return matmul_out;
}

PDNode *patterns::MatmulV2::operator()() {
auto matmul_op =
pattern->NewNode(matmul_op_repr())->assert_is_op("matmul_v2");

auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->assert_is_persistable_var()
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul_v2", "Out");

matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
}

PDNode *patterns::Squeeze2Matmul::operator()() {
auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr())
->assert_is_op_input("squeeze2", "X")
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,19 @@ struct Matmul : public PatternBase {
PATTERN_DECL_NODE(matmul_out);
};

// Matmul_v2 op
// Forward pass for matmul_v2.
struct MatmulV2 : public PatternBase {
MatmulV2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2") {}

PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};

// Squeeze2 + Matmul
// Forward pass.
struct Squeeze2Matmul : public PatternBase {
Expand Down
114 changes: 114 additions & 0 deletions paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cmath>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"

#include "paddle/fluid/framework/op_version_registry.h"
Expand Down Expand Up @@ -67,6 +68,42 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End();
}

MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();

AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}

Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
Expand Down Expand Up @@ -250,6 +287,75 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
patterns::MatmulV2 matmul_pattern(gpd.mutable_pattern(), name_scope);
matmul_pattern();

int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
bool flag = true;

bool trans_x = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_x"));
bool trans_y = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_y"));
flag = flag && !trans_x && !trans_y;

std::vector<int64_t> x_shape = matmul_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;

std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_op});
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulv2ToMulPass in out mul op compat failed.";
return;
}
}
};

gpd(graph, handler);
AddStatis(found_count);
}

void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
Expand Down Expand Up @@ -567,6 +673,14 @@ REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
.LE("matmul", 1)
.EQ("mul", 0));

REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulv2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));

REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ class MapMatmul2MulPass : public FusePassBase {
void ApplyImpl(Graph* graph) const override;
};

/*
* Map matmul_v2 to mul, the same as MapMatmul2MulPass.
*/
class MapMatmulv2ToMulPass : public FusePassBase {
public:
MapMatmulv2ToMulPass();
virtual ~MapMatmulv2ToMulPass() {}

protected:
void ApplyImpl(Graph* graph) const override;
};

/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
Expand Down
33 changes: 17 additions & 16 deletions paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ PDNode* MultiHeadMatmulPattern::operator()() {
PDNode* MultiHeadMatmulV3Pattern::operator()() {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("matmul");
input0->assert_is_ops_input(matmul_ops);

// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul");
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(matmul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
->assert_is_ops_input(matmul_ops, "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops);

decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
Expand Down Expand Up @@ -461,11 +461,12 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);

auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");

auto* eltadd_qk =
Expand Down Expand Up @@ -499,15 +500,15 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
reshape2_qkv_out_var->assert_is_ops_input(matmul_ops);

// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
->assert_is_ops_input(matmul_ops, "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops);

decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
Expand All @@ -534,16 +535,16 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul", "Y"); // link to matmul qk
transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops, "Y"); // link to matmul qk

// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(matmul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
->assert_is_ops_input(matmul_ops, "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops);

decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"tensorrt_subgraph_pass", //
Expand Down Expand Up @@ -141,6 +142,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
Expand Down Expand Up @@ -201,6 +203,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
Expand Down
35 changes: 18 additions & 17 deletions paddle/fluid/inference/lite/test_engine_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,24 @@ TEST(EngineManager, engine) {
};

LOG(INFO) << "Create EngineManager";
inference::Singleton<inference::lite::EngineManager>::Global().Create(
unique_key, config);
LOG(INFO) << "Create EngineManager done";
ASSERT_EQ(
inference::Singleton<inference::lite::EngineManager>::Global().Empty(),
false);
ASSERT_EQ(inference::Singleton<inference::lite::EngineManager>::Global().Has(
unique_key),
true);
paddle::lite_api::PaddlePredictor* engine_0 =
inference::Singleton<inference::lite::EngineManager>::Global().Get(
unique_key);
CHECK_NOTNULL(engine_0);
inference::Singleton<inference::lite::EngineManager>::Global().DeleteAll();
CHECK(inference::Singleton<inference::lite::EngineManager>::Global().Get(
unique_key) == nullptr)
<< "the engine_0 should be nullptr";
// TODO(wilber): The ut is out of date, we need to a new lite subgraph test.
// inference::Singleton<inference::lite::EngineManager>::Global().Create(
// unique_key, config);
// LOG(INFO) << "Create EngineManager done";
// ASSERT_EQ(
// inference::Singleton<inference::lite::EngineManager>::Global().Empty(),
// false);
// ASSERT_EQ(inference::Singleton<inference::lite::EngineManager>::Global().Has(
// unique_key),
// true);
// paddle::lite_api::PaddlePredictor* engine_0 =
// inference::Singleton<inference::lite::EngineManager>::Global().Get(
// unique_key);
// CHECK_NOTNULL(engine_0);
// inference::Singleton<inference::lite::EngineManager>::Global().DeleteAll();
// CHECK(inference::Singleton<inference::lite::EngineManager>::Global().Get(
// unique_key) == nullptr)
// << "the engine_0 should be nullptr";
}

} // namespace lite
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/operators/lite/lite_engine_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,16 @@ TEST(LiteEngineOp, engine_op) {
engine_op_desc.SetAttr("use_gpu", true);
engine_op_desc.SetAttr("zero_copy", true);
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
inference::Singleton<inference::lite::EngineManager>::Global().Create(
engine_key, config);
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
LOG(INFO) << "engine_op " << engine_op.get();
// Execute them.
LOG(INFO) << "engine_op run";
engine_op->Run(scope, place);
LOG(INFO) << "done";
// TODO(wilber): The ut is out of date, we need to a new lite subgraph test.
// inference::Singleton<inference::lite::EngineManager>::Global().Create(
// engine_key, config);
// LOG(INFO) << "create engine op";
// auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
// LOG(INFO) << "engine_op " << engine_op.get();
// // Execute them.
// LOG(INFO) << "engine_op run";
// engine_op->Run(scope, place);
// LOG(INFO) << "done";
}
#endif

Expand Down

1 comment on commit dd76d8d

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.