Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry-pick inference support bert when exists matmul_v2 #36500

Merged
merged 1 commit into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/lite.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
GIT_TAG ${LITE_GIT_TAG}
PREFIX ${LITE_SOURCES_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND sed -i "s?NNadapter_bridges_path = os.path.abspath('..')+\"\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?NNadapter_bridges_path = os.path.abspath(\'..\')+\"\/extern_lite\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?" ${LITE_SOURCES_DIR}/src/extern_lite//lite/tools/cmake_tools/record_supported_kernel_op.py && sed -i "/general::ssa::ConvertToSSA(cpp_prog)$<SEMICOLON>/d" ${LITE_SOURCES_DIR}/src/extern_lite/lite/model_parser/model_parser.cc
PATCH_COMMAND sed -i "s?NNadapter_bridges_path = os.path.abspath('..')+\"\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?NNadapter_bridges_path = os.path.abspath(\'..\')+\"\/extern_lite\/lite\/kernels\/nnadapter\/bridges\/paddle_use_bridges.h\"?" ${LITE_SOURCES_DIR}/src/extern_lite//lite/tools/cmake_tools/record_supported_kernel_op.py
BUILD_COMMAND ${LITE_BUILD_COMMAND}
INSTALL_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,25 @@ PDNode *patterns::Matmul::operator()() {
return matmul_out;
}

PDNode *patterns::MatmulV2::operator()() {
auto matmul_op =
pattern->NewNode(matmul_op_repr())->assert_is_op("matmul_v2");

auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->assert_is_persistable_var()
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul_v2", "Out");

matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
}

PDNode *patterns::Squeeze2Matmul::operator()() {
auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr())
->assert_is_op_input("squeeze2", "X")
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,19 @@ struct Matmul : public PatternBase {
PATTERN_DECL_NODE(matmul_out);
};

// Matmul_v2 op
// Forward pass for matmul_v2.
struct MatmulV2 : public PatternBase {
MatmulV2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2") {}

PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};

// Squeeze2 + Matmul
// Forward pass.
struct Squeeze2Matmul : public PatternBase {
Expand Down
114 changes: 114 additions & 0 deletions paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cmath>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"

#include "paddle/fluid/framework/op_version_registry.h"
Expand Down Expand Up @@ -67,6 +68,42 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End();
}

MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();

AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}

Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
Expand Down Expand Up @@ -250,6 +287,75 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
patterns::MatmulV2 matmul_pattern(gpd.mutable_pattern(), name_scope);
matmul_pattern();

int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
bool flag = true;

bool trans_x = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_x"));
bool trans_y = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_y"));
flag = flag && !trans_x && !trans_y;

std::vector<int64_t> x_shape = matmul_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;

std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_op});
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulv2ToMulPass in out mul op compat failed.";
return;
}
}
};

gpd(graph, handler);
AddStatis(found_count);
}

void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
Expand Down Expand Up @@ -567,6 +673,14 @@ REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
.LE("matmul", 1)
.EQ("mul", 0));

REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulv2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));

REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ class MapMatmul2MulPass : public FusePassBase {
void ApplyImpl(Graph* graph) const override;
};

/*
* Map matmul_v2 to mul, the same as MapMatmul2MulPass.
*/
class MapMatmulv2ToMulPass : public FusePassBase {
public:
MapMatmulv2ToMulPass();
virtual ~MapMatmulv2ToMulPass() {}

protected:
void ApplyImpl(Graph* graph) const override;
};

/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
Expand Down
33 changes: 17 additions & 16 deletions paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ PDNode* MultiHeadMatmulPattern::operator()() {
PDNode* MultiHeadMatmulV3Pattern::operator()() {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("matmul");
input0->assert_is_ops_input(matmul_ops);

// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul");
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(matmul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
->assert_is_ops_input(matmul_ops, "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops);

decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
Expand Down Expand Up @@ -461,11 +461,12 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);

auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");

auto* eltadd_qk =
Expand Down Expand Up @@ -499,15 +500,15 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
reshape2_qkv_out_var->assert_is_ops_input(matmul_ops);

// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
->assert_is_ops_input(matmul_ops, "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops);

decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
Expand All @@ -534,16 +535,16 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul", "Y"); // link to matmul qk
transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops, "Y"); // link to matmul qk

// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(matmul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
->assert_is_ops_input(matmul_ops, "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops);

decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"tensorrt_subgraph_pass", //
Expand Down Expand Up @@ -141,6 +142,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
Expand Down Expand Up @@ -201,6 +203,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
Expand Down
35 changes: 18 additions & 17 deletions paddle/fluid/inference/lite/test_engine_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,24 @@ TEST(EngineManager, engine) {
};

LOG(INFO) << "Create EngineManager";
inference::Singleton<inference::lite::EngineManager>::Global().Create(
unique_key, config);
LOG(INFO) << "Create EngineManager done";
ASSERT_EQ(
inference::Singleton<inference::lite::EngineManager>::Global().Empty(),
false);
ASSERT_EQ(inference::Singleton<inference::lite::EngineManager>::Global().Has(
unique_key),
true);
paddle::lite_api::PaddlePredictor* engine_0 =
inference::Singleton<inference::lite::EngineManager>::Global().Get(
unique_key);
CHECK_NOTNULL(engine_0);
inference::Singleton<inference::lite::EngineManager>::Global().DeleteAll();
CHECK(inference::Singleton<inference::lite::EngineManager>::Global().Get(
unique_key) == nullptr)
<< "the engine_0 should be nullptr";
// TODO(wilber): The ut is out of date, we need to a new lite subgraph test.
// inference::Singleton<inference::lite::EngineManager>::Global().Create(
// unique_key, config);
// LOG(INFO) << "Create EngineManager done";
// ASSERT_EQ(
// inference::Singleton<inference::lite::EngineManager>::Global().Empty(),
// false);
// ASSERT_EQ(inference::Singleton<inference::lite::EngineManager>::Global().Has(
// unique_key),
// true);
// paddle::lite_api::PaddlePredictor* engine_0 =
// inference::Singleton<inference::lite::EngineManager>::Global().Get(
// unique_key);
// CHECK_NOTNULL(engine_0);
// inference::Singleton<inference::lite::EngineManager>::Global().DeleteAll();
// CHECK(inference::Singleton<inference::lite::EngineManager>::Global().Get(
// unique_key) == nullptr)
// << "the engine_0 should be nullptr";
}

} // namespace lite
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/operators/lite/lite_engine_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,16 @@ TEST(LiteEngineOp, engine_op) {
engine_op_desc.SetAttr("use_gpu", true);
engine_op_desc.SetAttr("zero_copy", true);
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
inference::Singleton<inference::lite::EngineManager>::Global().Create(
engine_key, config);
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
LOG(INFO) << "engine_op " << engine_op.get();
// Execute them.
LOG(INFO) << "engine_op run";
engine_op->Run(scope, place);
LOG(INFO) << "done";
// TODO(wilber): The ut is out of date, we need to a new lite subgraph test.
// inference::Singleton<inference::lite::EngineManager>::Global().Create(
// engine_key, config);
// LOG(INFO) << "create engine op";
// auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
// LOG(INFO) << "engine_op " << engine_op.get();
// // Execute them.
// LOG(INFO) << "engine_op run";
// engine_op->Run(scope, place);
// LOG(INFO) << "done";
}
#endif

Expand Down