diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index c7274e8ce36da..089c252ea6947 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -18,6 +18,9 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace framework { @@ -257,16 +260,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { } PDNode* MultiHeadMatmulPattern::operator()() { + std::unordered_set mul_ops{"mul", "matmul_v2"}; + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; auto* input0 = pattern->NewNode(input0_repr()); - input0->assert_is_op_input("mul"); + input0->assert_is_ops_input(mul_ops); // First path with scale - auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops); auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) ->AsInput() - ->assert_is_op_input("mul", "Y"); + ->assert_is_ops_input(mul_ops, "Y"); auto* mul0_out_var = - pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul"); + pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops); decltype(mul0) eltadd0; decltype(mul0) eltadd0_b_var; @@ -299,11 +304,12 @@ PDNode* MultiHeadMatmulPattern::operator()() { auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); auto* scale_out_var = pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); - scale_out_var->AsIntermediate()->assert_is_op_input("matmul"); + scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); auto* eltadd_qk = @@ -319,12 +325,12 @@ PDNode* MultiHeadMatmulPattern::operator()() { pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); - softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); auto* matmul_qkv = - pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul"); + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); auto* matmul_qkv_out_var = - pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul"); + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); auto* transpose2_qkv = @@ -337,15 +343,15 @@ PDNode* MultiHeadMatmulPattern::operator()() { pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) ->assert_is_op_output("reshape2"); - reshape2_qkv_out_var->assert_is_op_input("mul"); + reshape2_qkv_out_var->assert_is_ops_input(mul_ops); // Second path to matmul - auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul"); + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops); auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) ->AsInput() - ->assert_is_op_input("mul", "Y"); + ->assert_is_ops_input(mul_ops, "Y"); auto* mul1_out_var = - pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul"); + pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops); decltype(mul1) eltadd1; decltype(mul1) eltadd1_b_var; @@ -372,16 +378,16 @@ PDNode* MultiHeadMatmulPattern::operator()() { pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_1_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qk + transpose2_1_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qk // Third path to matmul - auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul"); + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops); auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) ->AsInput() - ->assert_is_op_input("mul", "Y"); + ->assert_is_ops_input(mul_ops, "Y"); auto* mul2_out_var = - pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul"); + pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops); decltype(mul2) eltadd2; decltype(mul2) eltadd2_b_var; @@ -408,8 +414,8 @@ PDNode* MultiHeadMatmulPattern::operator()() { pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_2_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qkv + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv // Q path mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); @@ -631,6 +637,68 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { } } // namespace patterns +namespace { +template +inline void QKVWeightsProcess(Tensor* wq_tensor, + Tensor* wk_tensor, + Tensor* wv_tensor, + Tensor* bq_tensor, + Tensor* bk_tensor, + Tensor* bv_tensor) { + auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); + auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); + auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_w_dims = + phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]}); + + framework::LoDTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + auto* tmp_combined_w_data = + tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; + // Combine the three fc weights together. + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (3 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_data[out_index] = w_vec[j][in_index]; + } + } + } + + wq_tensor->Resize(combined_w_dims); + auto* new_combined_w_data = wq_tensor->mutable_data(platform::CPUPlace()); + memcpy( + new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel()); + + framework::LoDTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + auto* tmp_combined_bias_data = + tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size); + memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size); + memcpy( + tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size); + + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = + bq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_bias_data, + tmp_combined_bias_data, + sizeof(T) * bq_tensor->numel()); +} +} // namespace + void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); @@ -757,6 +825,23 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() { .IsType() .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + AddOpCompat(OpCompat("softmax")) .AddInput("X") .IsTensor() @@ -820,16 +905,17 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, auto* bv_tensor = scope->FindVar(eltadd2_b->Name())->GetMutable(); - auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); - auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); - auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); - auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); - auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); - auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); - - auto combined_w_dims = - phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); - auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]}); + if (wq_tensor->dtype() == phi::DataType::FLOAT32) { + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor); + } else if (wq_tensor->dtype() == phi::DataType::FLOAT16) { + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "multihead_matmul not supported weight dtype. we now only support " + "fp32 and fp16.")); + } // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. auto* combined_w_desc = mul0_w->Var(); @@ -840,53 +926,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, combined_bias_desc->SetShape({3, bq_tensor->dims()[0]}); combined_bias_desc->SetPersistable(true); - framework::LoDTensor tmp_combined_w_tensor; - tmp_combined_w_tensor.Resize(combined_w_dims); - auto* tmp_combined_w_data = - tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); - - std::vector w_vec = {wq_data, wk_data, wv_data}; - int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; - // Combine the three fc weights together. - for (int i = 0; i < dims_h; i++) { - for (int j = 0; j < 3; j++) { - for (int k = 0; k < dims_w; k++) { - int out_index = i * (3 * dims_w) + j * dims_w + k; - int in_index = i * dims_w + k; - tmp_combined_w_data[out_index] = w_vec[j][in_index]; - } - } - } - - wq_tensor->Resize(combined_w_dims); - auto* new_combined_w_data = - wq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_w_data, - tmp_combined_w_data, - sizeof(float) * wq_tensor->numel()); - scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); - - framework::LoDTensor tmp_combined_bias_tensor; - tmp_combined_bias_tensor.Resize(combined_bias_dims); - auto* tmp_combined_bias_data = - tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); - - size_t bias_size = bq_tensor->numel(); - memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); - memcpy( - tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); - memcpy(tmp_combined_bias_data + 2 * bias_size, - bv_data, - sizeof(float) * bias_size); - - bq_tensor->Resize(combined_bias_dims); - auto* new_combined_bias_data = - bq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_bias_data, - tmp_combined_bias_data, - sizeof(float) * bq_tensor->numel()); - scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); auto reshape_desc = reshape2->Op(); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 42d58dd782891..52b5d52449581 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -154,18 +154,21 @@ const std::vector kLiteSubgraphPasses({ // support fp16/bf16 precision, temporarily use low precision pass to prevent // running errors. After fusion operator supports low precision, delete this. const std::vector kGpuLowerPrecisionPasses{ + "simplify_with_basic_ops_pass", "conv_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass", "conv_elementwise_add_act_fuse_pass", "conv_elementwise_add2_act_fuse_pass", "conv_elementwise_add_fuse_pass", - "gpu_cpu_map_matmul_v2_to_mul_pass", // - "gpu_cpu_map_matmul_v2_to_matmul_pass", // + "multihead_matmul_fuse_pass_v2", + "gpu_cpu_map_matmul_v2_to_mul_pass", + "gpu_cpu_map_matmul_v2_to_matmul_pass", "fc_fuse_pass", "fc_elementwise_layernorm_fuse_pass", }; const std::vector kTrtLowerPrecisionPasses{ + "simplify_with_basic_ops_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", "trt_map_matmul_v2_to_mul_pass", diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 8a6d5b313ad36..f2d010e16a2ea 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -15,10 +15,12 @@ #include #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { @@ -64,6 +66,26 @@ __device__ float4 add_func(float4 a, float4 b) { c.w = a.w + b.w; return c; } +#if defined(PADDLE_WITH_CUDA) +template <> +__device__ half2 add_func(half2 a, half2 b) { +#if __CUDA_ARCH__ >= 530 + return __hadd2(a, b); +#else + return half2(__float2half(__half2float(a.x) + __half2float(b.x)), + __float2half(__half2float(b.x) + __half2float(b.y))); +#endif +} + +template <> +__device__ half add_func(half a, half b) { +#if __CUDA_ARCH__ >= 530 + return __hadd(a, b); +#else + return __float2half(__half2float(a) + __half2float(b)); +#endif +} +#endif template __global__ void TransposeQkvKernel(const int H, @@ -71,7 +93,7 @@ __global__ void TransposeQkvKernel(const int H, const T *bias, T *output) { // Input: BxSx3xNxH - // Bias: 3xSxB + // Bias: 3xNxH // Output: 3xBxNxSxH int n = threadIdx.y; int s = blockIdx.x; @@ -93,6 +115,17 @@ __global__ void TransposeQkvKernel(const int H, add_func(input[in_offset + i], bias[bias_offset + i]); } +template +void TransQKVWithBias(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const T *input, + const T *bias, + T *output, + gpuStream_t stream); + +template <> void TransQKVWithBias(const int batch, const int seq_len, const int head_size, @@ -153,6 +186,55 @@ void TransQKVWithBias(const int batch, } } +#if defined(PADDLE_WITH_CUDA) +template <> +void TransQKVWithBias(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const platform::float16 *input, + const platform::float16 *bias, + platform::float16 *output, + gpuStream_t stream) { + // BxSx3xNxH + 3xNxH -> 3xBxNxSxH + int scratch_size = batch * head_num * seq_len * seq_len; + const dim3 grid(seq_len, batch, 3); + if (head_size % 2 == 0 && scratch_size % 2 == 0) { + const int h = head_size / 2; + const half2 *input2 = reinterpret_cast(input); + const half2 *bias2 = reinterpret_cast(bias); + half2 *output2 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 2)); + TransposeQkvKernel + <<>>(h, input2, bias2, output2); + } else { + const dim3 block(head_size, head_num, 1); + const half *input_half = reinterpret_cast(input); + const half *bias_half = reinterpret_cast(bias); + half *output_half = reinterpret_cast(output); + + // limit head_size * head_num to max block size(1024). + PADDLE_ENFORCE_LE(head_size * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024)); + TransposeQkvKernel<<>>( + head_size, input_half, bias_half, output_half); + } +} +#endif + inline int round_up(int seq_len, int multiple = 32) { PADDLE_ENFORCE_GT( multiple, @@ -261,18 +343,31 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { bias_d, tptr, stream); - - math::MultiHeadGPUComputeFunctor multihead_compute_func; - multihead_compute_func(device_ctx, - batch, - seq_len, - head_number, - head_size, - qkptr, - bias_qk_d, - tptr, - scale, - T(0.0)); + if (std::is_same::value) { + math::MultiHeadGPUComputeFunctor multihead_compute_func; + multihead_compute_func(device_ctx, + batch, + seq_len, + head_number, + head_size, + reinterpret_cast(qkptr), + reinterpret_cast(bias_qk_d), + reinterpret_cast(tptr), + __float2half(static_cast(scale)), + __float2half(0.0)); + } else { + math::MultiHeadGPUComputeFunctor multihead_compute_func; + multihead_compute_func(device_ctx, + batch, + seq_len, + head_number, + head_size, + qkptr, + bias_qk_d, + tptr, + scale, + T(0.0)); + } int grid = batch * head_number * seq_len; int block = head_size; @@ -285,5 +380,12 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 +REGISTER_OP_CUDA_KERNEL( + multihead_matmul, + ops::MultiHeadMatMulV2Kernel, + ops::MultiHeadMatMulV2Kernel); +#else REGISTER_OP_CUDA_KERNEL(multihead_matmul, ops::MultiHeadMatMulV2Kernel); +#endif