Skip to content

Commit

Permalink
Multihead matmul fp16 (#44792)
Browse files Browse the repository at this point in the history
* multihead matmul add fp16

* fix windows error

* fix rocm error

* fix rocm error
  • Loading branch information
jiweibo authored Aug 2, 2022
1 parent be0ec90 commit 0fd8ee6
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 92 deletions.
194 changes: 117 additions & 77 deletions paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -257,16 +260,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
}

PDNode* MultiHeadMatmulPattern::operator()() {
std::unordered_set<std::string> mul_ops{"mul", "matmul_v2"};
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul");
input0->assert_is_ops_input(mul_ops);

// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops);

decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
Expand Down Expand Up @@ -299,11 +304,12 @@ PDNode* MultiHeadMatmulPattern::operator()() {
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul");
scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);

auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");

auto* eltadd_qk =
Expand All @@ -319,12 +325,12 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);

auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul");
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");

auto* transpose2_qkv =
Expand All @@ -337,15 +343,15 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul");
reshape2_qkv_out_var->assert_is_ops_input(mul_ops);

// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul");
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops);

decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
Expand All @@ -372,16 +378,16 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qk

// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul");
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops);

decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
Expand All @@ -408,8 +414,8 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv

// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
Expand Down Expand Up @@ -631,6 +637,68 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
}
} // namespace patterns

namespace {
template <typename T>
inline void QKVWeightsProcess(Tensor* wq_tensor,
Tensor* wk_tensor,
Tensor* wv_tensor,
Tensor* bq_tensor,
Tensor* bk_tensor,
Tensor* bv_tensor) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());

auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});

framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<T>(platform::CPUPlace());

std::vector<T*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}

wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());

framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<T>(platform::CPUPlace());

size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size);
memcpy(
tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size);

bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(T) * bq_tensor->numel());
}
} // namespace

void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);

Expand Down Expand Up @@ -757,6 +825,23 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
.IsType<bool>()
.End();

AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -820,16 +905,17 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();

auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace());

auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});
if (wq_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcess<float>(
wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor);
} else if (wq_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcess<platform::float16>(
wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"multihead_matmul not supported weight dtype. we now only support "
"fp32 and fp16."));
}

// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var();
Expand All @@ -840,53 +926,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);

framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());

std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}

wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());

scope->EraseVars({mul1_w->Name(), mul2_w->Name()});

framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());

size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(
tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size);

bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());

scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});

auto reshape_desc = reshape2->Op();
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,21 @@ const std::vector<std::string> kLiteSubgraphPasses({
// support fp16/bf16 precision, temporarily use low precision pass to prevent
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
"conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"multihead_matmul_fuse_pass_v2",
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
};

const std::vector<std::string> kTrtLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
// "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
"trt_map_matmul_v2_to_mul_pass",
Expand Down
Loading

0 comments on commit 0fd8ee6

Please sign in to comment.