Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multihead matmul fp16 #44792

Merged
merged 6 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 117 additions & 77 deletions paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -257,16 +260,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
}

PDNode* MultiHeadMatmulPattern::operator()() {
std::unordered_set<std::string> mul_ops{"mul", "matmul_v2"};
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul");
input0->assert_is_ops_input(mul_ops);

// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops);

decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
Expand Down Expand Up @@ -299,11 +304,12 @@ PDNode* MultiHeadMatmulPattern::operator()() {
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul");
scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);

auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");

auto* eltadd_qk =
Expand All @@ -319,12 +325,12 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);

auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul");
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");

auto* transpose2_qkv =
Expand All @@ -337,15 +343,15 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul");
reshape2_qkv_out_var->assert_is_ops_input(mul_ops);

// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul");
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops);

decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
Expand All @@ -372,16 +378,16 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qk

// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul");
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops);

decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
Expand All @@ -408,8 +414,8 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv

// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
Expand Down Expand Up @@ -631,6 +637,68 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
}
} // namespace patterns

namespace {
template <typename T>
inline void QKVWeightsProcess(Tensor* wq_tensor,
Tensor* wk_tensor,
Tensor* wv_tensor,
Tensor* bq_tensor,
Tensor* bk_tensor,
Tensor* bv_tensor) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());

auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});

framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<T>(platform::CPUPlace());

std::vector<T*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}

wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());

framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<T>(platform::CPUPlace());

size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size);
memcpy(
tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size);

bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(T) * bq_tensor->numel());
}
} // namespace

void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);

Expand Down Expand Up @@ -757,6 +825,23 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
.IsType<bool>()
.End();

AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -820,16 +905,17 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();

auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace());

auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});
if (wq_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcess<float>(
wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor);
} else if (wq_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcess<platform::float16>(
wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"multihead_matmul not supported weight dtype. we now only support "
"fp32 and fp16."));
}

// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var();
Expand All @@ -840,53 +926,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);

framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());

std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}

wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());

scope->EraseVars({mul1_w->Name(), mul2_w->Name()});

framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());

size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(
tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size);

bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());

scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});

auto reshape_desc = reshape2->Op();
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,21 @@ const std::vector<std::string> kLiteSubgraphPasses({
// support fp16/bf16 precision, temporarily use low precision pass to prevent
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
"conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"multihead_matmul_fuse_pass_v2",
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
};

const std::vector<std::string> kTrtLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
// "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
"trt_map_matmul_v2_to_mul_pass",
Expand Down
Loading