From f8f9bfaaf876a74abff565f634bfa596233935ed Mon Sep 17 00:00:00 2001 From: Hongwen Xin <55238914+penPenf28@users.noreply.github.com> Date: Thu, 23 May 2024 18:11:28 +0800 Subject: [PATCH] Fused_mt Branch Migration (#64125) * Merge fused_mt branch * Adjusted fuse_mt_int8 * Revert attention_layer_norm.h * Revert paddle/phi/kernels/fusion/gpu/fmha_ref.h * Add win support and refine format. * Reformat for win. * Removed redundant files, now only supports flash_attn_v2 and variable length * Refine static_fused_ft test * Refine fused_mt related testcase * Remove custom_adll_reduce * Remove operator cublaslt and revert parallel test * Refine empty seq_len * Refine ft * Refine ft_static test * Remove float32 support and static parallel ft test * Refine type static error. * Fix doc type error * Fuse_mt code format * Remove some redundant code * Remove redundant attention_layer_norm.h * Remove redundant code in ft_op * Remove Redundant code and skip fuse_mt doctest * Remove redundant fmha_ref mmha_util and other code * Remove redundant kernel * Remove redundant file * Refine fuse_mt code * Refine cublaslt comment --- .../fluid/inference/api/analysis_predictor.h | 4 + paddle/fluid/inference/api/api.cc | 2 + paddle/fluid/inference/api/api_impl.cc | 2 + .../inference/api/details/zero_copy_tensor.cc | 1 + paddle/fluid/operators/fused/attn_gemm_int8.h | 7 +- .../fused/fused_multi_transformer_helper.cu.h | 325 +++ .../fused/fused_multi_transformer_int8_op.cu | 51 +- .../fused/fused_multi_transformer_op.cc | 75 +- .../fused/fused_multi_transformer_op.cu | 2094 +++++--------- .../fused/fused_multi_transformer_op.cu.h | 2535 +++++++++-------- .../fused_multi_transformer_sig.cc | 32 +- paddle/fluid/platform/dynload/cublasLt.h | 32 +- paddle/fluid/pybind/eager_generator.h | 26 +- paddle/fluid/pybind/inference_api.cc | 23 +- paddle/phi/backends/dynload/cublasLt.h | 32 +- paddle/phi/infermeta/fusion.cc | 5 + paddle/phi/infermeta/fusion.h | 5 + paddle/phi/kernels/fusion/gpu/mmha_util.cu.h | 83 + .../ops/yaml/inconsistent/dygraph_ops.yaml | 4 +- .../phi/ops/yaml/inconsistent/static_ops.yaml | 4 +- .../communication/stream/all_to_all.py | 2 +- .../distributed/communication/stream/recv.py | 2 +- .../distributed/communication/stream/send.py | 2 +- python/paddle/distributed/utils/moe_utils.py | 4 +- python/paddle/incubate/layers/nn.py | 1 + .../nn/functional/fused_transformer.py | 138 +- .../incubate/nn/layer/fused_transformer.py | 166 +- python/paddle/tensor/creation.py | 4 +- python/paddle/tensor/linalg.py | 2 +- python/paddle/tensor/math.py | 15 +- test/legacy_test/CMakeLists.txt | 3 - ..._model_parallel_fused_multi_transformer.py | 192 -- test/legacy_test/test_empty_like_op.py | 9 +- test/legacy_test/test_empty_op.py | 9 +- .../test_fused_multi_transformer_op.py | 852 ++++-- ..._model_parallel_fused_multi_transformer.py | 50 - 36 files changed, 3695 insertions(+), 3098 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h delete mode 100644 test/legacy_test/static_model_parallel_fused_multi_transformer.py delete mode 100644 test/legacy_test/test_static_model_parallel_fused_multi_transformer.py diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 9c426a7a14219..e3fbf097bbb7a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -29,6 +29,8 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/resource_manager.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/utils/string/printf.h" #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) @@ -45,6 +47,8 @@ #include "paddle/pir/include/core/program.h" namespace paddle_infer { +using float16 = paddle::platform::float16; +using bfloat16 = phi::dtype::bfloat16; namespace experimental { class InternalUtils; }; diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index da29b3124fa72..b9699e2692678 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -28,6 +28,8 @@ int PaddleDtypeSize(PaddleDType dtype) { switch (dtype) { case PaddleDType::FLOAT32: return sizeof(float); + case PaddleDType::BFLOAT16: + return sizeof(uint16_t); case PaddleDType::INT64: return sizeof(int64_t); case PaddleDType::INT32: diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 9ae284402f196..744b10c659aab 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -221,6 +221,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::INT32) { input_ptr = input.mutable_data(ddim, place_); + } else if (inputs[i].dtype == PaddleDType::BFLOAT16) { + input_ptr = input.mutable_data(ddim, place_); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index a1bc622120694..d8206093efa53 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -826,6 +826,7 @@ template void Tensor::ORTCopyToCpu(int32_t *data) const; template void Tensor::ORTCopyToCpu(uint8_t *data) const; template void Tensor::ORTCopyToCpu(int8_t *data) const; template void Tensor::ORTCopyToCpu(float16 *data) const; +template void Tensor::ORTCopyToCpu(bfloat16 *data) const; #endif namespace experimental { diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index 157fade9e1526..8d5312ec6367b 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -16,12 +16,12 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/fused/cublaslt.h" #include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/cublaslt.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace paddle { @@ -35,7 +35,8 @@ class AttnMatmulINT8 { AttnMatmulINT8( const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias) : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { - auto helper = std::make_shared(m, k, n); + auto helper = std::make_shared( + m, k, n, dev_ctx.cublaslt_handle()); helpers_.emplace_back(helper); gpu_config_ = std::make_unique( phi::backends::gpu::GetGpuLaunchConfig1D( @@ -186,7 +187,7 @@ class AttnMatmulINT8 { int k_; // k int compute_bias_; - std::vector> helpers_; + std::vector> helpers_; std::unique_ptr gpu_config_; }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h new file mode 100644 index 0000000000000..4594918e6f7b5 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h @@ -0,0 +1,325 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/fused/attn_gemm_int8.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" +#include "paddle/phi/kernels/funcs/cublaslt.h" +#include "paddle/phi/kernels/funcs/load_store_util.h" +#include "paddle/phi/kernels/funcs/quant_dequant.h" +#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" +#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" + +/* +Note(Zhengzekang): +This header file is to store General Function Helper which has been used in +FusedMultiTransformer. +*/ + +namespace phi { +namespace fusion { + +namespace { // NOLINT + +template +class BiasActHelper { + public: + BiasActHelper(const phi::GPUContext &dev_ctx, + const std::string &act_method, + int rows, + int cols) + : dev_ctx_(dev_ctx), act_method_(act_method), rows_(rows), cols_(cols) {} + + // dst = Activation(x + bias(optional)) + void Compute(const phi::DenseTensor *x, + const phi::DenseTensor *bias, + phi::DenseTensor *output) { + const T *bias_data = (bias == nullptr) ? nullptr : bias->data(); + phi::funcs::Load load_func(x->data()); + phi::funcs::Store store_func(output->data()); + ComputeImpl(bias_data, load_func, store_func); + } + + private: + template + void ComputeImpl(const T *bias_data, + LoadFunc load_func, + StoreFunc store_func) { + if (act_method_ == "geglu") { + // Note(Zhengzekang): For GLU structure, we need divide the cols by 2. + VLOG(5) << "doing geglu"; + LaunchActFFNGlu, + LoadFunc, + StoreFunc, + LoadT>( + dev_ctx_, bias_data, rows_, cols_ / 2, load_func, store_func); + } else if (act_method_ == "swiglu") { + VLOG(5) << "doing swiglu"; + LaunchActFFNGlu, LoadFunc, StoreFunc, LoadT>( + dev_ctx_, bias_data, rows_, cols_ / 2, load_func, store_func); + } else if (act_method_ == "gelu") { + if (FLAGS_use_fast_math) { + VLOG(5) << "doing Fast GELU"; + LaunchBiasAct, + LoadFunc, + StoreFunc, + LoadT>( + dev_ctx_, bias_data, rows_, cols_, load_func, store_func); + } else { + VLOG(5) << "doing GELU"; + LaunchBiasAct, + LoadFunc, + StoreFunc, + LoadT>( + dev_ctx_, bias_data, rows_, cols_, load_func, store_func); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently Only Support GeGLU, SwiGLU, GeLU")); + } + } + const phi::GPUContext &dev_ctx_; + std::string act_method_; + int rows_; + int cols_; +}; + +template ::DataType> +class GEMMHelper { + public: + GEMMHelper(const phi::GPUContext &dev_ctx, + int token_num, + int dim_ffn, + int dim_embed, + const std::string gemm_method, + bool transpose_weight = false) + : dev_ctx_(dev_ctx), + token_num_(token_num), + dim_ffn_(dim_ffn), + dim_embed_(dim_embed), + gemm_method_(gemm_method), + transpose_weight_(transpose_weight) {} + + // dst = act(fc(src[0]) + bias) * src[1] + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *scale, + const phi::DenseTensor *bias, + phi::DenseTensor *workspace, + phi::DenseTensor *output) { + VLOG(5) << "GEMMHelper," + << " token_num_:" << token_num_ << " dim_ffn_:" << dim_ffn_ + << " dim_embed_:" << dim_embed_; + bool compute_bias = true; + if (bias == nullptr) { + compute_bias = false; + } + using NvType = typename phi::PDDataTypeTraits::DataType; + + if (gemm_method_ == "None") { + auto ffn_linear_compute = phi::fusion::AttnMatMul(dev_ctx_, + false, + transpose_weight_, + token_num_, + dim_ffn_, + dim_embed_, + compute_bias); + ffn_linear_compute.ComputeForward(weight, input, bias, output, output); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently GemmHelper only support `None`. ")); + } + } + + private: + const phi::GPUContext &dev_ctx_; + int token_num_; + int dim_ffn_; + int dim_embed_; + std::string gemm_method_; + bool transpose_weight_; // Just For AttnMatmul. +}; + +template +class NormHelper { + public: + NormHelper(const phi::GPUContext &dev_ctx, + const std::string &norm_type, + const int rows, + const int cols, + const float epsilon, + const float residual_alpha) + : dev_ctx_(dev_ctx), + norm_type_(norm_type), + rows_(rows), + cols_(cols), + epsilon_(epsilon), + residual_alpha_( + residual_alpha), // TODO(zhengzekang): currently only available for + // Layernorm. Need support rmsnorm. + layernorm_helper_(dev_ctx_, epsilon_, rows_, cols_) { + // VLOG(0) << "NormHelper residual_alpha:" << residual_alpha_; + paddle::operators::DropoutParam dropout_param( + true, 0, true, true, 0.0, nullptr, 0); + residual_bias_add_layernorm_helper_ = + paddle::operators::FusedDropoutLayerNormHelper( + dev_ctx, rows_, cols_, dropout_param, epsilon_); + } + + /* + Note(Zhengzekang): + Since input `X` and `Residual` in FusedMT will be swaped by preallocated + buffer, I have no choice but to pass the data pointer instead of + phi::DenseTensor. + */ + + // dst = Norm(x + residual + bias(optional)) + void NormResidualBias(const T *x_data, + const T *residual_data, + const phi::DenseTensor *bias, + const phi::DenseTensor *norm_weight, + const phi::DenseTensor *norm_bias, + phi::DenseTensor *mean, + phi::DenseTensor *var, + phi::DenseTensor *bias_residual_out, + phi::DenseTensor *output) { + using U = paddle::operators::LayerNormParamType; + const T *bias_data = bias ? bias->data() : nullptr; + U *mean_data = mean ? mean->data() : nullptr; + U *var_data = var ? var->data() : nullptr; + T *bias_residual_out_data = bias_residual_out->data(); + T *output_data = output->data(); + + if (norm_type_ == "layernorm") { + // For layernorm, it use FP32 type weight and bias. + const U *norm_weight_data = + norm_weight ? norm_weight->data() : nullptr; + const U *norm_bias_data = norm_bias ? norm_bias->data() : nullptr; + residual_bias_add_layernorm_helper_.LayernormResidualDropoutBias( + dev_ctx_, + x_data, + residual_data, + bias_data, + norm_weight_data, + norm_bias_data, + bias_residual_out_data, + nullptr, + output_data, + mean_data, + var_data); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently NormHelper only support `layernorm`. ")); + } + } + + // dst = Norm(x) + void Norm(const T *x_data, + const phi::DenseTensor *norm_weight, + const phi::DenseTensor *norm_bias, + phi::DenseTensor *mean, + phi::DenseTensor *var, + phi::DenseTensor *output) { + using U = paddle::operators::LayerNormParamType; + U *mean_data = mean ? mean->data() : nullptr; + U *var_data = var ? var->data() : nullptr; + T *output_data = output->data(); + + if (norm_type_ == "layernorm") { + // For layernorm, it use FP32 type weight and bias. + const U *norm_weight_data = + norm_weight ? norm_weight->data() : nullptr; + const U *norm_bias_data = norm_bias ? norm_bias->data() : nullptr; + layernorm_helper_.ComputeForward(x_data, + norm_weight_data, + norm_bias_data, + output_data, + mean_data, + var_data); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently NormHelper only support `layernorm`. ")); + } + } + + private: + const phi::GPUContext &dev_ctx_; + std::string norm_type_; + int rows_; + int cols_; + float epsilon_; + float residual_alpha_; + paddle::operators::FusedDropoutLayerNormHelper + residual_bias_add_layernorm_helper_; + AttnLayerNorm layernorm_helper_; +}; + +template ::DataType> +class FFNHelper { + public: + FFNHelper(const phi::GPUContext &dev_ctx, + const std::string &act_method, + int token_num, + int dim_ffn, + int dim_embed, + const std::string gemm_method) + : dev_ctx_(dev_ctx), + act_method_(act_method), + token_num_(token_num), + dim_ffn_(dim_ffn), + dim_embed_(dim_embed), + gemm_method_(gemm_method) {} + + // dst = act(fc(src[0]) + bias) * src[1] + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *scale, + const phi::DenseTensor *bias, + phi::DenseTensor *workspace, + phi::DenseTensor *bias_out, + phi::DenseTensor *output) { + /* + input's shape [token_num, dim_embed] + weight's shape [dim_embed, dim_ffn] + bias' shape [dim_ffn] + output's shape [token_num, dim_ffn]. + */ + GEMMHelper gemm_helper( + dev_ctx_, token_num_, dim_ffn_, dim_embed_, gemm_method_); + BiasActHelper bias_act_helper( + dev_ctx_, act_method_, token_num_, dim_ffn_); + + gemm_helper.Compute(input, weight, scale, bias, workspace, bias_out); + bias_act_helper.Compute(bias_out, nullptr, output); + } + + private: + const phi::GPUContext &dev_ctx_; + std::string act_method_; + int token_num_; + int dim_ffn_; + int dim_embed_; + std::string gemm_method_; +}; + +} // namespace + +} // namespace fusion +} // namespace phi diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 043ad86ebee71..28fc341f217a1 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused/attn_gemm_int8.h" -#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h" +#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h" #include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" +#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" namespace paddle { namespace operators { @@ -100,7 +102,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto *src_mask = ctx.Input("SrcMask"); auto cache_kvs = ctx.MultiInput("CacheKV"); auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); auto out_seq_len = seq_len; if (time_step) { @@ -289,9 +290,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { quant_max_bound, quant_min_bound); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif // step2. qkv const phi::DenseTensor *qkv_bias = @@ -334,30 +332,39 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &qkv_out, qkv_out_scales[i]); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif // step3. fmha const phi::DenseTensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + int cache_bsz = 0; + if (cache_kv) { + cache_bsz = cache_kv->dims()[1]; + } + if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] int max_seq_len = cache_kv->dims()[3]; phi::fusion::fmha(dev_ctx, qkv_out, *qkv_bias, - *src_mask, + src_mask, + nullptr, + nullptr, + nullptr, + nullptr, cache_kv_out, &fmha_out, bsz, + cache_bsz, + seq_len, max_seq_len, num_head, dim_head, time_step->data()[0], - 1. / std::sqrt(dim_head)); + 0, + 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage // TODO(wangxi): can remove dropout in inference fmha_compute.ComputeForward(qkv_out, @@ -413,9 +420,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &qktv_out, &fmha_out); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif if (pre_layer_norm) { out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], @@ -447,9 +451,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { quant_min_bound); phi::fusion::AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; -#endif // step5. ln(residual + dropout(input + bias)) if (pre_layer_norm) { @@ -495,9 +496,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ln_mean_data, ln_var_data); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; -#endif // step6. ffn matmul1 @@ -523,9 +521,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { quant_max_bound, quant_min_bound); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; -#endif // step7. act bias // TODO(wangxi): remove dropout mask in inference @@ -552,9 +547,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ffn1_dropout_out_data, ffn1_dropout_mask_data); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif // step8. ffn matmul2 if (pre_layer_norm) { @@ -579,9 +571,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { quant_max_bound, quant_min_bound); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.0"; -#endif if (pre_layer_norm) { phi::fusion::AllReduce(output_workspace, @@ -591,9 +580,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { } else { phi::fusion::AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.1"; -#endif // step9. residual bias if (pre_layer_norm) { @@ -648,9 +634,6 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ln_mean_data, ln_var_data); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step9"; -#endif if (pre_layer_norm) { x_data = buf1->data(); } diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index dc90eaa3e5306..22029e3176163 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -61,8 +61,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto x_dim = ctx->GetInputDim("X"); - auto y_dim = ctx->GetInputsDim("QKVW")[0]; - bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); PADDLE_ENFORCE_EQ( x_dim.size(), 3, @@ -71,25 +69,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "but received dimensions of" "Input is [%d]", x_dim.size())); - PADDLE_ENFORCE_EQ( - y_dim.size(), - 4, - phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "but received dimensions of" - "Input is [%d]", - y_dim.size())); - PADDLE_ENFORCE_EQ( - x_dim[2], - trans_qkvw ? y_dim[3] : y_dim[0], - phi::errors::InvalidArgument( - "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " - "true) or y_dim[0](trans_qkvw is false)" - "must be equal. But received: the shape " - "of input x = [%s], and the shape of " - "input qkv_weight = [%s]", - x_dim, - y_dim)); if (ctx->HasInputs("CacheKV")) { // [2, batch_size, num_head, max_seq_len, head_size] @@ -113,20 +92,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "batch size %d, but got %d", x_dim[0], c_dim[1])); // batch_size - PADDLE_ENFORCE_EQ(c_dim[2], - trans_qkvw ? y_dim[1] : y_dim[2], - phi::errors::InvalidArgument( - "The third dim of CacheKV must be equal with num " - "head %d, but got %d", - trans_qkvw ? y_dim[1] : y_dim[2], - c_dim[2])); // num_head - PADDLE_ENFORCE_EQ(c_dim[4], - trans_qkvw ? y_dim[2] : y_dim[3], - phi::errors::InvalidArgument( - "The fifth dim of CacheKV must be equal with head " - "size %d, but got %d", - trans_qkvw ? y_dim[2] : y_dim[3], - c_dim[4])); // head_size } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -166,6 +131,7 @@ class FusedMultiTransformerOpOpMaker AddInput("LnBias", "Bias is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable() .AsDuplicable(); AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); @@ -179,6 +145,9 @@ class FusedMultiTransformerOpOpMaker AddInput("RotaryPosEmb", "(optional) The RoPE embeddings for generation inference.") .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); @@ -194,6 +163,7 @@ class FusedMultiTransformerOpOpMaker AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") .AsDuplicable(); AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDispensable() .AsDuplicable(); AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") .AsDuplicable(); @@ -240,6 +210,10 @@ class FusedMultiTransformerOpOpMaker epsilon)); }); + AddAttr("residual_alpha", + "Constant for residual_alpha [default 1.0].") + .SetDefault(1.0f); + AddAttr("dropout_rate", "Probability of setting units to zero.") .SetDefault(.5f) .AddCustomChecker([](const float &drop_p) { @@ -269,12 +243,14 @@ class FusedMultiTransformerOpOpMaker AddAttr("act_method", "act_method") .SetDefault("gelu") .AddCustomChecker([](const std::string &act_type) { - PADDLE_ENFORCE_EQ( - act_type == "gelu" || act_type == "relu" || act_type == "none", - true, - phi::errors::InvalidArgument( - "Only support `gelu`, `relu`, `none` activation in " - "FusedMultiTransformer. ")); + PADDLE_ENFORCE_EQ(act_type == "gelu" || act_type == "geglu" || + act_type == "swiglu" || act_type == "relu" || + act_type == "none", + true, + phi::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `swiglu`, " + "`relu`, `none` activation in " + "FusedMultiTransformer. ")); }); AddAttr( @@ -290,6 +266,23 @@ class FusedMultiTransformerOpOpMaker "ring id for tensor model parallel. distributed training and inference") .SetDefault(-1); + AddAttr("norm_type", "norm_type") + .SetDefault("layernorm") + .AddCustomChecker([](const std::string &norm_type) { + PADDLE_ENFORCE_EQ( + norm_type == "layernorm" || norm_type == "rmsnorm", + true, + phi::errors::InvalidArgument( + "Only support `layernorm`, `rmsnorm` method for in" + "FusedMultiTransformerDyquant. ")); + }); + + AddAttr("use_neox_rotary_style", + "Whether use neox rotary embedding. ") + .SetDefault(false); + + AddAttr("gqa_group_size", "(int, default -1) the group size of GQA") + .SetDefault(-1); AddComment(R"DOC(fused multi transformer layers op)DOC"); } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index b3718dfe1f7d5..6924a698b65fe 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h" #include "paddle/fluid/framework/op_registry.h" @@ -22,1387 +22,831 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" -#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" +#include "paddle/phi/kernels/flash_attn_kernel.h" #include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" -#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" - -namespace phi { -namespace fusion { - -#if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation. - -template -void FusedMultiTransformerKernel( - const Context &dev_ctx, - const DenseTensor &x, - const std::vector &ln_scales, - const std::vector &ln_biases, - const std::vector &qkv_weights, - const paddle::optional> &qkv_biases, - const paddle::optional> &cache_kvs, - const paddle::optional> &pre_caches, - const paddle::optional &rotary_tensor, - const paddle::optional &time_step, - const paddle::optional &seq_lengths, - const paddle::optional &src_mask, - const std::vector &out_linear_weights, - const paddle::optional> &out_linear_biases, - const std::vector &ffn_ln_scales, - const std::vector &ffn_ln_biases, - const std::vector &ffn1_weights, - const paddle::optional> &ffn1_biases, - const std::vector &ffn2_weights, - const paddle::optional> &ffn2_biases, - bool pre_layer_norm, - float epsilon, - float dropout_rate, - int rotary_emb_dims, - bool is_test, - const std::string &dropout_implementation, - const std::string &act_method, - bool trans_qkvw, - int ring_id, - std::vector cache_kv_outs, - DenseTensor *out) { - if (cache_kvs) { - for (size_t i = 0; i < cache_kv_outs.size(); i++) { - *(cache_kv_outs[i]) = *(cache_kvs.get()[i]); +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace paddle { +namespace operators { + +template +class FusedMultiTransformerOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + const std::string act_method = ctx.Attr("act_method"); + bool use_glu = (act_method == "geglu" || act_method == "swiglu"); + const std::string norm_type = ctx.Attr("norm_type"); + const bool use_neox_rotary_style = ctx.Attr("use_neox_rotary_style"); + bool remove_padding = false; + auto *sequence_lengths = ctx.Input("SeqLengths"); + phi::DenseTensor sequence_lengths_backup; + if (sequence_lengths) { + remove_padding = true; + } else { + sequence_lengths_backup.Resize({{1}}); + auto *sequence_lengths_backup_data = + dev_ctx.Alloc(&sequence_lengths_backup, + sequence_lengths_backup.numel() * sizeof(int)); + phi::fusion::InitValue(dev_ctx, + sequence_lengths_backup_data, + sequence_lengths_backup.numel() * sizeof(int), + static_cast(seq_len)); + remove_padding = true; } - } - using U = phi::funcs::LayerNormParamType; - - auto *rotary_tensor_t = rotary_tensor.get_ptr(); - auto *seq_lengths_t = seq_lengths.get_ptr(); - auto *src_mask_t = src_mask.get_ptr(); - auto *time_step_t = time_step.get_ptr(); - - const auto input_x_dims = x.dims(); - int bsz = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int bsz_seq = bsz * seq_len; - bool remove_padding = false; - if (seq_lengths_t) { - remove_padding = true; - } - phi::DenseTensor d_token_tensor; - phi::DenseTensor padding_offset_tensor; - phi::DenseTensor x_remove_padding; - bool encoder_remove_padding = (remove_padding && !time_step_t); - int token_num = 0; - - // remove padding in encoder - if (encoder_remove_padding) { - // just for encoder - d_token_tensor.Resize({1}); - auto *d_token_num = dev_ctx.template Alloc( - &d_token_tensor, d_token_tensor.numel() * sizeof(int)); - // alloc the max size of padding_offset_tensor - padding_offset_tensor.Resize({bsz_seq}); - dev_ctx.template Alloc(&padding_offset_tensor, - padding_offset_tensor.numel() * sizeof(int)); - InvokeGetPaddingOffset(dev_ctx, - &token_num, - d_token_num, - padding_offset_tensor.data(), - seq_lengths_t->data(), - bsz, - seq_len); - padding_offset_tensor.Resize({token_num}); - x_remove_padding.Resize({token_num, dim_embed}); - dev_ctx.template Alloc(&x_remove_padding, - x_remove_padding.numel() * sizeof(T)); - InvokeRemovePadding(dev_ctx, - x_remove_padding.data(), - x.data(), - padding_offset_tensor.data(), - token_num, - dim_embed); - } else { - token_num = bsz_seq; - } - auto *padding_offset_data = - encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - - auto ln_compute = - phi::fusion::AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); - phi::DenseTensor ln_mean, ln_var; - ln_mean.Resize({token_num}); - auto *ln_mean_data = - dev_ctx.template Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({token_num}); - auto *ln_var_data = - dev_ctx.template Alloc(&ln_var, ln_var.numel() * sizeof(U)); - - // 2. qkv - // x: qkv's input [batch_size, seq_len, dim_embed] - // y: qkv's weight: [3, num_head, dim_head, dim_embed] - const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool compute_bias = - qkv_biases && !qkv_biases.get().empty() && time_step_t == nullptr; - // (transA, transB, compute_bias) = (false, trans_qkvw, false) - // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set - // compute_bias as false. - auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); - - phi::DenseTensor qkv_out; - qkv_out.Resize({token_num, 3, num_head, dim_head}); - auto *qkv_out_data = - dev_ctx.template Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); - - // 3. fmha - AttnDropoutParam attn_param( - true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - int cache_offset = 0; - if (pre_caches && pre_caches.get().size() > 0) { - cache_offset = pre_caches.get()[0]->dims()[3]; - } - - auto out_seq_len = seq_len; - if (time_step_t) { - PADDLE_ENFORCE_EQ(time_step_t->place(), - phi::CPUPlace(), - phi::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step_t->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, - 0, - phi::errors::PreconditionNotMet( - "The value of time_step_t must > 0, but now is %d", - time_step_value)); - PADDLE_ENFORCE_EQ( - seq_len, - 1, - phi::errors::PreconditionNotMet( - "In decode stage, the seq_len of input must be 1, but now is %d", - seq_len)); - out_seq_len += time_step_value; - } else { - out_seq_len += cache_offset; - } + auto gqa_group_size = ctx.Attr("gqa_group_size"); - phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; - q_transpose_out.Resize({bsz, num_head, seq_len, dim_head}); - auto *q_transpose_out_data = dev_ctx.template Alloc( - &q_transpose_out, q_transpose_out.numel() * sizeof(T)); + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } - kv_transpose_out.Resize({2, bsz, num_head, seq_len, dim_head}); - auto *kv_transpose_out_data = dev_ctx.template Alloc( - &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; - qk_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *qk_out_data = - dev_ctx.template Alloc(&qk_out, qk_out.numel() * sizeof(T)); + // cumulative seqlens [batch_size+1] + phi::DenseTensor cu_seqlens_q, cu_seqlens_k; + bool encoder_remove_padding = (remove_padding && !time_step); + int token_num = 0; - phi::DenseTensor src_mask_out; - if (cache_offset > 0) { - src_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *src_mask_out_data = dev_ctx.template Alloc( - &src_mask_out, src_mask_out.numel() * sizeof(T)); - } + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - // [2, bs, num_head, cache_seq_len + seq_len, head_dim] - phi::DenseTensor pre_cache_kv_out; - if (cache_offset > 0) { - pre_cache_kv_out.Resize( - {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); - auto *pre_cache_kv_out_data = dev_ctx.template Alloc( - &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); - } + // Init out + if (encoder_remove_padding) { + phi::fusion::InitValue( + dev_ctx, from_data, out->numel(), static_cast(0.)); + phi::fusion::InitValue( + dev_ctx, from_data, out->numel(), static_cast(0.)); + } - phi::DenseTensor softmax_out; - phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; - phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *softmax_out_data = - dev_ctx.template Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - - attn_dropout_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *attn_dropout_mask_out_data = dev_ctx.template Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *attn_dropout_data_data = dev_ctx.template Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - - qktv_out.Resize({bsz, num_head, seq_len, dim_head}); - auto *qktv_out_data = - dev_ctx.template Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({bsz, seq_len, num_head, dim_head}); - auto *fmha_out_data = - dev_ctx.template Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); - - // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, hidden_size, false); - - // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, token_num, dim_embed, dropout_param2, epsilon); - phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; - T *bias_dropout_residual_out_data = nullptr; - if (pre_layer_norm) { - bias_dropout_residual_out.Resize({token_num, dim_embed}); - bias_dropout_residual_out_data = dev_ctx.template Alloc( - &bias_dropout_residual_out, - bias_dropout_residual_out.numel() * sizeof(T)); - } - dropout_mask_out.Resize({token_num, dim_embed}); - auto *dropout_mask_out_data = dev_ctx.template Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); - - // 6. ffn1 matmul + act + bias - auto ffn1_weight_dim = ffn1_weights[0]->dims(); - - int dim_ffn = ffn1_weight_dim[1]; - - auto ffn1_cublas_linear = CublasFusedMLP(dev_ctx); - const phi::DDim ffn1_input_shape({token_num, dim_embed}); - ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false); - - phi::DenseTensor ffn1_out; - ffn1_out.Resize({token_num, dim_ffn}); - auto *ffn1_out_data = - dev_ctx.template Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); - - // 7. ffn2 matmul + bias + residual. - auto ffn2_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); - - // 8. ffn2 Layernorm residual bias - DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *from_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; - phi::DenseTensor tmp_out, tmp_out_rm_padding; - tmp_out.Resize({token_num, dim_embed}); - if (encoder_remove_padding) { - tmp_out_rm_padding.Resize({token_num, dim_embed}); - auto *tmp_out_rm_padding_data = dev_ctx.template Alloc( - &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); - } - auto *tmp_out_data = - dev_ctx.template Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - - const T *x_data; - if (encoder_remove_padding) { - x_data = x_remove_padding.data(); - } else { - x_data = x.data(); - } - phi::DenseTensor *buf0 = nullptr; - phi::DenseTensor *buf1 = nullptr; - - // step0: x --> buf1 - // step1: buf1 --> buf0 - // step2: buf0 --> buf1 - int layers = qkv_weights.size(); - if (encoder_remove_padding) { - // In the case of variable lengths, the padding needs to be rebuilt - // eventually. So buf0 and buf1 do not need to be changed according to the - // pre_layer_norm and the number of layers. - buf0 = &tmp_out; - buf1 = &tmp_out_rm_padding; - } else { - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out - buf0 = &tmp_out; - buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; - } + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({{1}}); + auto *d_token_num = dev_ctx.Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({{bsz_seq}}); + dev_ctx.Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + cu_seqlens_q.Resize({{bsz + 1}}); + dev_ctx.Alloc(&cu_seqlens_q, + cu_seqlens_q.numel() * sizeof(int32_t)); + + phi::fusion::InvokeGetPaddingOffset( + dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + cu_seqlens_q.data(), + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(), + bsz, + seq_len); + if (token_num == 0) return; + padding_offset_tensor.Resize({{token_num}}); + x_remove_padding.Resize({{token_num, dim_embed}}); + dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); + phi::fusion::InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + input_x->data(), + padding_offset_tensor.data(), + token_num, + dim_embed); } else { - buf0 = &tmp_out; - buf1 = out; + token_num = bsz_seq; + if (token_num == 0) return; } - } - for (int i = 0; i < layers; ++i) { - // step1. layer_norm - if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + const float residual_alpha = ctx.Attr("residual_alpha"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + phi::fusion::NormHelper norm_helper( + dev_ctx, norm_type, token_num, dim_embed, epsilon, residual_alpha); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({{token_num}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{token_num}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] if not GQA else + // [num_head + 2 * gqa_group_size, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head, dim_head; + if (gqa_group_size > 0) { + num_head = trans_qkvw ? (qkv_w_dims[0] - 2 * gqa_group_size) + : (qkv_w_dims[1] - 2 * gqa_group_size); + dim_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + } else { + num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif - - // step2. qkv - const phi::DenseTensor *qkv_bias = - qkv_biases && !qkv_biases.get().empty() ? qkv_biases.get()[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step_t ? nullptr : qkv_bias; - if (!pre_layer_norm && i == 0) { - const phi::DenseTensor *tmp_input_x = - (encoder_remove_padding) ? &x_remove_padding : &x; - qkv_compute.ComputeForward( - qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); + int hidden_size = num_head * dim_head; + int output_size = gqa_group_size <= 0 + ? 3 * hidden_size + : (num_head + 2 * gqa_group_size) * dim_head; + int input_size = dim_embed; + + // Set a flag whether need to add Matmul / Layernorm bias. + bool compute_bias = qkv_biases.size() > 0; + bool compute_ln_bias = ln_biases.size() > 0; + + auto qkv_compute = phi::fusion::GEMMHelper( + dev_ctx, token_num, output_size, input_size, "None", trans_qkvw); + + phi::DenseTensor qkv_out; + if (gqa_group_size > 0) { + qkv_out.Resize({{token_num, num_head + 2 * gqa_group_size, dim_head}}); } else { - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + qkv_out.Resize({{token_num, 3, num_head, dim_head}}); + } + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + + // 3. fmha + phi::fusion::AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + auto pre_caches = ctx.MultiInput("PreCaches"); + int cache_offset = 0; + if (pre_caches.size() > 0) { + cache_offset = pre_caches[0]->dims()[3]; } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif - // step3. fmha - const phi::DenseTensor *cache_kv = - cache_kvs && cache_kvs.get().size() > 0 ? cache_kvs.get()[i] : nullptr; - phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; - - if (time_step_t) { // generation decoder stage - // [2, batch_size, num_head, max_seq_len, head_size] - int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask_t, - seq_lengths_t, - rotary_tensor_t, - cache_kv_out, - &fmha_out, - bsz, - max_seq_len, - num_head, - dim_head, - time_step_t->data()[0], - rotary_emb_dims, - 1. / std::sqrt(dim_head)); - } else if (cache_kv_out) { // generation context stage - const phi::DenseTensor *pre_cache_kv_tensor = - pre_caches && pre_caches.get().size() > 0 ? pre_caches.get()[i] - : nullptr; - phi::DenseTensor *pre_cache_kv_out_tmp = - cache_offset > 0 ? &pre_cache_kv_out : nullptr; - phi::DenseTensor *src_mask_tmp = - cache_offset > 0 ? &src_mask_out : nullptr; - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor_t->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? seq_lengths_t->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_EQ(time_step->place(), + phi::CPUPlace(), + phi::errors::PreconditionNotMet( + "The place of input(TimeStep) must be CPUPlace.")); + // cache_seq_len + int time_step_value = time_step->data()[0]; + PADDLE_ENFORCE_GT(time_step_value, + 0, + phi::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + phi::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; + } - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, - src_mask_t, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - const T *k_ptr = nullptr; - const T *v_ptr = nullptr; - - if (cache_offset > 0) { - // [2, bsz, num_head, cache_offset + seq_len, head_dim] - const T *kv_data = pre_cache_kv_out.data(); - k_ptr = kv_data; - int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; - v_ptr = k_ptr + k_size; + // whether to broadcast 2nd dimension for src_mask, default true + // if mask_broadcast_num_heads if False, which means src_mask shape + // will be: + // 1. [batch_size, num_head, seq_len, seq_len] for encoder + // 2. [batch_size, num_heads, 1, time_step+1] for decoder + // and do not need to broadcast num_heads dimension when calculating + // attn_mask offset in MHA + bool mask_broadcast_num_heads = true; + if (src_mask) { + if (src_mask->dims()[1] == 1) { + mask_broadcast_num_heads = true; + } else if (src_mask->dims()[1] == num_head) { + mask_broadcast_num_heads = false; } else { - // [3, bsz, num_head, seq_len, head_dim] - int64_t k_size = bsz * seq_len * num_head * dim_head; - const T *q_ptr = q_transpose_out_data; - k_ptr = kv_transpose_out_data; - v_ptr = k_ptr + k_size; + PADDLE_THROW(phi::errors::InvalidArgument( + "Unknow dimension for attn_mask, the num_head(2nd) " + "dimension is invalid, it should be 1 or num_head(%d), " + "but got %d", + num_head, + src_mask->dims()[1])); } - - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - - const int seq_len_tmp = seq_len + cache_offset; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len_tmp, - max_seq_len, - dim_head); - } else { // not generation - // TODO(wangxi): can remove dropout in inference - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor_t->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? seq_lengths_t->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(cache_kv, - src_mask_t, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif - if (pre_layer_norm) { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + phi::DenseTensor q_transpose_out, kv_transpose_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + if (encoder_remove_padding) { + phi::fusion::InitValue(dev_ctx, + q_transpose_out_data, + q_transpose_out.numel(), + static_cast(0.)); + phi::fusion::InitValue(dev_ctx, + kv_transpose_out_data, + kv_transpose_out.numel(), + static_cast(0.)); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; -#endif - // step5. ln(residual + dropout(input + bias)) - if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); - - // inplace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); - auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; -#endif - // step6. ffn matmul1 - ffn1_cublas_linear.ComputeForward(buf1, - ffn1_weights[i], - ffn1_biases.get()[i], - nullptr, - &ffn1_out, - act_method); + phi::DenseTensor softmax_out; + phi::DenseTensor qktv_out, fmha_out; -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; -#endif + // unpadding_q/unpadding_k/unpadding_v: [token_num, num_head, dim_head] + phi::DenseTensor unpadding_q, unpadding_k, unpadding_v; + phi::DenseTensor softmax_lse, seed_offset; - // step7. ffn2 matmul - if (pre_layer_norm) { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr); + unpadding_q.Resize({{token_num, num_head, dim_head}}); + if (gqa_group_size > 0) { + unpadding_k.Resize({{token_num, gqa_group_size, dim_head}}); + unpadding_v.Resize({{token_num, gqa_group_size, dim_head}}); } else { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr); + unpadding_k.Resize({{token_num, num_head, dim_head}}); + unpadding_v.Resize({{token_num, num_head, dim_head}}); } - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif - - if (pre_layer_norm) { - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + cu_seqlens_k.Resize(cu_seqlens_q.dims()); + + dev_ctx.Alloc(&unpadding_q, unpadding_q.numel() * sizeof(T)); + dev_ctx.Alloc(&unpadding_k, unpadding_k.numel() * sizeof(T)); + dev_ctx.Alloc(&unpadding_v, unpadding_v.numel() * sizeof(T)); + dev_ctx.Alloc(&cu_seqlens_k, + cu_seqlens_k.numel() * sizeof(int32_t)); + + T *attn_dropout_mask_out_data = nullptr; + T *attn_dropout_data_data = nullptr; + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + if (remove_padding) { + fmha_out.Resize({{token_num, num_head, dim_head}}); } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7.1"; -#endif - - // step8. layer norm + bias_add + residual - if (pre_layer_norm) { - // TODO(wangxi): remove dropout mask in inference - if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases.get()[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - buf0->data(), - ln_mean_data, - ln_var_data); - } else { - ffn2_fused_dropout_helper.ResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases.get()[i]->data(), - buf1->data(), - dropout_mask_out_data); - } - } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - buf1->data(), - ffn2_biases.get()[i]->data(), - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); } - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8"; -#endif + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + + auto out_linear_compute = phi::fusion::GEMMHelper( + dev_ctx, token_num, dim_embed, hidden_size, "None", false); + + // 5. ln(residual + bias) + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; if (pre_layer_norm) { - x_data = buf1->data(); - std::swap(buf0, buf1); + bias_dropout_residual_out.Resize({{token_num, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); } - } - if (encoder_remove_padding) { - if (pre_layer_norm) { - InvokeRebuildPadding(dev_ctx, - from_data, - buf0->data(), - padding_offset_data, - token_num, - dim_embed); - } else { - InvokeRebuildPadding(dev_ctx, - from_data, - buf1->data(), - padding_offset_data, - token_num, - dim_embed); + uint8_t *dropout_mask_out_data = nullptr; + + // 6. ffn matmul1 + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + // if quant weight, + // matmul weight is transposed + int dim_ffn = ffn1_weight_dim[1]; + phi::fusion::FFNHelper ffn1_helper( + dev_ctx, act_method, token_num, dim_ffn, dim_embed, "None"); + + phi::DenseTensor ffn1_out; + ffn1_out.Resize({{token_num, dim_ffn}}); + auto *ffn1_out_data = + dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + // Note(Zhengzekang): It is no need when using FP16 matmul. + phi::DenseTensor mixgemm_workspace; + char *mixgemm_workspace_data = nullptr; + + // 7. ffn act + bias + DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + int tmp_dim_ffn = dim_ffn; + if (use_glu) tmp_dim_ffn /= 2; + int8_t *ffn1_dropout_mask_data = nullptr; + ffn1_dropout_out.Resize({{token_num, tmp_dim_ffn}}); + auto *ffn1_dropout_out_data = dev_ctx.Alloc( + &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); + + // 8. ffn2 matmul + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_linear_compute = phi::fusion::GEMMHelper( + dev_ctx, token_num, dim_embed, tmp_dim_ffn, "None", false); + + // 9. ffn2 residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({{token_num, dim_embed}}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({{token_num, dim_embed}}); + auto *tmp_out_rm_padding_data = dev_ctx.Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); } - } -} + auto *tmp_out_data = + dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); -#else - -template -void FusedMultiTransformerKernel( - const Context &dev_ctx, - const DenseTensor &x, - const std::vector &ln_scales, - const std::vector &ln_biases, - const std::vector &qkv_weights, - const paddle::optional> &qkv_biases, - const paddle::optional> &cache_kvs, - const paddle::optional> &pre_caches, - const paddle::optional &rotary_tensor, - const paddle::optional &time_step, - const paddle::optional &seq_lengths, - const paddle::optional &src_mask, - const std::vector &out_linear_weights, - const paddle::optional> &out_linear_biases, - const std::vector &ffn_ln_scales, - const std::vector &ffn_ln_biases, - const std::vector &ffn1_weights, - const paddle::optional> &ffn1_biases, - const std::vector &ffn2_weights, - const paddle::optional> &ffn2_biases, - bool pre_layer_norm, - float epsilon, - float dropout_rate, - int rotary_emb_dims, - bool is_test, - const std::string &dropout_implementation, - const std::string &act_method, - bool trans_qkvw, - int ring_id, - std::vector cache_kv_outs, - DenseTensor *out) { - if (cache_kvs) { - for (size_t i = 0; i < cache_kv_outs.size(); i++) { - *(cache_kv_outs[i]) = *(cache_kvs.get()[i]); + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = input_x->data(); } - } - using U = phi::funcs::LayerNormParamType; - auto *rotary_tensor_t = rotary_tensor.get_ptr(); - auto *seq_lengths_t = seq_lengths.get_ptr(); - auto *src_mask_t = src_mask.get_ptr(); - auto *time_step_t = time_step.get_ptr(); - - // 0. input - const auto input_x_dims = x.dims(); - int bsz = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int bsz_seq = bsz * seq_len; - bool remove_padding = false; - if (seq_lengths_t) { - remove_padding = true; - } - phi::DenseTensor d_token_tensor; - phi::DenseTensor padding_offset_tensor; - phi::DenseTensor x_remove_padding; - bool encoder_remove_padding = (remove_padding && !time_step_t); - int token_num = 0; - - // remove padding in encoder - if (encoder_remove_padding) { - // just for encoder - d_token_tensor.Resize({1}); - auto *d_token_num = dev_ctx.template Alloc( - &d_token_tensor, d_token_tensor.numel() * sizeof(int)); - // alloc the max size of padding_offset_tensor - padding_offset_tensor.Resize({bsz_seq}); - dev_ctx.template Alloc(&padding_offset_tensor, - padding_offset_tensor.numel() * sizeof(int)); - InvokeGetPaddingOffset(dev_ctx, - &token_num, - d_token_num, - padding_offset_tensor.data(), - seq_lengths_t->data(), - bsz, - seq_len); - padding_offset_tensor.Resize({token_num}); - x_remove_padding.Resize({token_num, dim_embed}); - dev_ctx.template Alloc(&x_remove_padding, - x_remove_padding.numel() * sizeof(T)); - InvokeRemovePadding(dev_ctx, - x_remove_padding.data(), - x.data(), - padding_offset_tensor.data(), - token_num, - dim_embed); - } else { - token_num = bsz_seq; - } - auto *padding_offset_data = - encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - - // 1. layer norm - - auto ln_compute = - phi::fusion::AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); - phi::DenseTensor ln_mean, ln_var; - ln_mean.Resize({token_num}); - auto *ln_mean_data = - dev_ctx.template Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({token_num}); - auto *ln_var_data = - dev_ctx.template Alloc(&ln_var, ln_var.numel() * sizeof(U)); - - // 2. qkv - // x: qkv's input [batch_size, seq_len, dim_embed] - // y: qkv's weight: [3, num_head, dim_head, dim_embed] - const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool compute_bias = - qkv_biases && !qkv_biases.get().empty() && time_step_t == nullptr; - // (transA, transB, compute_bias) = (false, trans_qkvw, false) - // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we - // set compute_bias as false. - auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); - - phi::DenseTensor qkv_out; - qkv_out.Resize({token_num, 3, num_head, dim_head}); - auto *qkv_out_data = - dev_ctx.template Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); - - // 3. fmha - AttnDropoutParam attn_param( - true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - int cache_offset = 0; - if (pre_caches && pre_caches.get().size() > 0) { - cache_offset = pre_caches.get()[0]->dims()[3]; - } - - auto out_seq_len = seq_len; - if (time_step_t) { - PADDLE_ENFORCE_EQ(time_step_t->place(), - phi::CPUPlace(), - phi::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step_t->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, - 0, - phi::errors::PreconditionNotMet( - "The value of time_step_t must > 0, but now is %d", - time_step_value)); - PADDLE_ENFORCE_EQ( - seq_len, - 1, - phi::errors::PreconditionNotMet( - "In decode stage, the seq_len of input must be 1, but now is %d", - seq_len)); - out_seq_len += time_step_value; - } else { - out_seq_len += cache_offset; - } - - phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; - q_transpose_out.Resize({bsz, num_head, seq_len, dim_head}); - auto *q_transpose_out_data = dev_ctx.template Alloc( - &q_transpose_out, q_transpose_out.numel() * sizeof(T)); - - kv_transpose_out.Resize({2, bsz, num_head, seq_len, dim_head}); - auto *kv_transpose_out_data = dev_ctx.template Alloc( - &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); - - qk_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *qk_out_data = - dev_ctx.template Alloc(&qk_out, qk_out.numel() * sizeof(T)); - - phi::DenseTensor src_mask_out; - if (cache_offset > 0) { - src_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *src_mask_out_data = dev_ctx.template Alloc( - &src_mask_out, src_mask_out.numel() * sizeof(T)); - } - - // [2, bs, num_head, cache_seq_len + seq_len, head_dim] - phi::DenseTensor pre_cache_kv_out; - if (cache_offset > 0) { - pre_cache_kv_out.Resize( - {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); - auto *pre_cache_kv_out_data = dev_ctx.template Alloc( - &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); - } - - phi::DenseTensor softmax_out; - phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; - phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *softmax_out_data = - dev_ctx.template Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - - attn_dropout_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *attn_dropout_mask_out_data = dev_ctx.template Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({bsz, num_head, seq_len, out_seq_len}); - auto *attn_dropout_data_data = dev_ctx.template Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - - qktv_out.Resize({bsz, num_head, seq_len, dim_head}); - auto *qktv_out_data = - dev_ctx.template Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({bsz, seq_len, num_head, dim_head}); - auto *fmha_out_data = - dev_ctx.template Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); - - // 4. out_linear - // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, hidden_size, false); - - // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, token_num, dim_embed, dropout_param2, epsilon); - phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; - T *bias_dropout_residual_out_data = nullptr; - if (pre_layer_norm) { - bias_dropout_residual_out.Resize({token_num, dim_embed}); - bias_dropout_residual_out_data = dev_ctx.template Alloc( - &bias_dropout_residual_out, - bias_dropout_residual_out.numel() * sizeof(T)); - } - dropout_mask_out.Resize({token_num, dim_embed}); - auto *dropout_mask_out_data = dev_ctx.template Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); - - // 6. ffn matmul1 - auto ffn1_weight_dim = ffn1_weights[0]->dims(); - - int dim_ffn = ffn1_weight_dim[1]; - auto ffn1_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); - phi::DenseTensor ffn1_out; - ffn1_out.Resize({token_num, dim_ffn}); - auto *ffn1_out_data = - dev_ctx.template Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); - - // 7. ffn act + bias - DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, token_num, dim_ffn, ffn1_dropout_param); - phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({token_num, dim_ffn}); - auto *ffn1_dropout_out_data = dev_ctx.template Alloc( - &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({token_num, dim_ffn}); - auto *ffn1_dropout_mask_data = dev_ctx.template Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); - - // 8. ffn2 matmul - auto ffn2_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); - - // 9. ffn2 residual bias - DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *from_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; - phi::DenseTensor tmp_out, tmp_out_rm_padding; - tmp_out.Resize({token_num, dim_embed}); - if (encoder_remove_padding) { - tmp_out_rm_padding.Resize({token_num, dim_embed}); - auto *tmp_out_rm_padding_data = dev_ctx.template Alloc( - &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); - } - auto *tmp_out_data = - dev_ctx.template Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - - const T *x_data; - if (encoder_remove_padding) { - x_data = x_remove_padding.data(); - } else { - x_data = x.data(); - } - phi::DenseTensor *buf0 = nullptr; - phi::DenseTensor *buf1 = nullptr; - - // step0: x --> buf1 - // step1: buf1 --> buf0 - // step2: buf0 --> buf1 - int layers = qkv_weights.size(); - if (encoder_remove_padding) { - // In the case of variable lengths, the padding needs to be rebuilt - // eventually. So buf0 and buf1 do not need to be changed according to the - // pre_layer_norm and the number of layers. - buf0 = &tmp_out; - buf1 = &tmp_out_rm_padding; - } else { - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. + buf0 = &tmp_out; + buf1 = &tmp_out_rm_padding; + } else { + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { buf0 = &tmp_out; buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; } - } else { - buf0 = &tmp_out; - buf1 = out; } - } - for (int i = 0; i < layers; ++i) { - // step1. layer_norm - if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + norm_helper.Norm(x_data, + ln_scales[i], + compute_ln_bias ? ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf1); + } - // step2. qkv - const phi::DenseTensor *qkv_bias = - qkv_biases && !qkv_biases.get().empty() ? qkv_biases.get()[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step_t ? nullptr : qkv_bias; - if (!pre_layer_norm && i == 0) { - const phi::DenseTensor *tmp_input_x = - (encoder_remove_padding) ? &x_remove_padding : &x; - qkv_compute.ComputeForward( - qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); - } else { - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif + // step2. qkv + // NOTE: In decoder stage, bias is fused in fmha. In encoder stage, bias + // is fused in QKVBiasAddTransposeSplit + const phi::DenseTensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : input_x; + qkv_compute.Compute(tmp_input_x, + qkv_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + &qkv_out); + } else { + qkv_compute.Compute(buf1, + qkv_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + &qkv_out); + } - // step3. fmha - const phi::DenseTensor *cache_kv = - cache_kvs && cache_kvs.get().size() > 0 ? cache_kvs.get()[i] : nullptr; - phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; - - if (time_step_t) { // generation decoder stage - // [2, batch_size, num_head, max_seq_len, head_size] - int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask_t, - seq_lengths_t, - rotary_tensor_t, - cache_kv_out, - &fmha_out, + // step3. fmha + const phi::DenseTensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + int cache_bsz = 0; + if (cache_kv) { + cache_bsz = cache_kv->dims()[1]; + } + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + phi::fusion::fmha(dev_ctx, + qkv_out, + *qkv_bias, + src_mask, + nullptr, + sequence_lengths, + rotary_tensor, + beam_cache_offset, + cache_kv_out, + &fmha_out, + bsz, + cache_bsz, + seq_len, + max_seq_len, + num_head, + dim_head, + src_mask->dims()[3] - 1, + rotary_emb_dims, + 1. / sqrt(dim_head), + mask_broadcast_num_heads, + compute_bias, + use_neox_rotary_style, + gqa_group_size); + } else if (cache_kv_out) { // generation context stage + if (!encoder_remove_padding) { + PADDLE_THROW(phi::errors::InvalidArgument( + "encoder_remove_padding must be True, but got False")); + } + if (rotary_emb_dims != 0) { + if (gqa_group_size <= 0) { + phi::fusion::rotary_qk_variable( + dev_ctx, + qkv_out_data, + qkv_out_data, + qkv_bias->data(), + rotary_tensor->data(), + padding_offset_data, + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(), + token_num, + num_head, + seq_len, + rotary_tensor->dims()[3], + dim_head, + rotary_tensor->dims()[1]); + } else { + phi::fusion::gqa_rotary_qk_variable( + dev_ctx, + qkv_out_data, + qkv_out_data, + qkv_bias->data(), + rotary_tensor->data(), + padding_offset_data, + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(), + token_num, + num_head, + seq_len, + rotary_tensor->dims()[3], + dim_head, + gqa_group_size, + rotary_tensor->dims()[1]); + } + } + if (gqa_group_size <= 0) { + phi::fusion::qkv_transpose_split( + dev_ctx, + unpadding_q.data(), + unpadding_k.data(), + unpadding_v.data(), + qkv_out_data, + padding_offset_data, + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(), + token_num, + bsz, + num_head, + seq_len, + dim_head); + } else { + phi::fusion::gqa_qkv_transpose_split( + dev_ctx, + unpadding_q.data(), + unpadding_k.data(), + unpadding_v.data(), + qkv_out_data, + padding_offset_data, + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(), + token_num, bsz, - max_seq_len, num_head, + seq_len, dim_head, - time_step_t->data()[0], - rotary_emb_dims, - 1. / std::sqrt(dim_head)); - } else if (cache_kv_out) { // generation context stage - const phi::DenseTensor *pre_cache_kv_tensor = - pre_caches && pre_caches.get().size() > 0 ? pre_caches.get()[i] - : nullptr; - phi::DenseTensor *pre_cache_kv_out_tmp = - cache_offset > 0 ? &pre_cache_kv_out : nullptr; - phi::DenseTensor *src_mask_tmp = - cache_offset > 0 ? &src_mask_out : nullptr; - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor_t->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? seq_lengths_t->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); + gqa_group_size); + } + phi::Copy( + dev_ctx, cu_seqlens_q, cu_seqlens_k.place(), false, &cu_seqlens_k); + + // fmha_out[token_num, num_head, dim_head] + phi::FlashAttnUnpaddedKernel( + dev_ctx, + unpadding_q, + unpadding_k, + unpadding_v, + cu_seqlens_q, + cu_seqlens_k, + none /*fixed_seed_offset*/, + none /*attn_mask*/, + seq_len, + seq_len, + 1.0f / sqrt(static_cast(dim_head)), + 0.0, + true /*causal*/, + false, + true /* is_test*/, + "" /*rng_name*/, + &fmha_out, + &softmax_out, + &softmax_lse, + &seed_offset); + // Note(@RichardWooSJTU): gqa_write_cachekv do not support pre_cache + // and cache quantization + phi::fusion::gqa_write_cachekv(dev_ctx, + cache_kv_out, + unpadding_k, + unpadding_v, + padding_offset_tensor, + *sequence_lengths, + seq_len); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + phi::fusion::qkv_bias_add_transpose_split( + dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias ? qkv_bias->data() : nullptr, + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + const int *sequence_lengths_data = + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(); + // encoder_remove_padding ? sequence_lengths->data() : nullptr; + phi::fusion::rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + rotary_tensor->dims()[1], + bsz, + num_head, + seq_len, + dim_head, + use_neox_rotary_style); + } + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + + if (encoder_remove_padding) { + phi::fusion::TransposeSplit( + dev_ctx, + unpadding_q.data(), + unpadding_k.data(), + unpadding_v.data(), + q_transpose_out.data(), + kv_transpose_out.data(), + padding_offset_data, + sequence_lengths ? sequence_lengths->data() + : sequence_lengths_backup.data(), + token_num, + bsz, + num_head, + seq_len, + dim_head); + phi::Copy(dev_ctx, + cu_seqlens_q, + cu_seqlens_k.place(), + false, + &cu_seqlens_k); + + // fmha_out[token_num, num_head, dim_head] + phi::FlashAttnUnpaddedKernel( + dev_ctx, + unpadding_q, + unpadding_k, + unpadding_v, + cu_seqlens_q, + cu_seqlens_k, + none /*fixed_seed_offset*/, + none /*attn_mask*/, + seq_len, + seq_len, + 1.0f / sqrt(static_cast(dim_head)), + 0.0, + true /*causal*/, + false, + true /* is_test*/, + "" /*rng_name*/, + &fmha_out, + &softmax_out, + &softmax_lse, + &seed_offset); + } + } + if (pre_layer_norm) { + out_linear_compute.Compute(&fmha_out, + out_linear_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf1); + + phi::fusion::AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + out_linear_compute.Compute(&fmha_out, + out_linear_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf0); + + phi::fusion::AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, - src_mask_t, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - const T *k_ptr = nullptr; - const T *v_ptr = nullptr; - if (cache_offset > 0) { - // [2, bsz, num_head, cache_offset + seq_len, head_dim] - const T *kv_data = pre_cache_kv_out.data(); - k_ptr = kv_data; - int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; - v_ptr = k_ptr + k_size; + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + norm_helper.NormResidualBias( + buf1->data(), + x_data, + compute_bias ? out_linear_biases[i] : nullptr, /*skip_bias*/ + ffn_ln_scales[i], + compute_ln_bias ? ffn_ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + &bias_dropout_residual_out, + buf1); } else { - // [3, bsz, num_head, seq_len, head_dim] - int64_t k_size = bsz * seq_len * num_head * dim_head; - const T *q_ptr = q_transpose_out_data; - k_ptr = kv_transpose_out_data; - v_ptr = k_ptr + k_size; + auto *residual_data = (i == 0 ? x_data : buf1->data()); + norm_helper.NormResidualBias( + buf0->data(), + residual_data, + compute_bias ? out_linear_biases[i] : nullptr, /*skip_bias*/ + ln_scales[i], + compute_ln_bias ? ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf0, + buf1); + } + // step6. ffn matmul1 + ffn1_helper.Compute(buf1, + ffn1_weights[i], + /*weight_scale*/ nullptr, + compute_bias ? ffn1_biases[i] : nullptr, + &mixgemm_workspace, + &ffn1_out, + &ffn1_dropout_out); + + // step7. ffn2 matmul + if (pre_layer_norm) { + ffn2_linear_compute.Compute(&ffn1_dropout_out, + ffn2_weights[i], + nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf1); + } else { + ffn2_linear_compute.Compute(&ffn1_dropout_out, + ffn2_weights[i], + nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf0); } - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - const int seq_len_tmp = seq_len + cache_offset; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len_tmp, - max_seq_len, - dim_head); - } else { // not generation - // TODO(wangxi): can remove dropout in inference - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor_t->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? seq_lengths_t->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); + if (pre_layer_norm) { + phi::fusion::AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + phi::fusion::AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(cache_kv, - src_mask_t, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif + // step8. residual bias + // TODO(wangxi): remove dropout mask in inference + if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference + if (i < layers - 1) { + norm_helper.NormResidualBias( + buf1->data(), + bias_dropout_residual_out_data, + compute_bias ? ffn2_biases[i] : nullptr, /*skip_bias*/ + ln_scales[i + 1], + compute_ln_bias ? ln_biases[i + 1] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf1, + buf0); + } else { + ffn2_fused_dropout_helper.ResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + compute_bias ? ffn2_biases[i]->data() : nullptr, + buf1->data(), + dropout_mask_out_data); + } + } else { + norm_helper.NormResidualBias( + buf0->data(), + buf1->data(), + compute_bias ? ffn2_biases[i] : nullptr, /*skip_bias*/ + ffn_ln_scales[i], + compute_ln_bias ? ffn_ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf0, + buf1); + } - if (pre_layer_norm) { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); + } } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; -#endif - - // step5. ln(residual + dropout(input + bias)) - if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); - // inplace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); - auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); + if (encoder_remove_padding) { + if (pre_layer_norm) { + phi::fusion::InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + phi::fusion::InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); + } } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; -#endif - - // step6. ffn matmul1 - ffn1_linear_compute.ComputeForward( - ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; -#endif + } +}; - // step7. act bias - // TODO(wangxi): remove dropout mask in inference - fused_act_dropout_helper.DropoutActBias(dev_ctx, - ffn1_out_data, - ffn1_biases.get()[i]->data(), - act_method, - ffn1_dropout_out_data, - ffn1_dropout_mask_data); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif +} // namespace operators +} // namespace paddle - // step8. ffn matmul2 - if (pre_layer_norm) { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); - } else { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.0"; -#endif - - if (pre_layer_norm) { - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.1"; -#endif +namespace ops = paddle::operators; - // step9. residual bias - if (pre_layer_norm) { - // TODO(wangxi): remove dropout mask in inference - if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases.get()[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - buf0->data(), - ln_mean_data, - ln_var_data); - } else { - ffn2_fused_dropout_helper.ResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases.get()[i]->data(), - buf1->data(), - dropout_mask_out_data); - } - } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - buf1->data(), - ffn2_biases.get()[i]->data(), - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step9"; +#if CUDA_VERSION >= 11000 +PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer, + GPU, + ALL_LAYOUT, + ops::FusedMultiTransformerOpKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer, + GPU, + ALL_LAYOUT, + ops::FusedMultiTransformerOpKernel, + phi::dtype::float16) {} #endif - if (pre_layer_norm) { - x_data = buf1->data(); - std::swap(buf0, buf1); - } - } - if (encoder_remove_padding) { - if (pre_layer_norm) { - InvokeRebuildPadding(dev_ctx, - from_data, - buf0->data(), - padding_offset_data, - token_num, - dim_embed); - } else { - InvokeRebuildPadding(dev_ctx, - from_data, - buf1->data(), - padding_offset_data, - token_num, - dim_embed); - } - } -} -#endif // CUDA_VERSION >= 11060 - -} // namespace fusion -} // namespace phi - -PD_REGISTER_KERNEL(fused_multi_transformer, - GPU, - ALL_LAYOUT, - phi::fusion::FusedMultiTransformerKernel, - float, - phi::dtype::float16) { - kernel->InputAt(8).SetBackend(phi::Backend::CPU); -} diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 62156bee87300..5880ce1ff19c9 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -19,41 +19,25 @@ limitations under the License. */ #pragma once -#include -#include - -#include - -#include "paddle/common/flags.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/backends/dynload/cublasLt.h" -#include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/backends/gpu/gpu_dnn.h" -#include "paddle/phi/core/distributed/comm_context_manager.h" -#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" -#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" +#include +#include #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif -COMMON_DECLARE_bool(gemm_use_half_precision_compute_type); - +#include "paddle/phi/kernels/flash_attn_kernel.h" +#include "paddle/phi/kernels/funcs/load_store_util.h" +#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h" +#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h" +#include "paddle/phi/kernels/gpu/flash_attn_utils.h" namespace phi { namespace fusion { -// for debug -// #define _DEBUG_FUSED_MULTI_TRANSFORMER - template static void AllReduce(phi::DenseTensor &tensor, // NOLINT const int ring_id, @@ -65,60 +49,22 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT if (map->has(ring_id)) { paddle::distributed::ProcessGroup *pg = map->get(ring_id); - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(tensor); paddle::distributed::AllreduceOptions opts; opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); + auto task = pg->AllReduce(&tensor, tensor, opts, false, true); task->Wait(); } else { - auto dtype = phi::ToNCCLDataType(tensor.dtype()); + auto dtype = paddle::platform::ToNCCLDataType( + paddle::framework::TransToProtoVarType(tensor.dtype())); int64_t numel = tensor.numel(); const void *sendbuff = tensor.data(); auto place = ctx.GetPlace(); void *recvbuff = tensor.mutable_data(place); - gpuStream_t stream = nullptr; - paddle::platform::NCCLComm *comm = nullptr; - phi::distributed::NCCLCommContext *comm_ctx = nullptr; - - const auto &comm_context_manager = - phi::distributed::CommContextManager::GetInstance(); - - if (FLAGS_dynamic_static_unified_comm) { - // Use New Communication Library - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), - true, - phi::errors::InvalidArgument( - "You choose to use new communication library by " - "setting environment " - "variable FLAGS_dynamic_static_unified_comm True. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(ring_id))); - comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(ring_id))); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - phi::errors::Unavailable( - "NCCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - - stream = comm_ctx->GetStream(); - - VLOG(3) << "new comm_context_manager has ring_id" << ring_id; - } else { - comm = paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); - stream = ctx.stream(); - VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; - } - if (comm_ctx) { - comm_ctx->AllReduce(&tensor, tensor, ncclSum, stream); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream)); - } + auto comm = + paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream)); } #else PADDLE_THROW(phi::errors::Unimplemented( @@ -158,630 +104,71 @@ struct Masked_multihead_attention_params { // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] const T *qkv; // bias, [3, num_head, dim_head] - const T *qkv_bias; + T *qkv_bias; + // [bsz, seq_len] + const int *cum_offsets; // TODO(wangxi): optimize with input_lengths and max_input_len? // [bsz, 1, 1, time_step(cache_seq_length)+1] const T *attn_mask; + int mask_length; + // whether to broadcast num_heads(2nd) dimension for attn_mask + // in MMHA, if false, attn_mask shape should be + // [bsz, num_heads, 1, time_step(cache_seq_length)+1] + bool mask_broadcast_num_heads; // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first // v [B, num_head, max_seq_len, dim_head] - T *cache_kv; + T *cache_kv = nullptr; + // [B, max_seq_len] + const int *beam_cache_offset = nullptr; const int *sequence_lengths{nullptr}; - // The RoPE embedding, [B, 1, 1, dim_head] + // The RoPE embedding, [2, B, rotary_seq_len, 1, dim_head] // rotary_emb_dims = 1 if pos_ids_extra is null else 2 - const T *rotary_emb; + const float *rotary_emb; + int rotary_bsz; int rotary_emb_dims; + int rotary_seq_len = 1; - int batch_size; + int batch_size; // batch * beam + int beam_width; + int cache_batch_size; int num_head; int timestep; // cache_seq_length + int seq_len; int max_seq_length; + int gqa_group_size; + int gqa_num_per_partitions; + // 1.f / sqrt(Dh) float inv_sqrt_dh; -}; - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -// clang-format off - -template struct Qk_vec_ {}; -template <> struct Qk_vec_ { using Type = float; }; -template <> struct Qk_vec_ { using Type = float2; }; -template <> struct Qk_vec_ { using Type = float4; }; -template <> struct Qk_vec_ { using Type = float4; }; -template <> struct Qk_vec_ { using Type = uint32_t; }; -template <> struct Qk_vec_ { using Type = uint32_t; }; -template <> struct Qk_vec_ { using Type = uint2; }; -template <> struct Qk_vec_ { using Type = uint4; }; - -template struct K_vec_ {}; -template <> struct K_vec_ { using Type = float; }; -template <> struct K_vec_ { using Type = float2; }; -template <> struct K_vec_ { using Type = float4; }; -template <> struct K_vec_ { using Type = uint32_t; }; -template <> struct K_vec_ { using Type = uint2; }; -template <> struct K_vec_ { using Type = uint4; }; - -template struct V_vec_ {}; -template <> struct V_vec_ { using Type = float; }; -template <> struct V_vec_ { using Type = float2; }; -template <> struct V_vec_ { using Type = float4; }; -template <> struct V_vec_ { using Type = uint32_t; }; -template <> struct V_vec_ { using Type = uint2; }; -template <> struct V_vec_ { using Type = uint4; }; - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -#endif - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template struct V_vec_acum_fp32_ {}; -// template <> struct V_vec_acum_fp32_ { using Type = float; }; -// template <> struct V_vec_acum_fp32_ { using Type = float2; }; -template <> struct V_vec_acum_fp32_ { using Type = float4; }; -// template <> struct V_vec_acum_fp32_ { using Type = float2; }; -// template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; -template <> struct V_vec_acum_fp32_ { using Type = Float8_; }; -#endif - -// clang-format on - -inline __device__ float half_to_float(uint16_t h) { - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -inline __device__ float2 half2_to_float2(uint32_t v) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -inline __device__ uint32_t float2_to_half2(float2 f) { - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" - : "=r"(tmp.u32) - : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -inline __device__ float add(float a, float b) { return a + b; } - -inline __device__ float2 add(float2 a, float2 b) { - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -inline __device__ float4 add(float4 a, float4 b) { - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -inline __device__ uint16_t add(uint16_t a, uint16_t b) { - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -inline __device__ uint32_t add(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -inline __device__ uint2 add(uint2 a, uint2 b) { - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -inline __device__ uint4 add(uint4 a, uint4 b) { - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -inline __device__ float2 add(uint32_t a, float2 fb) { - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -inline __device__ Float8_ add(uint4 a, Float8_ fb) { - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -template -inline __device__ Acc mul(A a, B b); - -template <> -inline __device__ float mul(float a, float b) { - return a * b; -} - -template <> -inline __device__ float2 mul(float2 a, float2 b) { - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -template <> -inline __device__ float4 mul(float4 a, float4 b) { - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -template <> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) { - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -template <> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -template <> -inline __device__ uint2 mul(uint2 a, uint2 b) { - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -template <> -inline __device__ uint4 mul(uint4 a, uint4 b) { - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -template <> -inline __device__ uint32_t mul(uint32_t a, float b) { - float2 tmp = half2_to_float2(a); - float2 tmp_res; - tmp_res.x = tmp.x * b; - tmp_res.y = tmp.y * b; - uint32_t res = float2_to_half2(tmp_res); - return res; -} - -template <> -inline __device__ float2 mul(uint32_t a, float b) { - float2 tmp = half2_to_float2(a); - float2 res; - res.x = tmp.x * b; - res.y = tmp.y * b; - return res; -} - -template <> -inline __device__ uint2 mul(uint2 a, float b) { - uint2 res; - res.x = mul(a.x, b); - res.y = mul(a.y, b); - return res; -} - -template <> -inline __device__ uint4 mul(uint4 a, float b) { - uint4 res; - res.x = mul(a.x, b); - res.y = mul(a.y, b); - res.z = mul(a.z, b); - res.w = mul(a.w, b); - return res; -} - -template <> -inline __device__ float2 mul(float2 a, float b) { - float2 res; - res.x = a.x * b; - res.y = a.y * b; - return res; -} - -template <> -inline __device__ float2 mul(float2 a, uint32_t b) { - float2 tmp_b = half2_to_float2(b); - float2 res; - res.x = a.x * tmp_b.x; - res.y = a.y * tmp_b.y; - return res; -} - -template <> -inline __device__ float4 mul(float4 a, float b) { - float4 res; - res.x = a.x * b; - res.y = a.y * b; - res.z = a.z * b; - res.w = a.w * b; - return res; -} - -template -inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left, - Qk_vec input_right, - Qk_vec cos_emb, - Qk_vec sin_emb, - float alpha) { - Qk_vec res1 = mul(input_left, cos_emb); - Qk_vec res2 = mul(input_right, sin_emb); - res2 = mul(res2, alpha); - return add(res1, res2); -} - -inline __device__ float sum(float v) { return v; } -inline __device__ float sum(float2 v) { return v.x + v.y; } -inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -inline __device__ float sum(uint16_t v) { return half_to_float(v); } -inline __device__ float sum(uint32_t v) { - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -inline __device__ float sum(uint2 v) { - uint32_t c = add(v.x, v.y); - return sum(c); -} - -inline __device__ float sum(uint4 v) { - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); - return sum(c); -} - -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); -} - -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); -} - -inline __device__ constexpr uint32_t shfl_mask(int threads) { - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -template -inline __device__ __host__ T div_up(T m, T n) { - return (m + n - 1) / n; -} - -inline __device__ float fma(float a, float b, float c) { return a * b + c; } - -inline __device__ float2 fma(float2 a, float2 b, float2 c) { - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -inline __device__ float2 fma(float2 a, uint32_t b, float2 c) { - float2 tmp_b = half2_to_float2(b); - float2 d = fma(a, tmp_b, c); - return d; -} - -inline __device__ float4 fma(float4 a, float4 b, float4 c) { - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(d) - : "r"(a), "r"(b), "r"(c)); - return d; -} - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -inline __device__ float2 fma(float a, float2 b, float2 c) { - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -inline __device__ float4 fma(float a, float4 b, float4 c) { - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -inline __device__ uint32_t h0_h0(uint16_t a) { - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { - return fma(h0_h0(a), b, c); -} -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -inline __device__ float cast_to_float(float u) { return u; } - -inline __device__ float2 cast_to_float(float2 u) { return u; } - -inline __device__ float4 cast_to_float(float4 u) { return u; } - -inline __device__ Float8_ cast_to_float(uint4 u) { - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - K_vec inv_q = mul(q[0], inv_sqrt_dh); - K_vec qk_vec = mul(inv_q, k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); - } - - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { - float4 c; - float zero = 0.f; - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], - const uint32_t (&k)[N], - float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); - K_vec_acum qk_vec = mul(inv_q, k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - return qk_dot_(q, k, inv_sqrt_dh); - } -}; - -template <> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], - const uint32_t (&k)[N], - float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 - return qk_hmma_dot_(q, k, inv_sqrt_dh); -#else - return qk_dot_<4>(q, k, inv_sqrt_dh); -#endif - } + bool add_qkv_bias; + bool neox_rotary_style; }; -template -inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE_T; - int lane = threadIdx.x % WARP_SIZE_T; - -#pragma unroll - for (int mask = WARP_SIZE_T / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - if (lane == 0) { - red_smem[warp] = sum; - } - __syncthreads(); - - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - return __shfl_sync(uint32_t(-1), sum, 0); -} - -inline __device__ void convert_from_float(float &dst, float src) { // NOLINT - dst = src; -} - -inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT - dst = src; -} - -inline __device__ void convert_from_float(phi::dtype::float16 &dst, // NOLINT - float src) { - dst = static_cast(src); -} - -inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT - -template -inline __device__ void zero(T &dst) { // NOLINT - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - template + int THREADS_PER_BLOCK, + typename LoadFunc, + typename StoreFunc> __global__ void masked_multihead_attention_kernel( - Masked_multihead_attention_params params) { + Masked_multihead_attention_params params, + LoadFunc load_func, + StoreFunc store_func) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - typedef PDDataTypeTraits traits_; + const int bi = blockIdx.y; + if (params.sequence_lengths && params.sequence_lengths[bi] == 0) { + return; + } + + typedef phi::PDDataTypeTraits traits_; typedef typename traits_::DataType DataType_; static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); @@ -802,13 +189,24 @@ __global__ void masked_multihead_attention_kernel( __shared__ float red_smem[WARPS_PER_BLOCK * 2]; using Qk_vec = typename Qk_vec_::Type; + using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - const int bi = blockIdx.y; + // beam id + const int beami = bi % params.beam_width; + // real batch id + const int bbi = bi / params.beam_width; const int hi = blockIdx.x; + const int kv_hi = hi / params.gqa_num_per_partitions; const int bhi = bi * params.num_head + hi; + const int bbhi = bbi * params.beam_width * params.num_head + hi; + const int ti = + params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; + const int thi = params.cum_offsets ? ti * params.num_head + hi : -1; const int tid = threadIdx.x; + const int bi_seq_len_offset = bi * params.max_seq_length; + float qk_max = -FLT_MAX; float qk = 0; @@ -816,8 +214,8 @@ __global__ void masked_multihead_attention_kernel( ? params.timestep : params.sequence_lengths[bi]; - // qkv [B, S=1, 3, num_head, head_dim] - int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; + // qkv [B, S=1, num_head + 2 * gqa_group_size, head_dim] + int qkv_base_offset = bi * (params.num_head + 2 * params.gqa_group_size) * Dh; constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); @@ -830,110 +228,129 @@ __global__ void masked_multihead_attention_kernel( constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - const T *q_base = params.qkv; - const T *k_base = params.qkv + params.num_head * Dh; - const T *q_bias_base = params.qkv_bias; - const T *k_bias_base = params.qkv_bias + params.num_head * Dh; + // const T *q_base = params.qkv; + // const T *k_base = params.qkv + params.num_head * Dh; + T *q_bias_base = nullptr; + T *k_bias_base = nullptr; + + if (params.add_qkv_bias) { + q_bias_base = params.qkv_bias; + k_bias_base = params.qkv_bias + params.num_head * Dh; + } if (tid < QK_VECS_PER_WARP) { int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; - int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + const int q_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + const int k_bias_offset = kv_hi * Dh + tid * QK_VEC_SIZE; Qk_vec q; zero(q); - q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_base[qk_offset]) - : q; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(q, qk_offset + hi * Dh); + } + Qk_vec k; zero(k); - k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_base[qk_offset]) - : k; - - Qk_vec q_bias; - zero(q_bias); - q_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) - : q_bias; - Qk_vec k_bias; - zero(k_bias); - k_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) - : k_bias; - - q = add(q, q_bias); - // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 - // we may not require k_bias. - k = add(k, k_bias); - - // rotary pos emb - if (params.rotary_emb_dims != 0) { - int last_dim = Dh / params.rotary_emb_dims; - int half_lastdim = last_dim / 2; - int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; - const T *cos_base = params.rotary_emb; - const T *sin_base = params.rotary_emb + params.batch_size * Dh; - int stride = half_lastdim / QK_VEC_SIZE; - int stride_all_lastdim = 2 * stride; - int right_id = tid / stride_all_lastdim * stride_all_lastdim + - (tid + stride) % (stride_all_lastdim); - int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; - int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; - Qk_vec q_right; - zero(q_right); - q_right = - (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_base[qk_right_offset]) - : q_right; - Qk_vec k_right; - zero(k_right); - k_right = - (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_base[qk_right_offset]) - : k_right; - - Qk_vec q_right_bias; - zero(q_right_bias); - q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &q_bias_base[qk_right_bias_offset]) - : q_right_bias; - Qk_vec k_right_bias; - zero(k_right_bias); - k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &k_bias_base[qk_right_bias_offset]) - : k_right_bias; - - q_right = add(q_right, q_right_bias); - k_right = add(k_right, k_right_bias); - - Qk_vec cos_emb; - zero(cos_emb); - cos_emb = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&cos_base[rotary_offset]) - : cos_emb; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load( + k, params.num_head * Dh + qk_offset + kv_hi * Dh); + } - Qk_vec sin_emb; - zero(sin_emb); - sin_emb = + if (params.add_qkv_bias) { + Qk_vec q_bias; + zero(q_bias); + Qk_vec k_bias; + zero(k_bias); + + q_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_bias_base[q_bias_offset]) + : q_bias; + k_bias = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&sin_base[rotary_offset]) - : sin_emb; - float alpha = (tid % stride_all_lastdim) < stride ? static_cast(-1) - : static_cast(1); - q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha); - k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha); + ? *reinterpret_cast(&k_bias_base[k_bias_offset]) + : k_bias; + + q = add(q, q_bias); + // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 + // we may not require k_bias. + k = add(k, k_bias); + } + + if (!params.neox_rotary_style) { + if (params.rotary_emb_dims != 0) { + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.rotary_bsz * Dh; + Qk_vec_RoPE cos_emb, sin_emb; + zero(cos_emb); + zero(sin_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + apply_rotary_embedding(q, k, cos_emb, sin_emb); + } + } else { + /* old rotary pos emb */ + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.rotary_bsz * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int q_right_offset = qkv_base_offset + hi * Dh + right_id * QK_VEC_SIZE; + int k_right_offset = qkv_base_offset + params.num_head * Dh + + kv_hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load(q_right, q_right_offset); + } + Qk_vec k_right; + zero(k_right); + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load(k_right, k_right_offset); + } + + Qk_vec_RoPE cos_emb; + zero(cos_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + + Qk_vec_RoPE sin_emb; + zero(sin_emb); + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride + ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb( + q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb( + k, k_right, cos_emb, sin_emb, alpha); + } } *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; int co = tid / QK_VECS_IN_16B; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; - int offset = bhi * params.max_seq_length * Dh + + + int offset = bi * params.gqa_group_size * params.max_seq_length * Dh + + kv_hi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + act_time_step * QK_ELTS_IN_16B + ci; if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { @@ -955,10 +372,6 @@ __global__ void masked_multihead_attention_kernel( qk = block_sum(&red_smem[WARPS_PER_RED], qk); } if (tid == 0) { - // NOTE(wangxi): mask must be 0.0 - // T mask = params.attn_mask[ - // bi * (params.timestep + 1) + params.timestep]; - // qk += static_cast(mask); qk *= params.inv_sqrt_dh; qk_max = qk; qk_smem[act_time_step] = qk; @@ -966,11 +379,12 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - VLOG(0) << "=======q_out=======\n"; - for (int i = 0; i < Dh; ++i) VLOG(0) << static_cast(q_smem[i]); - } - __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("=======q_out=======\n"); + // for (int i = 0; i < Dh; ++i) printf("%f ", + // static_cast(q_smem[i])); printf("\n"); + // } + // __syncthreads(); #endif using K_vec = typename K_vec_::Type; @@ -994,7 +408,10 @@ __global__ void masked_multihead_attention_kernel( constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE_TMP / THREADS_PER_KEY; - T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + T *k_cache = + ¶ms.cache_kv[bi * params.gqa_group_size * params.max_seq_length * Dh + + kv_hi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { @@ -1004,6 +421,7 @@ __global__ void masked_multihead_attention_kernel( #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * params.max_seq_length + ti; + // get k from the cache_kv, and dequant k for qk operation if (ti < act_time_step) { k[ii] = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) @@ -1020,8 +438,11 @@ __global__ void masked_multihead_attention_kernel( // bool is_mask = false; if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; - qk += static_cast(mask); + auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi; + if (params.attn_mask) { + T mask = params.attn_mask[mask_bhi * params.mask_length + ti]; + qk += static_cast(mask); + } qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; @@ -1051,12 +472,12 @@ __global__ void masked_multihead_attention_kernel( qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("=======qk_out=======\n"); - for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); - printf("qk_max=%f\n", qk_max); - } - __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("=======qk_out=======\n"); + // for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); + // printf("qk_max=%f\n", qk_max); + // } + // __syncthreads(); #endif float sum = 0.f; @@ -1083,9 +504,11 @@ __global__ void masked_multihead_attention_kernel( int vo = tid / THREADS_PER_VALUE; int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; - T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * - params.max_seq_length * Dh + - bhi * params.max_seq_length * Dh + vi]; + T *v_cache = + ¶ms.cache_kv[params.cache_batch_size * params.gqa_group_size * + params.max_seq_length * Dh + + bi * params.gqa_group_size * params.max_seq_length * Dh + + kv_hi * params.max_seq_length * Dh + vi]; #ifdef MMHA_USE_FP32_ACUM_FOR_OUT using V_vec_acum = typename V_vec_acum_fp32_::Type; @@ -1099,7 +522,8 @@ __global__ void masked_multihead_attention_kernel( constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; if (Dh == Dh_MAX || vi < Dh) { for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { - V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); + V_vec v; + v = *reinterpret_cast(&v_cache[ti * Dh]); #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti]; out = fma(logit, cast_to_float(v), out); @@ -1112,22 +536,29 @@ __global__ void masked_multihead_attention_kernel( } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("======logits_out=====\n"); - for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); - printf("\n"); - } - __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("======logits_out=====\n"); + // for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); + // printf("\n"); + // } + // __syncthreads(); #endif V_vec v_bias; zero(v_bias); if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { - V_vec v = *reinterpret_cast( - ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); - v_bias = *reinterpret_cast( - ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); - v = add(v, v_bias); + V_vec v; + load_func.template load(v, + params.num_head * Dh + + params.gqa_group_size * Dh + + qkv_base_offset + kv_hi * Dh + vi); + if (params.add_qkv_bias) { + v_bias = *reinterpret_cast( + ¶ms.qkv_bias[(params.num_head + params.gqa_group_size) * Dh + + kv_hi * Dh + vi]); + v = add(v, v_bias); + } + *reinterpret_cast(&v_cache[act_time_step * Dh]) = v; #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) @@ -1165,21 +596,24 @@ __global__ void masked_multihead_attention_kernel( if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), - out); + V_vec tmp_out; + convert_from_float(tmp_out, out); + store_func.template store(tmp_out, + thi != -1 ? thi * Dh + vi : bhi * Dh + vi); #else - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; + store_func.template store(out, + thi != -1 ? thi * Dh + vi : bhi * Dh + vi); #endif } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - __syncthreads(); - if (bi == 0 && hi == 0 && tid == 0) { - printf("======fmha_out=====\n"); - for (int i = 0; i < Dh; ++i) - printf("%f ", static_cast(params.out[i])); - printf("\n"); - } + // __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("======fmha_out=====\n"); + // for (int i = 0; i < Dh; ++i) + // printf("%f ", static_cast(params.out[i])); + // printf("\n"); + // } #endif #else assert(false); @@ -1208,34 +642,122 @@ inline size_t smem_size_in_bytes( return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template +#define MMHA_LAUNCH_KERNEL(T, \ + Dh, \ + Dh_MAX, \ + THDS_PER_KEY, \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + stream, \ + load_func, \ + store_func) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_head, params.batch_size); \ + constexpr auto kernel_fn = \ + masked_multihead_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + kernel_fn<<>>( \ + params, load_func, store_func); + +template void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, - const cudaStream_t &stream) { + const cudaStream_t &stream, + LoadFunc load_func, + StoreFunc store_func) { constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; if (params.timestep < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); + MMHA_LAUNCH_KERNEL( + T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream, load_func, store_func); } else if (params.timestep < 2048) { #if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ __CUDA_ARCH__ >= 750 - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func); #else - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 2, + THREADS_PER_VALUE, + 128, + stream, + load_func, + store_func); #endif } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 1, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func); + } +} + +template +void fmha_impl(const phi::GPUContext &dev_ctx, + const Masked_multihead_attention_params ¶ms, + int dim_head, + LoadFunc load_func, + StoreFunc store_func) { + switch (dim_head) { + case 10: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 26: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 32: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 64: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 96: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 128: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 192: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + default: + PADDLE_THROW( + phi::errors::Unimplemented("Dim_head = %d is unsupport!", dim_head)); } } @@ -1243,102 +765,95 @@ template void fmha(const phi::GPUContext &dev_ctx, const phi::DenseTensor &qkv_tensor, const phi::DenseTensor &qkv_bias_tensor, - const phi::DenseTensor &src_mask_tensor, + const phi::DenseTensor *src_mask_tensor, + const phi::DenseTensor *cum_offsets_tensor, const phi::DenseTensor *sequence_lengths_tensor, const phi::DenseTensor *rotary_tensor, + const phi::DenseTensor *beam_cache_offset_tensor, phi::DenseTensor *cache_kv_tensor, phi::DenseTensor *out_tensor, int batch_size, + int cache_batch_size, + int seq_len, int max_seq_length, int num_head, int dim_head, int timestep, int rotary_emb_dims, - float inv_sqrt_dh) { + float inv_sqrt_dh, + const bool mask_broadcast_num_heads = true, + const bool add_qkv_bias = true, + const bool neox_rotary_style = false, + const int gqa_group_size = -1) { Masked_multihead_attention_params params; - params.out = out_tensor->data(); - params.qkv = qkv_tensor.data(); - params.qkv_bias = qkv_bias_tensor.data(); - params.attn_mask = src_mask_tensor.data(); + // params.out = out_tensor->data(); + // params.qkv = qkv_tensor.data(); + + if (add_qkv_bias) { + // Because we may not add qkv_bias, so here we cast to T*. + // Author(zhengzekang). + params.qkv_bias = const_cast(qkv_bias_tensor.data()); + } + params.mask_broadcast_num_heads = mask_broadcast_num_heads; params.cache_kv = cache_kv_tensor->data(); + params.neox_rotary_style = neox_rotary_style; + if (src_mask_tensor) { + params.attn_mask = src_mask_tensor->data(); + params.mask_length = src_mask_tensor->dims()[3]; + } else { + params.attn_mask = nullptr; + params.mask_length = -1; + } + if (sequence_lengths_tensor) { params.sequence_lengths = sequence_lengths_tensor->data(); } + if (cum_offsets_tensor) { + params.cum_offsets = cum_offsets_tensor->data(); + } else { + params.cum_offsets = nullptr; + } + params.seq_len = seq_len; + if (rotary_emb_dims > 0) { - params.rotary_emb = rotary_tensor->data(); + params.rotary_emb = rotary_tensor->data(); + params.rotary_bsz = rotary_tensor->dims()[1]; } else { params.rotary_emb = nullptr; + params.rotary_bsz = 0; + } + + if (beam_cache_offset_tensor) { + params.beam_cache_offset = beam_cache_offset_tensor->data(); + params.beam_width = beam_cache_offset_tensor->dims()[1]; + } + + if (gqa_group_size > 0) { + params.gqa_group_size = gqa_group_size; + params.gqa_num_per_partitions = num_head / gqa_group_size; + } else { + params.gqa_group_size = num_head; + params.gqa_num_per_partitions = 1; } + VLOG(1) << "gqa_group_size " << params.gqa_group_size; + VLOG(1) << "gqa_num_per_partitions " << params.gqa_num_per_partitions; + + params.add_qkv_bias = add_qkv_bias; params.batch_size = batch_size; + params.cache_batch_size = cache_batch_size; params.num_head = num_head; params.timestep = timestep; params.max_seq_length = max_seq_length; params.inv_sqrt_dh = inv_sqrt_dh; params.rotary_emb_dims = rotary_emb_dims; - switch (dim_head) { - case 10: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 26: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 32: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 64: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - // opt model - case 80: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 96: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 128: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 192: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - default: - PADDLE_THROW( - phi::errors::Unimplemented("Dim_head = %d is unsupport!", dim_head)); - } -} - -template -void fmha(const phi::GPUContext &dev_ctx, - const phi::DenseTensor &qkv_tensor, - const phi::DenseTensor &qkv_bias_tensor, - const phi::DenseTensor &src_mask_tensor, - phi::DenseTensor *cache_kv_tensor, - phi::DenseTensor *out_tensor, - int batch_size, - int max_seq_length, - int num_head, - int dim_head, - int timestep, - float inv_sqrt_dh) { - fmha(dev_ctx, - qkv_tensor, - qkv_bias_tensor, - src_mask_tensor, - nullptr, - nullptr, - cache_kv_tensor, - out_tensor, - batch_size, - max_seq_length, - num_head, - dim_head, - timestep, - 0, - inv_sqrt_dh); + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data()); + fmha_impl( + dev_ctx, params, dim_head, load_func, store_func); } // NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 @@ -1446,6 +961,221 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, cache_v, v, num_head, dim_head, seq_len, max_seq_len); } +template +__global__ void gqa_write_cache_k_kernel(T *cache_k, + const T *k, + const int *seq_lens, + const int *padding_offsets, + const int gqa_group_size, + const int max_seq_len, + const int seq_len, + const int dim_head, + const int64_t num_elems) { + phi::AlignedVector in_vec; + + for (int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * X_ELEMS; + linear_idx < num_elems; + linear_idx += blockDim.x * gridDim.x * X_ELEMS) { + const int hidden_size = gqa_group_size * dim_head; + const int token_idx = linear_idx / hidden_size; + const int head_idx = (linear_idx % hidden_size) / dim_head; + const int head_offset = linear_idx % dim_head; + const int head_vec_id = head_offset / X_ELEMS; + const int ori_token_id = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_id / seq_len; + + if (seq_lens[ori_bi] == 0) continue; + + const int local_token_id = ori_token_id % seq_len; + + const int tgt_idx = ori_bi * gqa_group_size * max_seq_len * dim_head + + head_idx * max_seq_len * dim_head + + head_vec_id * max_seq_len * X_ELEMS + + local_token_id * X_ELEMS; + + phi::Load(&k[linear_idx], &in_vec); + phi::Store(in_vec, &cache_k[tgt_idx]); + } +} + +template +__global__ void gqa_write_cache_v_kernel(T *cache_v, + const T *v, + const int *seq_lens, + const int *padding_offsets, + const int gqa_group_size, + const int max_seq_len, + const int seq_len, + const int dim_head, + const int64_t num_elems) { + phi::AlignedVector in_vec; + + for (int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * X_ELEMS; + linear_idx < num_elems; + linear_idx += blockDim.x * gridDim.x * X_ELEMS) { + const int hidden_size = gqa_group_size * dim_head; + const int token_idx = linear_idx / hidden_size; + const int head_idx = (linear_idx % hidden_size) / dim_head; + const int head_offset = linear_idx % dim_head; + const int ori_token_id = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_id / seq_len; + + if (seq_lens[ori_bi] == 0) continue; + + const int local_token_id = ori_token_id % seq_len; + + const int tgt_idx = ori_bi * gqa_group_size * max_seq_len * dim_head + + head_idx * max_seq_len * dim_head + + local_token_id * dim_head + head_offset; + + phi::Load(&v[linear_idx], &in_vec); + phi::Store(in_vec, &cache_v[tgt_idx]); + } +} + +template +void gqa_write_cachekv( + const phi::GPUContext &dev_ctx, + phi::DenseTensor *cache_kv_out, // [2, cache_bsz, gqa_group_size, + // max_seq_len, dim_head] k need + const phi::DenseTensor + &unpadding_k, // [token_num, gqa_group_size, dim_head] + const phi::DenseTensor &unpadding_v, + const phi::DenseTensor &padding_offsets, + const phi::DenseTensor &seq_lens, + const int seq_len) { + constexpr int block_sz = 128; + constexpr int x = VEC_16B / sizeof(T); + + const int cache_bsz = cache_kv_out->dims()[1]; + const int gqa_group_size = cache_kv_out->dims()[2]; + const int max_seq_len = cache_kv_out->dims()[3]; + const int dim_head = cache_kv_out->dims()[4]; + + assert(dim_head % x == 0); + PADDLE_ENFORCE_EQ( + dim_head % x, + 0, + phi::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); + + const int64_t num_elems = unpadding_k.numel(); + + T *cache_k = cache_kv_out->data(); + T *cache_v = cache_k + cache_bsz * gqa_group_size * max_seq_len * dim_head; + + int grid_size; + GetNumBlocks(num_elems, &grid_size); + + gqa_write_cache_k_kernel<<>>( + cache_k, + unpadding_k.data(), + seq_lens.data(), + padding_offsets.data(), + gqa_group_size, + max_seq_len, + seq_len, + dim_head, + num_elems); + gqa_write_cache_v_kernel<<>>( + cache_v, + unpadding_v.data(), + seq_lens.data(), + padding_offsets.data(), + gqa_group_size, + max_seq_len, + seq_len, + dim_head, + num_elems); +} + +template +__global__ void fusedQKV_transpose_split_kernel(T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + const int32_t write_idx = + token_idx * hidden_size + head_id * size_per_head + size_id; + if (qkv_id == 0) { + phi::Store(src_vec, &q_buf[write_idx]); + } else if (qkv_id == 1) { + phi::Store(src_vec, &k_buf[write_idx]); + } else { + phi::Store(src_vec, &v_buf[write_idx]); + } + } +} + +template +void qkv_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + phi::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + fusedQKV_transpose_split_kernel + <<>>(q_buf, + k_buf, + v_buf, + qkv, + padding_offset, + seq_lens, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); +} + template __global__ void add_fusedQKV_bias_transpose_split_kernel( T *q_buf, @@ -1581,10 +1311,151 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, } } +template +__global__ void gqa_fusedQKV_transpose_split_kernel(T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head, + const int gqa_group_size) { + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + + const int fused_hidden_size = (head_num + 2 * gqa_group_size) * size_per_head; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + + const int32_t head_id = bias_idx / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + // [token_num, num_head or gqa_group_size, size_per_head] + if (head_id < head_num) { + const int32_t write_idx = token_idx * head_num * size_per_head + + head_id * size_per_head + size_id; + phi::Store(src_vec, &q_buf[write_idx]); + } else { + if (head_id < head_num + gqa_group_size) { + const int32_t write_idx = token_idx * gqa_group_size * size_per_head + + (head_id - head_num) * size_per_head + + size_id; + phi::Store(src_vec, &k_buf[write_idx]); + } else { + const int32_t write_idx = + token_idx * gqa_group_size * size_per_head + + (head_id - head_num - gqa_group_size) * size_per_head + size_id; + phi::Store(src_vec, &v_buf[write_idx]); + } + } + } +} + +template +void gqa_qkv_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head, + const int gqa_group_size) { + const int32_t elem_cnt = + token_num * (head_num + 2 * gqa_group_size) * size_per_head; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + phi::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + gqa_fusedQKV_transpose_split_kernel + <<>>(q_buf, + k_buf, + v_buf, + qkv, + padding_offset, + seq_lens, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head, + gqa_group_size); +} + +/* old rope emb */ +template +__global__ void NeoXRotaryKernel(const T *input, + const float *cos_emb, + const float *sin_emb, + const int *sequence_lengths, + T *output, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int last_dim) { + int bi = blockIdx.x; + int hi = blockIdx.y; + int si = blockIdx.z; + if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return; + int half_lastdim = last_dim / 2; + for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { + int base_idx = bi * head_num * seq_len * last_dim + + hi * seq_len * last_dim + si * last_dim; + int left_idx = base_idx + ti; + const int right_idx = base_idx + ti + half_lastdim; + int emb_idx_left = bi * seq_len * last_dim + si * last_dim + ti; + int emb_idx_right = + bi * seq_len * last_dim + si * last_dim + ti + half_lastdim; + float input_left = static_cast(input[left_idx]); + float input_right = static_cast(input[right_idx]); + + float cos_tmp_left = cos_emb[emb_idx_left]; + float sin_tmp_left = sin_emb[emb_idx_left]; + float cos_tmp_right = cos_emb[emb_idx_right]; + float sin_tmp_right = sin_emb[emb_idx_right]; + + T res1 = + static_cast(input_left * cos_tmp_left - input_right * sin_tmp_left); + T res2 = static_cast(input_right * cos_tmp_right + + input_left * sin_tmp_right); + output[left_idx] = res1; + output[right_idx] = res2; + } +} + template -__global__ void RotrayKernel(const T *input, - const T *cos_emb, - const T *sin_emb, +__global__ void RotaryKernel(const T *input, + const float *cos_emb, + const float *sin_emb, const int *sequence_lengths, T *output, const int rotary_emb_dims, @@ -1602,15 +1473,15 @@ __global__ void RotrayKernel(const T *input, for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { int base_idx = bi * head_num * seq_len * last_dim + hi * seq_len * last_dim + si * last_dim; - int left_idx = base_idx + ti; - const int right_idx = base_idx + ti + half_lastdim; - int emb_idx = bi * seq_len * last_dim + si * last_dim + ti; - T input_left = input[left_idx]; - T input_right = input[right_idx]; - T cos_tmp = cos_emb[emb_idx]; - T sin_tmp = sin_emb[emb_idx]; - T res1 = input_left * cos_tmp - input_right * sin_tmp; - T res2 = input_right * cos_tmp + input_left * sin_tmp; + int left_idx = base_idx + 2 * ti; + const int right_idx = base_idx + 2 * ti + 1; + int emb_idx = bi * seq_len * last_dim + si * last_dim + 2 * ti; + float input_left = static_cast(input[left_idx]); + float input_right = static_cast(input[right_idx]); + float cos_tmp = cos_emb[emb_idx]; + float sin_tmp = sin_emb[emb_idx]; + T res1 = static_cast(input_left * cos_tmp - input_right * sin_tmp); + T res2 = static_cast(input_right * cos_tmp + input_left * sin_tmp); output[left_idx] = res1; output[right_idx] = res2; } @@ -1622,13 +1493,15 @@ void rotary_qk(const phi::GPUContext &dev_ctx, T *k, // kv const T *q_input, // q const T *k_input, // kv - const T *rotary_emb, + const float *rotary_emb, const int *sequence_lengths, const int rotary_emb_dims, + const int rope_bsz, const int batch_size, const int head_num, const int seq_len, - const int dim_head) { + const int dim_head, + const bool neox_rotary_style) { // q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] // kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, @@ -1651,34 +1524,60 @@ void rotary_qk(const phi::GPUContext &dev_ctx, } }; int BlockSize = getBlockSize(last_dim / 2); - const T *cos_emb = rotary_emb; - const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head; - RotrayKernel<<>>( - q_input, - cos_emb, - sin_emb, - sequence_lengths, - q, - rotary_emb_dims, - batch_size, - head_num, - seq_len * rotary_emb_dims, - last_dim); - RotrayKernel<<>>( - k_input, - cos_emb, - sin_emb, - sequence_lengths, - k, - rotary_emb_dims, - batch_size, - head_num, - seq_len * rotary_emb_dims, - last_dim); + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + batch_size * seq_len * dim_head; + if (!neox_rotary_style) { + RotaryKernel<<>>( + q_input, + cos_emb, + sin_emb, + sequence_lengths, + q, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + RotaryKernel<<>>( + k_input, + cos_emb, + sin_emb, + sequence_lengths, + k, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + } else { + NeoXRotaryKernel<<>>( + q_input, + cos_emb, + sin_emb, + sequence_lengths, + q, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + NeoXRotaryKernel<<>>( + k_input, + cos_emb, + sin_emb, + sequence_lengths, + k, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + } } __global__ void GetPaddingOffset(int *d_token_num, int *padding_offset, + int *cu_seqlens_data, const int *sequence_lengths, const int batch_size, const int max_seq_len) { @@ -1686,6 +1585,7 @@ __global__ void GetPaddingOffset(int *d_token_num, int total_seq_len = 0; int cum_offset = 0; int index = 0; + cu_seqlens_data[0] = 0; for (int i = 0; i < batch_size; i++) { const int seq_len = sequence_lengths[i]; for (int j = 0; j < seq_len; j++) { @@ -1694,6 +1594,7 @@ __global__ void GetPaddingOffset(int *d_token_num, } cum_offset += max_seq_len - seq_len; total_seq_len += seq_len; + cu_seqlens_data[i + 1] = cu_seqlens_data[i] + seq_len; } d_token_num[0] = total_seq_len; } @@ -1702,11 +1603,16 @@ void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx, int *h_token_num, int *d_token_num, int *padding_offset, + int *cu_seqlens_data, const int *sequence_lengths, const int batch_size, const int max_seq_len) { - GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>( - d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len); + GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>(d_token_num, + padding_offset, + cu_seqlens_data, + sequence_lengths, + batch_size, + max_seq_len); phi::memory_utils::Copy(phi::CPUPlace(), h_token_num, dev_ctx.GetPlace(), @@ -1771,249 +1677,588 @@ void InvokeRebuildPadding(const phi::GPUContext &dev_ctx, output_data, input_data, padding_offset, dim_embed); } -#if CUDA_VERSION >= 11060 -// Only Used in Inference -template -class CublasFusedMLP { - public: - // (m, n, k) = bsz_seq, hidden_feature, in_feature - explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) { - cudaDataType_t mat_type = CUDA_R_32F; - cudaDataType_t scale_type = CUDA_R_32F; - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - if (std::is_same::value) { - mat_type = CUDA_R_16F; - if (FLAGS_gemm_use_half_precision_compute_type) { - // This option default value is true, it tends to result NaN, but get - // better inference speed. you can turn off by using `export - // FLAGS_gemm_use_half_precision_compute_type=0`. - compute_type = CUBLAS_COMPUTE_16F; - scale_type = CUDA_R_16F; - } +template +__global__ void InitOutValueKernel(T *output_data, + const int64_t numel, + const T init_value) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + int64_t global_thread_idx = bid * blockDim.x + tid; + + for (int linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < numel; + linear_index += step) { + for (int i = 0; i < VecSize; i++) { + output_data[linear_index + i] = init_value; } - if (std::is_same::value) { - mat_type = CUDA_R_16BF; + } +} + +template +void InitValue(const phi::GPUContext &dev_ctx, + T *output_data, + const int64_t numel, + const T init_value) { + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ( + numel % PackSize, + 0, + phi::errors::PreconditionNotMet( + "numel=%d must be divisible by vec_size=%d", numel, PackSize)); + const int pack_num = numel / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + InitOutValueKernel + <<>>( + output_data, numel, init_value); +} + +template +__global__ void ActFFNGlu(const T *bias, + Functor act_functor, + const int token_num, + const int hid_dim, + const int elem_num, + LoadFunc load_func, + StoreFunc store_func) { + using LoadT = phi::AlignedVector; + LoadT src_vec1; + LoadT src_vec2; + LoadT bias_vec1; + LoadT bias_vec2; + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int bi = i / hid_dim; + int idx = i % hid_dim; + // const T *input_this_thread = input + bi * hid_dim * 2; + // T *output_this_thread = output + bi * hid_dim; + // phi::Load(&input_this_thread[idx], &src_vec1); + // phi::Load(&input_this_thread[idx + hid_dim], &src_vec2); + + load_func.template load(&src_vec1, bi * hid_dim * 2 + idx); + load_func.template load(&src_vec2, + bi * hid_dim * 2 + idx + hid_dim); + + if (bias) { + phi::Load(&bias[idx], &bias_vec1); + phi::Load(&bias[idx + hid_dim], &bias_vec2); } - if (std::is_same::value) { - mat_type = CUDA_R_64F; - scale_type = CUDA_R_64F; - compute_type = CUBLAS_COMPUTE_64F; +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if (bias) { + src_vec1[j] += bias_vec1[j]; + src_vec2[j] += bias_vec2[j]; + } + src_vec1[j] = act_functor(src_vec1[j]); + src_vec1[j] *= src_vec2[j]; } - - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate( - &operation_desc_, compute_type, scale_type)); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasLtMatrixLayoutCreate(&x_desc_, mat_type, 1, 1, 1)); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasLtMatrixLayoutCreate(&w_desc_, mat_type, 1, 1, 1)); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( - &out_desc_, mat_type, 1, 1, 1)); - } - ~CublasFusedMLP() { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasLtMatmulDescDestroy(operation_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasLtMatrixLayoutDestroy(x_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasLtMatrixLayoutDestroy(w_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasLtMatrixLayoutDestroy(out_desc_)); + // phi::Store(src_vec1, &output_this_thread[idx]); + store_func.template store(src_vec1, bi * hid_dim + idx); } +} - void Setup(const phi::DDim &x_shape, - const phi::DDim &w_shape, - bool trans_x, - bool trans_w) { - int64_t M = trans_x ? x_shape[1] : x_shape[0]; - int64_t K = trans_w ? w_shape[1] : w_shape[0]; - int64_t N = trans_w ? w_shape[0] : w_shape[1]; - - cublasOperation_t cublas_transA = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cublas_transB = trans_w ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublas_transA, - sizeof(cublas_transA))); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublas_transB, - sizeof(cublas_transB))); - - SetCublasMatrixLayout(x_desc_, trans_x, M, K); - SetCublasMatrixLayout(w_desc_, trans_w, K, N); - SetCublasMatrixLayout(out_desc_, false, M, N); +template +void LaunchActFFNGlu(const phi::GPUContext &dev_ctx, + const T *bias, + const int token_num, + const int hid_dim, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(LoadT); + const int elem_cnt = token_num * hid_dim; + const int blocksize = 128; + int grid_size = 1; + Functor functor; + switch (hid_dim % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + ActFFNGlu + <<>>(bias, + functor, + token_num, + hid_dim, + elem_cnt, + load_func, + store_func); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + ActFFNGlu<<>>( + bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); + break; } +} - void ComputeForward(const phi::DenseTensor *x, - const phi::DenseTensor *weight, - const phi::DenseTensor *bias, - phi::DenseTensor *residual, - phi::DenseTensor *output, - const std::string &activation) { - T *out_data = output->data(); +template +__global__ void BiasAct(const T *bias, + Functor act_functor, + const int rows, + const int cols, + const int elem_num, + LoadFunc load_func, + StoreFunc store_func) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; - const bool add_residual = (residual == nullptr) ? false : true; - const bool add_bias = (bias == nullptr) ? false : true; +// Zero Initialize BiasVec. +#pragma unroll + for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + bias_vec[unroll_idx] = 0; + } - const T *bias_data = nullptr; - if (add_bias) { - bias_data = bias->data(); + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int row_idx = i / cols; + int col_idx = i % cols; + int linear_idx = row_idx * cols + col_idx; + // phi::Load(&input[linear_idx], &src_vec); + load_func.template load(&src_vec, linear_idx); + if (bias) { + phi::Load(&bias[col_idx], &bias_vec); } - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_data, - sizeof(bias_data))); - - cublasLtEpilogue_t epiloque_func = GetEpilogueType(activation, add_bias); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func, - sizeof(epiloque_func))); - - T *residual_data = add_residual ? residual->data() : out_data; - - cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle(); - size_t workspace_size = static_cast(4) * 1024 * 1024; - cudaStream_t stream = dev_ctx_.stream(); - phi::Allocator::AllocationPtr workspace = phi::memory_utils::Alloc( - dev_ctx_.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(dev_ctx_.stream()))); - - // if add_residual, we compute result + 1.0 * residual, - // else result + 0.0 * out. - double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0; - float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f; - half alpha16 = static_cast(1.0), - beta16 = - add_residual ? static_cast(1.0) : static_cast(0.0); - - void *alpha = &alpha32, *beta = &beta32; - if (std::is_same::value) { - alpha = &alpha64; - beta = &beta64; +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if (bias) { + src_vec[j] += bias_vec[j]; + } + src_vec[j] = act_functor(src_vec[j]); } + // phi::Store(src_vec, &output[linear_idx]); + store_func.template store(src_vec, linear_idx); + } +} - if (std::is_same::value && - FLAGS_gemm_use_half_precision_compute_type) { - alpha = &alpha16; - beta = &beta16; - } +template +void LaunchBiasAct(const phi::GPUContext &dev_ctx, + const T *bias, + const int token_num, + const int hid_dim, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(LoadT); + const int elem_cnt = token_num * hid_dim; + const int blocksize = 128; + int grid_size = 1; + Functor functor; + switch (hid_dim % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + BiasAct + <<>>(bias, + functor, + token_num, + hid_dim, + elem_cnt, + load_func, + store_func); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + BiasAct<<>>( + bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); + break; + } +} + +template +__global__ void fused_transpose_split_kernel( + T *q_out, // [total, num_head, head_dim] + T *k_out, // [total, num_head, head_dim] + T *v_out, // [total, num_head, head_dim] + const T *q_input, // [bsz, num_head, seq_len, head_dim] + const T *kv_input, // [2, bsz, num_head, seq_len, head_dim] + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int max_len_this_time, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = + batch_size * max_len_this_time * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; - const auto *x_data = x->data(); - const auto *w_data = weight->data(); - - auto algo = phi::funcs::GemmEpilogueAlgoCache::Instance().GetGemmAlgo( - lt_handle, - operation_desc_, - w_desc_, - x_desc_, - out_desc_, - alpha, - beta, - w_data, - x_data, - out_data, - stream, - workspace->ptr(), - workspace_size); - - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle, - operation_desc_, - alpha, - w_data, - w_desc_, - x_data, - x_desc_, - beta, - residual_data, - out_desc_, - out_data, - out_desc_, - algo, - workspace->ptr(), - workspace_size, - stream)); + int q_size = token_num * hidden_size; + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + int32_t bias_idx = linear_index % fused_hidden_size; + int32_t current_token = linear_index / fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + const int32_t seq_id = ori_token_idx % seq_len; + + // equal to: + // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { // read q + phi::Load( + &q_input[target_batch_id * head_num * max_len_this_time * + size_per_head + + head_id * max_len_this_time * size_per_head + + seq_id * size_per_head + size_id], + &src_vec); + } else { // read k/v + const int32_t kv_store_offset = (qkv_id - 1) * offset; + phi::Load( + &kv_input[kv_store_offset + + target_batch_id * head_num * max_len_this_time * + size_per_head + + head_id * max_len_this_time * size_per_head + + seq_id * size_per_head + size_id], + &src_vec); + } + int32_t write_index = + linear_index - (qkv_id + 2 * current_token) * hidden_size; + if (qkv_id == 0) { + phi::Store(src_vec, &q_out[write_index]); + } else if (qkv_id == 1) { + phi::Store(src_vec, &k_out[write_index]); + } else if (qkv_id == 2) { + phi::Store(src_vec, &v_out[write_index]); + } } +} - private: - cublasLtEpilogue_t GetEpilogueType(const std::string &activation, - const bool add_bias) { - if (activation == "relu") { - if (add_bias) { - return CUBLASLT_EPILOGUE_RELU_BIAS; - } else { - return CUBLASLT_EPILOGUE_RELU; - } - } else if (activation == "gelu") { - if (add_bias) { - return CUBLASLT_EPILOGUE_GELU_BIAS; - } else { - return CUBLASLT_EPILOGUE_GELU; - } - } else if (activation == "none") { - if (add_bias) { - return CUBLASLT_EPILOGUE_BIAS; +template +void TransposeSplit(const phi::GPUContext &dev_ctx, + T *q_out, + T *k_out, + T *v_out, + const T *q_input, + const T *kv_input, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int max_len_this_time, + const int seq_len, + const int size_per_head) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + phi::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + fused_transpose_split_kernel + <<>>(q_out, + k_out, + v_out, + q_input, + kv_input, + padding_offset, + seq_lens, + elem_cnt, + batch_size, + max_len_this_time, + seq_len, + token_num, + head_num, + size_per_head); +} + +template +void TransposeSplit(const phi::GPUContext &dev_ctx, + T *q_out, + T *k_out, + T *v_out, + const T *q_input, + const T *kv_input, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head) { + TransposeSplit(dev_ctx, + q_out, + k_out, + v_out, + q_input, + kv_input, + padding_offset, + seq_lens, + token_num, + batch_size, + head_num, + seq_len, + seq_len, + size_per_head); +} + +template +__global__ void VariableLengthRotaryKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const T *qkv_biases, + T *qkv_out, + const int64_t elem_cnt, + const int num_head, + const int seq_len, + const int last_dim) { + using LoadT = phi::AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int hidden_size = num_head * last_dim; + const int offset = 3 * hidden_size; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size; + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / last_dim; + const int h_bias = qkv_bias % last_dim; + + const int ori_seq_id = ori_token_idx % seq_len; + + const int64_t emb_idx = + ori_bi * seq_len * last_dim + ori_seq_id * last_dim + h_bias; + const int64_t bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias; + const int64_t base_idx = token_idx * 3 * hidden_size + bias_idx; + phi::Load(&qkv[base_idx], &src_vec); + phi::Load(&qkv_biases[bias_idx], &bias_vec); + phi::Load(&cos_emb[emb_idx], &cos_emb_vec); + phi::Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = + static_cast(src_vec[2 * i] + bias_vec[2 * i]); + const float input_right = + static_cast(src_vec[2 * i + 1] + bias_vec[2 * i + 1]); + // const float cos_tmp = cos_emb_vec[i]; + // const float sin_tmp = sin_emb_vec[i]; + // src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * + // sin_tmp); src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + + // input_left * sin_tmp); + + if (qkv_id < 2) { // qk rope + const float cos_tmp = cos_emb_vec[2 * i]; + const float sin_tmp = sin_emb_vec[2 * i]; + src_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + src_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); } else { - return CUBLASLT_EPILOGUE_DEFAULT; + src_vec[2 * i] = static_cast(input_left); + src_vec[2 * i + 1] = static_cast(input_right); } - } else { - PADDLE_ENFORCE_EQ( - true, - false, - phi::errors::InvalidArgument( - "The activation attribute of fused_gemm_epilogue op should be" - " one of {\"none\", \"relu\", \"gelu\"}. But received %s." - "But received activation=%s.", - activation)); } + phi::Store(src_vec, &qkv_out[base_idx]); } +} - void SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc, - const bool transpose, - const uint64_t cublas_row, - const uint64_t cublas_col) { - cudaDataType_t mat_type = CUDA_R_32F; - if (std::is_same::value) { - mat_type = CUDA_R_16F; - } - if (std::is_same::value) { - mat_type = CUDA_R_16BF; - } - if (std::is_same::value) { - mat_type = CUDA_R_64F; +template +void rotary_qk_variable( + const phi::GPUContext &dev_ctx, + T *qkv, // [token_num, 3, num_head, dim_head] + const T *qkv_input, // qkv + const T *qkv_bias, + const float *rotary_emb, // [2, bs, 1, seq_len, dim_head] + const int *padding_offsets, + const int *seq_lens, + const int token_num, + const int head_num, + const int seq_len, + const int input_output_len, + const int dim_head, + const int rope_bsz) { + const int elem_nums = token_num * 3 * head_num * dim_head; // just q and k + constexpr int PackSize = 16 / sizeof(T); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + rope_bsz * input_output_len * dim_head; + + VariableLengthRotaryKernel + <<>>(qkv_input, + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + qkv_bias, + qkv, + elem_nums, + head_num, + seq_len, + dim_head); +} + +template +__global__ void GQAVariableLengthRotaryKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const T *qkv_biases, + T *qkv_out, + const int64_t elem_cnt, + const int num_head, + const int seq_len, + const int last_dim, + const int gqa_group_size) { + using LoadT = phi::AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + // const int hidden_size = num_head * last_dim; + const int offset = (num_head + 2 * gqa_group_size) * last_dim; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / last_dim; + const int h_bias = bias % last_dim; + + const int ori_seq_id = ori_token_idx % seq_len; + + const int64_t emb_idx = + ori_bi * seq_len * last_dim + ori_seq_id * last_dim + h_bias; + const int64_t bias_idx = hi * last_dim + h_bias; + const int64_t base_idx = token_idx * offset + bias_idx; + phi::Load(&qkv[base_idx], &src_vec); + phi::Load(&qkv_biases[bias_idx], &bias_vec); + phi::Load(&cos_emb[emb_idx], &cos_emb_vec); + phi::Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = + static_cast(src_vec[2 * i] + bias_vec[2 * i]); + const float input_right = + static_cast(src_vec[2 * i + 1] + bias_vec[2 * i + 1]); + // const float cos_tmp = cos_emb_vec[i]; + // const float sin_tmp = sin_emb_vec[i]; + // src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * + // sin_tmp); src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + + // input_left * sin_tmp); + + if (hi < num_head + gqa_group_size) { // qk rope + const float cos_tmp = cos_emb_vec[2 * i]; + const float sin_tmp = sin_emb_vec[2 * i]; + src_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + src_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + src_vec[2 * i] = static_cast(input_left); + src_vec[2 * i + 1] = static_cast(input_right); + } } - - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, CUBLASLT_MATRIX_LAYOUT_TYPE, &mat_type, sizeof(mat_type))); - - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_ROWS, - transpose ? &cublas_row : &cublas_col, - sizeof(cublas_row))); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_COLS, - transpose ? &cublas_col : &cublas_row, - sizeof(cublas_col))); - int64_t cublas_ld = transpose ? cublas_row : cublas_col; - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, CUBLASLT_MATRIX_LAYOUT_LD, &cublas_ld, sizeof(cublas_ld))); + phi::Store(src_vec, &qkv_out[base_idx]); } +} - const phi::GPUContext &dev_ctx_; - cublasLtMatmulDesc_t operation_desc_ = NULL; - cublasLtMatrixLayout_t x_desc_ = NULL; - cublasLtMatrixLayout_t w_desc_ = NULL; - cublasLtMatrixLayout_t out_desc_ = NULL; -}; - -#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_ +template +void gqa_rotary_qk_variable( + const phi::GPUContext &dev_ctx, + T *qkv, // [token_num, 3, num_head, dim_head] + const T *qkv_input, // qkv + const T *qkv_bias, + const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const int *padding_offsets, + const int *seq_lens, + const int token_num, + const int head_num, + const int seq_len, + const int input_output_len, + const int dim_head, + const int gqa_group_size, + const int rope_bsz) { + const int elem_nums = + token_num * (head_num + 2 * gqa_group_size) * dim_head; // for all q k v + constexpr int PackSize = 16 / sizeof(T); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + rope_bsz * input_output_len * dim_head; + GQAVariableLengthRotaryKernel + <<>>(qkv_input, + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + qkv_bias, + qkv, + elem_nums, + head_num, + seq_len, + dim_head, + gqa_group_size); +} } // namespace diff --git a/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc b/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc index 184df326b79e8..8aa4b0c8b6b10 100644 --- a/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc +++ b/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc @@ -20,25 +20,13 @@ KernelSignature FusedMultiTransformerOpArgumentMapping( const ArgumentMappingContext& ctx UNUSED) { return KernelSignature("fused_multi_transformer", { - "X", - "LnScale", - "LnBias", - "QKVW", - "QKVBias", - "CacheKV", - "PreCaches", - "RotaryPosEmb", - "TimeStep", - "SeqLengths", - "SrcMask", - "OutLinearW", - "OutLinearBias", - "FFNLnScale", - "FFNLnBias", - "FFN1Weight", - "FFN1Bias", - "FFN2Weight", - "FFN2Bias", + "X", "LnScale", "LnBias", + "QKVW", "QKVBias", "CacheKV", + "PreCaches", "RotaryPosEmb", "BeamCacheOffset", + "TimeStep", "SeqLengths", "SrcMask", + "OutLinearW", "OutLinearBias", "FFNLnScale", + "FFNLnBias", "FFN1Weight", "FFN1Bias", + "FFN2Weight", "FFN2Bias", }, {"pre_layer_norm", "epsilon", @@ -48,7 +36,11 @@ KernelSignature FusedMultiTransformerOpArgumentMapping( "dropout_implementation", "act_method", "trans_qkvw", - "ring_id"}, + "ring_id", + "residual_alpha", + "norm_type", + "use_neox_rotary_style", + "gqa_group_size"}, {"CacheKVOut", "Out"}); } diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index c3425ac604858..ac3944a7ce12c 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -38,8 +38,8 @@ namespace dynload { using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ extern DynLoad__##__name __name -// APIs available after CUDA 10.1 -// #if CUDA_VERSION >= 10100 +// APIs available after CUDA 11.1 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -61,7 +61,33 @@ namespace dynload { __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/fluid/pybind/eager_generator.h b/paddle/fluid/pybind/eager_generator.h index 2d218b7d352be..a2f2eca47d219 100644 --- a/paddle/fluid/pybind/eager_generator.h +++ b/paddle/fluid/pybind/eager_generator.h @@ -59,25 +59,13 @@ std::map> op_ins_map = { "OutLinearWeight", "OutLinearBias"}}, {"fused_multi_transformer", - {"X", - "LnScale", - "LnBias", - "QKVW", - "QKVBias", - "CacheKV", - "PreCaches", - "RotaryPosEmb", - "TimeStep", - "SeqLengths", - "SrcMask", - "OutLinearW", - "OutLinearBias", - "FFNLnScale", - "FFNLnBias", - "FFN1Weight", - "FFN1Bias", - "FFN2Weight", - "FFN2Bias"}}, + {"X", "LnScale", "LnBias", + "QKVW", "QKVBias", "CacheKV", + "PreCaches", "RotaryPosEmb", "BeamCacheOffset", + "TimeStep", "SeqLengths", "SrcMask", + "OutLinearW", "OutLinearBias", "FFNLnScale", + "FFNLnBias", "FFN1Weight", "FFN1Bias", + "FFN2Weight", "FFN2Bias"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", "SrcMask", diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index d679bd7ab2d88..0b3d79b6e4ea4 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -184,6 +184,9 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { case PaddleDType::FLOAT16: dt = py::dtype::of(); break; + case PaddleDType::BFLOAT16: + dt = py::dtype::of(); + break; case PaddleDType::UINT8: dt = py::dtype::of(); break; @@ -196,7 +199,7 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now only supports INT32, INT64, FLOAT64, " - "FLOAT32, FLOAT16, INT8, UINT8 and BOOL.")); + "FLOAT32, FLOAT16, BFLOAT16, INT8, UINT8 and BOOL.")); } return dt; @@ -386,6 +389,9 @@ size_t PaddleGetDTypeSize(PaddleDType dt) { case PaddleDType::FLOAT16: size = sizeof(phi::dtype::float16); break; + case PaddleDType::BFLOAT16: + size = sizeof(phi::dtype::bfloat16); + break; case PaddleDType::INT8: size = sizeof(int8_t); break; @@ -398,7 +404,7 @@ size_t PaddleGetDTypeSize(PaddleDType dt) { default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data t ype. Now only supports INT32, INT64, FLOAT64, " - "FLOAT32, FLOAT16, INT8, UINT8 and BOOL.")); + "FLOAT32, FLOAT16, BFLOAT16, INT8, UINT8 and BOOL.")); } return size; } @@ -426,6 +432,10 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT tensor.copy_to_cpu( static_cast(array.mutable_data())); break; + case PaddleDType::BFLOAT16: + tensor.copy_to_cpu( + static_cast(array.mutable_data())); + break; case PaddleDType::UINT8: tensor.copy_to_cpu(static_cast(array.mutable_data())); break; @@ -438,7 +448,7 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now only supports INT32, INT64, FLOAT64, " - "FLOAT32, FLOAT16, INT8, UINT8 and BOOL.")); + "FLOAT32, FLOAT16, BFLOAT16, INT8, UINT8 and BOOL.")); } return array; } @@ -466,6 +476,10 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT tensor.CopyToCpu( static_cast(array.mutable_data())); break; + case PaddleDType::BFLOAT16: + tensor.CopyToCpu( + static_cast(array.mutable_data())); + break; case PaddleDType::UINT8: tensor.CopyToCpu(static_cast(array.mutable_data())); break; @@ -478,7 +492,7 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data t ype. Now only supports INT32, INT64, FLOAT64, " - "FLOAT32, FLOAT16, INT8, UINT8 and BOOL.")); + "FLOAT32, FLOAT16, BFLOAT16, INT8, UINT8 and BOOL.")); } return array; } @@ -561,6 +575,7 @@ void BindPaddleDType(py::module *m) { .value("FLOAT64", PaddleDType::FLOAT64) .value("FLOAT32", PaddleDType::FLOAT32) .value("FLOAT16", PaddleDType::FLOAT16) + .value("BFLOAT16", PaddleDType::BFLOAT16) .value("INT64", PaddleDType::INT64) .value("INT32", PaddleDType::INT32) .value("UINT8", PaddleDType::UINT8) diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 5b05ee644f6c5..5da900f7515db 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -52,8 +52,8 @@ extern void *cublasLt_dso_handle; }; \ extern DynLoad__##__name __name -// APIs available after CUDA 10.1 -// #if CUDA_VERSION >= 10100 +// APIs available after CUDA 11.1 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -75,7 +75,33 @@ extern void *cublasLt_dso_handle; __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index ed946f4377a8f..9987524d4997d 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -139,6 +139,7 @@ void FusedMultiTransformerInferMeta( const paddle::optional>& cache_kvs, const paddle::optional>& pre_caches, const MetaTensor& rotary_tensor, + const MetaTensor& beam_offset, const MetaTensor& time_step, const MetaTensor& seq_lengths, const MetaTensor& src_mask, @@ -152,6 +153,7 @@ void FusedMultiTransformerInferMeta( const paddle::optional>& ffn2_biases, bool pre_layer_norm, float epsilon, + float residual_alpha, float dropout_rate, int rotary_emb_dims, bool is_test, @@ -159,6 +161,9 @@ void FusedMultiTransformerInferMeta( const std::string& act_method, bool trans_qkvw, int ring_id, + const std::string& norm_type, + bool use_neox_rotary_style, + int gqa_group_size, std::vector cache_kv_outs, MetaTensor* out) { // x: qkv's input [batch_size, seq_len, dim_embed] diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index fd72ad88ce87f..aa48f64434ee3 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -31,6 +31,7 @@ void FusedMultiTransformerInferMeta( const paddle::optional>& cache_kvs, const paddle::optional>& pre_caches, const MetaTensor& rotary_tensor, + const MetaTensor& beam_offset, const MetaTensor& time_step, const MetaTensor& seq_lengths, const MetaTensor& src_mask, @@ -44,6 +45,7 @@ void FusedMultiTransformerInferMeta( const paddle::optional>& ffn2_biases, bool pre_layer_norm, float epsilon, + float residual_alpha, float dropout_rate, int rotary_emb_dims, bool is_test, @@ -51,6 +53,9 @@ void FusedMultiTransformerInferMeta( const std::string& act_method, bool trans_qkvw, int ring_id, + const std::string& norm_type, + bool use_neox_rotary_style, + int gqa_group_size, std::vector cache_kv_outs, MetaTensor* out); diff --git a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h index 7d4e0c81198b1..d6b9912d56bd3 100644 --- a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h +++ b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h @@ -628,6 +628,14 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { template inline __device__ void mul_pointer_v2(T* c, float a, IntT* b); +template <> +inline __device__ void mul_pointer_v2(float4* c, float a, uint8_t* b) { + c->x = a * (static_cast(b[0]) - 128.0); + c->y = a * (static_cast(b[1]) - 128.0); + c->z = a * (static_cast(b[2]) - 128.0); + c->w = a * (static_cast(b[3]) - 128.0); +} + template <> inline __device__ void mul_pointer_v2(float4* c, float a, uint32_t* b) { uint8_t* b_tmp = reinterpret_cast(b); @@ -637,6 +645,12 @@ inline __device__ void mul_pointer_v2(float4* c, float a, uint32_t* b) { c->w = a * (static_cast(b_tmp[3]) - 128.0); } +template <> +inline __device__ void mul_pointer_v2(float2* c, float a, uint8_t* b) { + c->x = a * (static_cast(b[0]) - 128.0); + c->y = a * (static_cast(b[1]) - 128.0); +} + template <> inline __device__ void mul_pointer_v2(float* c, float a, uint8_t* b) { c[0] = a * (static_cast(b[0]) - 128.0); @@ -667,6 +681,17 @@ inline __device__ void convert_(float16* result, uint32_t const& source) { #endif } +template <> +inline __device__ void mul_pointer_v2(uint32_t* c, float a, uint8_t* b) { + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 2; ++i) { + tmp_fp16[i] = a_prime * (static_cast(b[i]) - offset); + } +} + // float16 * 2 <- uint8_t * 2 template <> inline __device__ void mul_pointer_v2(uint32_t* c, float a, uint16_t* b) { @@ -696,6 +721,17 @@ inline __device__ void mul_pointer_v2(uint32_t* c, float a, uint16_t* b) { #endif } +template <> +inline __device__ void mul_pointer_v2(uint2* c, float a, uint8_t* b) { + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 4; ++i) { + tmp_fp16[i] = a_prime * (static_cast(b[i]) - offset); + } +} + // float16 * 4 <- uint8_t * 4 template <> inline __device__ void mul_pointer_v2(uint2* c, float a, uint32_t* b) { @@ -707,6 +743,18 @@ inline __device__ void mul_pointer_v2(uint2* c, float a, uint32_t* b) { c_prime[i] *= a_prime; } } + +template <> +inline __device__ void mul_pointer_v2(uint4* c, float a, uint8_t* b) { + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 8; ++i) { + tmp_fp16[i] = a_prime * (static_cast(b[i]) - offset); + } +} + // float16 * 8 <- uint8_t * 8 template <> inline __device__ void mul_pointer_v2(uint4* c, float a, uint64_t* b) { @@ -750,6 +798,19 @@ inline __device__ static void convert_(__nv_bfloat16* result, #endif } +template <> +inline __device__ void mul_pointer_v2(__nv_bfloat162* c, float a, uint8_t* b) { +#if __CUDA_ARCH__ >= 800 + __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); + __nv_bfloat16* c_prime = reinterpret_cast<__nv_bfloat16*>(c); + convert_(c_prime, static_cast(*reinterpret_cast(b))); +#pragma unroll + for (int i = 0; i < 2; ++i) { + c_prime[i] *= a_prime; + } +#endif +} + template <> inline __device__ void mul_pointer_v2(__nv_bfloat162* c, float a, uint16_t* b) { using Packed_Int8_t = typename packed_type::type; @@ -779,6 +840,19 @@ inline __device__ void mul_pointer_v2(__nv_bfloat162* c, float a, uint16_t* b) { c->y = c->y * scale; } +template <> +inline __device__ void mul_pointer_v2(bf16_4_t* c, float a, uint8_t* b) { +#if __CUDA_ARCH__ >= 800 + __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); + __nv_bfloat16* c_prime = reinterpret_cast<__nv_bfloat16*>(c); + convert_(c_prime, *reinterpret_cast(b)); +#pragma unroll + for (int i = 0; i < 4; ++i) { + c_prime[i] *= a_prime; + } +#endif +} + template <> inline __device__ void mul_pointer_v2(bf16_4_t* c, float a, uint32_t* b) { __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); @@ -790,6 +864,15 @@ inline __device__ void mul_pointer_v2(bf16_4_t* c, float a, uint32_t* b) { } } +template <> +inline __device__ void mul_pointer_v2(bf16_8_t* c, float a, uint8_t* b) { + bf16_4_t* tmp_c = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 2; ++i) { + mul_pointer_v2(tmp_c + i, a, b + 4 * i); + } +} + template <> inline __device__ void mul_pointer_v2(bf16_8_t* c, float a, uint64_t* b) { bf16_4_t* tmp_c = reinterpret_cast(c); diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml index f27c19984c599..e76f9b444e2cf 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml @@ -600,8 +600,8 @@ optional: reserve_space - op : fused_multi_transformer - args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1) - optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs + args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor beam_offset, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float residual_alpha = 1.0f, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1, str norm_type = "layernorm", bool use_neox_rotary_style=true, int gqa_group_size=-1) + optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, beam_offset, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out) infer_meta : func : FusedMultiTransformerInferMeta diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 5202b5243f072..b948415ea318d 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -803,8 +803,8 @@ backward : fused_bn_add_activation_grad - op : fused_multi_transformer - args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1) - optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs + args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor beam_offset, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float residual_alpha = 1.0f, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw = true, int ring_id = -1, str norm_type = "layernorm", bool use_neox_rotary_style=true, int gqa_group_size=-1) + optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, beam_offset, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out) infer_meta : func : FusedMultiTransformerInferMeta diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index d63456d6d83dd..9c24d71cb0f7d 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -97,7 +97,7 @@ def _all_to_all_in_static_mode( data_feeder.check_variable_and_dtype( in_tensor, 'in_tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'all_to_all', ) helper.append_op( diff --git a/python/paddle/distributed/communication/stream/recv.py b/python/paddle/distributed/communication/stream/recv.py index 8a55db4407590..f009ae1c1ae16 100644 --- a/python/paddle/distributed/communication/stream/recv.py +++ b/python/paddle/distributed/communication/stream/recv.py @@ -43,7 +43,7 @@ def _recv_in_static_mode( data_feeder.check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'recv', ) ring_id = 0 if group is None else group.id diff --git a/python/paddle/distributed/communication/stream/send.py b/python/paddle/distributed/communication/stream/send.py index 2013c619f278f..36275344ebbaf 100644 --- a/python/paddle/distributed/communication/stream/send.py +++ b/python/paddle/distributed/communication/stream/send.py @@ -43,7 +43,7 @@ def _send_in_static_mode( data_feeder.check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'send', ) diff --git a/python/paddle/distributed/utils/moe_utils.py b/python/paddle/distributed/utils/moe_utils.py index 7e4db782c91ab..d41b4892dd639 100644 --- a/python/paddle/distributed/utils/moe_utils.py +++ b/python/paddle/distributed/utils/moe_utils.py @@ -124,7 +124,7 @@ def global_scatter( check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'global_scatter', ) check_variable_and_dtype( @@ -249,7 +249,7 @@ def global_gather( check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'global_gather', ) diff --git a/python/paddle/incubate/layers/nn.py b/python/paddle/incubate/layers/nn.py index aee7f2b9088de..cecf49d0ce091 100644 --- a/python/paddle/incubate/layers/nn.py +++ b/python/paddle/incubate/layers/nn.py @@ -564,6 +564,7 @@ def partial_concat(input, start_index=0, length=-1): 'float16', 'float32', 'float64', + 'uint16', 'int32', 'int64', 'complex64', diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 5a25e0b91f082..30c226d05dee4 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -977,7 +977,9 @@ def fused_multi_transformer( ffn2_biases, pre_layer_norm=True, epsilon=1e-05, + residual_alpha=1.0, cache_kvs=None, + beam_offset=None, pre_caches=None, seq_lens=None, rotary_embs=None, @@ -990,6 +992,9 @@ def fused_multi_transformer( mode='upscale_in_train', trans_qkvw=True, ring_id=-1, + norm_type="layernorm", + use_neox_rotary_style=False, + gqa_group_size=-1, name=None, ): r""" @@ -1032,32 +1037,55 @@ def fused_multi_transformer( ... out = ffn_layer_norm(out) Args: - x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`. - ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`. - ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`. - qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`. - qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`. - linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`. - linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`. - ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]` - ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]` - ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`. - ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`. - ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`. - ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`. - pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True. - epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. - cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. - pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, + the shape is `[batch\_size, sequence\_length, d\_model]`. + ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, + the shape is `[d\_model]`. + ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. + the shape is `[d\_model]`. + qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. + The shape is `[3, num\_head, dim\_head, d\_model]`. + qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. + The shape is `[3, num\_head, dim\_head]`. + linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. + The shape is `[num\_head * dim\_head, d\_model]`. + linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. + The shape is `[d\_model]`. + ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, + the shape is `[d\_model]` + ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, + the shape is `[d\_model]` + ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, + the shape is `[d\_model, dim\_feedforward]`. + ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, + the shape is `[dim\_feedforward]`. + ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, + the shape is `[dim\_feedforward, d\_model]`. + ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, + the shape is `[d_model]`. + pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). + Default True. + epsilon (float, optional): Small float value added to denominator of the layer_norm + to avoid dividing by zero. Default is 1e-5. + cache_kvs (list(Tensor)|tuple(Tensor), optional): + The cache structure tensors for the generation model. + The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. + pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. + The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None. - rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. - time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. + rotary_embs (Tensor optional): The RoPE embs for rotary computation. + The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. + time_step (Tensor, optional): The time step tensor for the generation model. + Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. + The shape is `[1]`, must be in CPUPlace. Default None. attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None. dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0. - rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, - 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, + and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, + 2 when rotary_embs and pos_extra_ids are both not None. Default 0. activation (str, optional): The activation. Default "gelu". training (bool, optional): A flag indicating whether it is in train phrase or not. Default False. mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] @@ -1074,8 +1102,10 @@ def fused_multi_transformer( trans_qkvw (bool, optional): Whether to transpose for weights of qkv. If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True. - ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. + Default is -1, means not using mp. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor|tuple: If `cache_kvs` is None, return a tensor that has @@ -1087,37 +1117,38 @@ def fused_multi_transformer( Examples: .. code-block:: python - >>> # doctest: +REQUIRES(env:GPU) + >>> # doctest: +SKIP('Depends on Flash Attention 2.') + >>> import re >>> import paddle >>> paddle.device.set_device('gpu') >>> import paddle.incubate.nn.functional as F >>> # input: [batch_size, seq_len, embed_dim] - >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float16") >>> # ln_scale: [embed_dim], ln_bias: [embed_dim] >>> ln_scale = paddle.rand(shape=(128,), dtype="float32") >>> ln_bias = paddle.rand(shape=(128,), dtype="float32") >>> # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim] - >>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") - >>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + >>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float16") + >>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float16") >>> # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim] - >>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32") - >>> linear_bias = paddle.rand(shape=(128,), dtype="float32") + >>> linear_weight = paddle.rand(shape=(128, 128), dtype="float16") + >>> linear_bias = paddle.rand(shape=(128,), dtype="float16") >>> # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim] >>> ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32") >>> ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32") >>> # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim] - >>> ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32") - >>> ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32") + >>> ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float16") + >>> ffn1_bias = paddle.rand(shape=(4*128,), dtype="float16") >>> # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim] - >>> ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32") - >>> ffn2_bias = paddle.rand(shape=(128,), dtype="float32") + >>> ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float16") + >>> ffn2_bias = paddle.rand(shape=(128,), dtype="float16") >>> # self attention mask: [batch_size, 1, seq_len, seq_len] >>> attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32") @@ -1140,7 +1171,7 @@ def fused_multi_transformer( ) # semantic transfer if in_dynamic_or_pir_mode(): - cache_kv_out, final_out = _C_ops.fused_multi_transformer( + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( x, ln_scales, ln_biases, @@ -1149,6 +1180,7 @@ def fused_multi_transformer( cache_kvs, pre_caches, rotary_embs, + beam_offset, time_step, seq_lens, attn_mask, @@ -1160,15 +1192,33 @@ def fused_multi_transformer( ffn1_biases, ffn2_weights, ffn2_biases, + cache_kvs, + 'pre_layer_norm', pre_layer_norm, + 'epsilon', epsilon, + 'residual_alpha', + residual_alpha, + 'dropout_rate', dropout_rate, + 'rotary_emb_dims', rotary_emb_dims, + 'is_test', not training, + 'dropout_implementation', mode, + 'act_method', activation, + 'trans_qkvw', trans_qkvw, + 'ring_id', ring_id, + 'norm_type', + norm_type, + 'use_neox_rotary_style', + use_neox_rotary_style, + 'gqa_group_size', + gqa_group_size, ) if cache_kvs is not None: return final_out, cache_kv_out @@ -1178,18 +1228,23 @@ def fused_multi_transformer( dtype = x.dtype # check dtypes check_variable_and_dtype( - x, 'x', ['float16', 'float32'], 'fused_multi_transformer' + x, 'x', ['uint16', 'float16'], 'fused_multi_transformer' ) check_dtype( - dtype, 'dtype', ['float16', 'float32'], 'fused_multi_transformer' + dtype, + 'dtype', + ['uint16', 'float16'], + 'fused_multi_transformer', ) # set inputs inputs = {} inputs['X'] = [x] inputs['LnScale'] = ln_scales - inputs['LnBias'] = ln_biases inputs['QKVW'] = qkv_weights + + if ln_biases is not None: + inputs['LnBias'] = ln_biases if qkv_biases is not None: inputs['QKVBias'] = qkv_biases if cache_kvs is not None: @@ -1199,6 +1254,8 @@ def fused_multi_transformer( inputs['TimeStep'] = time_step if pre_caches is not None: inputs['PreCaches'] = pre_caches + if beam_offset is not None: + inputs['BeamCacheOffset'] = beam_offset if rotary_emb_dims > 0: inputs['RotaryPosEmb'] = rotary_embs inputs['SeqLengths'] = seq_lens @@ -1208,7 +1265,8 @@ def fused_multi_transformer( inputs['OutLinearBias'] = linear_biases inputs['FFNLnScale'] = ffn_ln_scales - inputs['FFNLnBias'] = ffn_ln_biases + if ffn_ln_biases is not None: + inputs['FFNLnBias'] = ffn_ln_biases inputs['FFN1Weight'] = ffn1_weights if ffn1_biases is not None: inputs['FFN1Bias'] = ffn1_biases @@ -1220,6 +1278,7 @@ def fused_multi_transformer( attrs = { 'pre_layer_norm': pre_layer_norm, 'epsilon': epsilon, + 'residual_alpha': residual_alpha, 'dropout_rate': dropout_rate, 'rotary_emb_dims': rotary_emb_dims, 'is_test': not training, @@ -1227,6 +1286,9 @@ def fused_multi_transformer( 'act_method': activation, 'trans_qkvw': trans_qkvw, 'ring_id': ring_id, + 'norm_type': norm_type, + 'use_neox_rotary_style': use_neox_rotary_style, + 'gqa_group_size': gqa_group_size, } outputs = {} diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index c1465bd9e9379..44de5c3366b51 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1176,10 +1176,14 @@ def __init__( ffn2_weight_attrs=None, ffn2_bias_attrs=None, epsilon=1e-5, + residual_alpha=1.0, num_layers=-1, nranks=1, trans_qkvw=True, ring_id=-1, + norm_type="layernorm", + use_neox_rotary_style=False, + gqa_group_size=-1, name=None, ): super().__init__() @@ -1198,8 +1202,15 @@ def __init__( self.normalize_before = normalize_before self._dtype = self._helper.get_default_dtype() self._epsilon = epsilon + self._residual_alpha = residual_alpha self._trans_qkvw = trans_qkvw self._ring_id = ring_id + self._norm_type = norm_type + self._use_neox_rotary_style = use_neox_rotary_style + self._gqa_group_size = gqa_group_size + self._norm_weight_dtype = ( + "float32" if self._norm_type == "layernorm" else self._dtype + ) self.embed_dim = embed_dim self.num_heads = num_heads @@ -1227,6 +1238,10 @@ def __init__( self.ffn_ln_scales, self.ffn_ln_biases = [], [] self.ffn1_weights, self.ffn1_biases = [], [] self.ffn2_weights, self.ffn2_biases = [], [] + self.qkv_weights_scales = [] + self.linear_weights_scales = [] + self.ffn1_weights_scales = [] + self.ffn2_weights_scales = [] def get_attr(attrs, idx): if isinstance(attrs, (list, tuple)): @@ -1234,6 +1249,12 @@ def get_attr(attrs, idx): return attrs[idx] return attrs + def _add_parameter(param): + if param is None: + return + assert param.name not in self._parameters + self._parameters[param.name] = param + for i in range(num_layers): ln_scale_attr = get_attr(ln_scale_attrs, i) ln_bias_attr = get_attr(ln_bias_attrs, i) @@ -1253,70 +1274,99 @@ def get_attr(attrs, idx): attr=ln_scale_attr, shape=[embed_dim], default_initializer=Constant(value=1.0), + dtype=self._norm_weight_dtype, ) - ln_bias = self.create_parameter( - attr=ln_bias_attr, shape=[embed_dim], is_bias=True + ln_bias = None + if ln_bias_attr: + ln_bias = self.create_parameter( + attr=ln_bias_attr, + shape=[embed_dim], + is_bias=True, + dtype=self._norm_weight_dtype, + ) + qkv_head_shape = ( + [3, num_heads] + if self._gqa_group_size <= 0 + else [num_heads + 2 * self._gqa_group_size] ) qkv_weight = self.create_parameter( - shape=[3, num_heads, self.head_dim, embed_dim] + shape=qkv_head_shape + [self.head_dim, embed_dim] if trans_qkvw - else [embed_dim, 3, num_heads, self.head_dim], + else [embed_dim] + qkv_head_shape + [self.head_dim], attr=qkv_weight_attr, dtype=self._dtype, is_bias=False, ) - qkv_bias = self.create_parameter( - shape=[3, num_heads, self.head_dim], - attr=qkv_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + qkv_bias = None + if qkv_bias_attr: + qkv_bias = self.create_parameter( + shape=qkv_head_shape + [self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) linear_weight = self.create_parameter( shape=[num_heads * self.head_dim, embed_dim], attr=linear_weight_attr, dtype=self._dtype, is_bias=False, ) - linear_bias = self.create_parameter( - shape=[embed_dim], - attr=linear_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + linear_bias = None + if linear_bias_attr: + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) ffn_ln_scale = self.create_parameter( shape=[embed_dim], attr=ffn_ln_scale_attr, is_bias=False, default_initializer=Constant(1.0), + dtype=self._norm_weight_dtype, ) - ffn_ln_bias = self.create_parameter( - shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True - ) + ffn_ln_bias = None + if ffn_ln_bias_attr: + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_bias_attr, + is_bias=True, + dtype=self._norm_weight_dtype, + ) ffn1_weight = self.create_parameter( - shape=[embed_dim, dim_feedforward], + shape=[embed_dim, dim_feedforward * 2] + if activation.endswith("glu") + else [embed_dim, dim_feedforward], attr=ffn1_weight_attr, dtype=self._dtype, is_bias=False, ) - ffn1_bias = self.create_parameter( - shape=[dim_feedforward], - attr=ffn1_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + ffn1_bias = None + if ffn1_bias_attr: + ffn1_bias = self.create_parameter( + shape=[dim_feedforward * 2] + if activation.endswith("glu") + else [dim_feedforward], + attr=ffn1_bias_attr, + dtype=self._dtype, + is_bias=True, + ) ffn2_weight = self.create_parameter( shape=[dim_feedforward, embed_dim], attr=ffn2_weight_attr, dtype=self._dtype, is_bias=False, ) - ffn2_bias = self.create_parameter( - shape=[embed_dim], - attr=ffn2_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + ffn2_bias = None + if ffn2_bias_attr: + ffn2_bias = self.create_parameter( + shape=[embed_dim], + attr=ffn2_bias_attr, + dtype=self._dtype, + is_bias=True, + ) # tensor model parallel if nranks > 1: @@ -1342,6 +1392,37 @@ def get_attr(attrs, idx): self.ffn1_biases.append(ffn1_bias) self.ffn2_weights.append(ffn2_weight) self.ffn2_biases.append(ffn2_bias) + _add_parameter(ln_scale) + _add_parameter(ln_bias) + _add_parameter(qkv_weight) + _add_parameter(qkv_bias) + _add_parameter(linear_weight) + _add_parameter(linear_bias) + + _add_parameter(ffn_ln_scale) + _add_parameter(ffn_ln_bias) + _add_parameter(ffn1_weight) + _add_parameter(ffn1_bias) + _add_parameter(ffn2_weight) + _add_parameter(ffn2_bias) + + if self.ln_biases[0] is None: + self.ln_biases = None + + if self.qkv_biases[0] is None: + self.qkv_biases = None + + if self.linear_biases[0] is None: + self.linear_biases = None + + if self.ffn_ln_biases[0] is None: + self.ffn_ln_biases = None + + if self.ffn1_biases[0] is None: + self.ffn1_biases = None + + if self.ffn2_biases[0] is None: + self.ffn2_biases = None self.dropout_rate = dropout_rate self.activation = activation @@ -1355,6 +1436,7 @@ def forward( pre_caches=None, rotary_embs=None, rotary_emb_dims=0, + beam_offset=None, seq_lens=None, time_step=None, ): @@ -1376,11 +1458,16 @@ def forward( inference and should be None for training. The shape is `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches - for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. - rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. - rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, - 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. - seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None. + for the generation model. The shape is + `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + rotary_embs (Tensor optional): The RoPE embs for the rotary computation. + The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, + and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, + 2 when rotary_embs and pos_extra_ids are both not None. Default 0. + seq_lens (Tensor optional): The sequence lengths of this batch. + The shape is `[bsz]`. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be @@ -1412,7 +1499,9 @@ def forward( self.ffn2_biases, pre_layer_norm=self.normalize_before, epsilon=self._epsilon, + residual_alpha=self._residual_alpha, cache_kvs=caches, + beam_offset=beam_offset, pre_caches=pre_caches, rotary_embs=rotary_embs, time_step=time_step, @@ -1425,6 +1514,9 @@ def forward( mode='upscale_in_train', trans_qkvw=self._trans_qkvw, ring_id=self._ring_id, + norm_type=self._norm_type, + use_neox_rotary_style=self._use_neox_rotary_style, + gqa_group_size=self._gqa_group_size, name=self.name, ) return out diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 9aebccc72d7bf..24c60af7499e6 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1216,6 +1216,7 @@ def _check_attr(attr, message): 'float16', 'float32', 'float64', + 'uint16', 'int32', 'int64', 'complex64', @@ -1918,7 +1919,7 @@ def diagflat(x, offset=0, name=None): check_dtype( x.dtype, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'diagflat', ) check_type(offset, 'offset', (int), 'diagflat') @@ -2041,6 +2042,7 @@ def diag(x, offset=0, padding_value=0, name=None): 'uint16', 'float32', 'float64', + 'uint16', 'int32', 'int64', 'complex64', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 5e36fbda8e874..575bc61fe2d38 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1855,7 +1855,7 @@ def t(input, name=None): check_variable_and_dtype( input, 'input', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'transpose', ) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 4680695d0e121..2de9d81b79d62 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1835,7 +1835,10 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): [9. , 18.]) """ check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'nansum' + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'nansum', ) check_type(axis, 'axis', (int, list, tuple, type(None)), 'nansum') @@ -3963,10 +3966,16 @@ def kron(x, y, name=None): else: helper = LayerHelper('kron', **locals()) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron' + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'kron', ) check_variable_and_dtype( - y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron' + y, + 'y', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'kron', ) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 405c903b8041f..161e6c317b83a 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -30,7 +30,6 @@ endif() list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention) -list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer) list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_save_load) list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert) @@ -901,8 +900,6 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) PROPERTIES TIMEOUT 120) set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120) - set_tests_properties(test_static_model_parallel_fused_multi_transformer - PROPERTIES TIMEOUT 120) set_tests_properties(test_pipeline_parallel PROPERTIES LABELS "RUN_TYPE=DIST") set_tests_properties(test_fleet_perf_test PROPERTIES LABELS "RUN_TYPE=DIST") diff --git a/test/legacy_test/static_model_parallel_fused_multi_transformer.py b/test/legacy_test/static_model_parallel_fused_multi_transformer.py deleted file mode 100644 index 1da41d0e0dd21..0000000000000 --- a/test/legacy_test/static_model_parallel_fused_multi_transformer.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from test_dist_base import TestDistRunnerBase, runtime_main - -import paddle -from paddle import base -from paddle.distributed import fleet -from paddle.incubate.nn import FusedMultiTransformer - -paddle.enable_static() - - -def get_param_attr(weight, bias): - weight_attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.Assign(weight) - ) - bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(bias)) - return weight_attr, bias_attr - - -DTYPE = "float32" -MODEL_PARALLEL_SIZE = 2 -num_head = 2 * MODEL_PARALLEL_SIZE -dim_head = 4 -hidden = num_head * dim_head -dim_ffn = 4 * hidden - - -def create_model(data, rank): - np.random.seed(2021) - ln_w = np.random.uniform(-1, 1, size=(hidden,)).astype(DTYPE) - ln_b = np.random.uniform(-1, 1, size=(hidden,)).astype(DTYPE) - qkv_w = np.random.uniform( - -1, 1, size=(3, num_head, dim_head, hidden) - ).astype(DTYPE) - qkv_b = np.random.uniform(-1, 1, size=(3, num_head, dim_head)).astype(DTYPE) - linear_w = np.random.uniform( - -1, 1, size=(num_head * dim_head, hidden) - ).astype(DTYPE) - linear_b = np.random.uniform(-1, 1, size=(hidden,)).astype(DTYPE) - - ffn_ln_w = np.random.uniform(-1, 1, size=(hidden,)).astype(DTYPE) - ffn_ln_b = np.random.uniform(-1, 1, size=(hidden,)).astype(DTYPE) - ffn1_w = np.random.uniform(-1, 1, size=(hidden, dim_ffn)).astype(DTYPE) - ffn1_b = np.random.uniform(-1, 1, size=(dim_ffn,)).astype(DTYPE) - ffn2_w = np.random.uniform(-1, 1, size=(dim_ffn, hidden)).astype(DTYPE) - ffn2_b = np.random.uniform(-1, 1, size=(hidden,)).astype(DTYPE) - - if rank is not None: - start = 0 if rank == 0 else (num_head // MODEL_PARALLEL_SIZE) - end = start + (num_head // MODEL_PARALLEL_SIZE) - col_qkv_w = qkv_w[:, start:end, :, :] - col_qkv_b = qkv_b[:, start:end, :] - row_linear_w = linear_w[(start * dim_head) : (end * dim_head), :] - - ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b) - qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b) - linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b) - - start = 0 if rank == 0 else (dim_ffn // MODEL_PARALLEL_SIZE) - end = start + (dim_ffn // MODEL_PARALLEL_SIZE) - col_ffn1_w = ffn1_w[:, start:end] - col_ffn1_b = ffn1_b[start:end] - row_ffn2_w = ffn2_w[start:end, :] - - ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b) - ffn1_w_attr, ffn1_b_attr = get_param_attr(col_ffn1_w, col_ffn1_b) - ffn2_w_attr, ffn2_b_attr = get_param_attr(row_ffn2_w, ffn2_b) - - multi_transformer = FusedMultiTransformer( - hidden, - num_head, - dim_ffn, - dropout_rate=0.0, - activation="gelu", - normalize_before=True, - ln_scale_attrs=[ln_w_attr], - ln_bias_attrs=[ln_b_attr], - qkv_weight_attrs=[qkv_w_attr], - qkv_bias_attrs=[qkv_b_attr], - linear_weight_attrs=[linear_w_attr], - linear_bias_attrs=[linear_b_attr], - ffn_ln_scale_attrs=[ffn_ln_w_attr], - ffn_ln_bias_attrs=[ffn_ln_b_attr], - ffn1_weight_attrs=[ffn1_w_attr], - ffn1_bias_attrs=[ffn1_b_attr], - ffn2_weight_attrs=[ffn2_w_attr], - ffn2_bias_attrs=[ffn2_b_attr], - nranks=MODEL_PARALLEL_SIZE, - ring_id=0, - ) - result = multi_transformer(data) - else: - ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b) - qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b) - linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b) - - ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b) - ffn1_w_attr, ffn1_b_attr = get_param_attr(ffn1_w, ffn1_b) - ffn2_w_attr, ffn2_b_attr = get_param_attr(ffn2_w, ffn2_b) - - multi_transformer = FusedMultiTransformer( - hidden, - num_head, - dim_ffn, - dropout_rate=0.0, - activation="gelu", - normalize_before=True, - ln_scale_attrs=[ln_w_attr], - ln_bias_attrs=[ln_b_attr], - qkv_weight_attrs=[qkv_w_attr], - qkv_bias_attrs=[qkv_b_attr], - linear_weight_attrs=[linear_w_attr], - linear_bias_attrs=[linear_b_attr], - ffn_ln_scale_attrs=[ffn_ln_w_attr], - ffn_ln_bias_attrs=[ffn_ln_b_attr], - ffn1_weight_attrs=[ffn1_w_attr], - ffn1_bias_attrs=[ffn1_b_attr], - ffn2_weight_attrs=[ffn2_w_attr], - ffn2_bias_attrs=[ffn2_b_attr], - ) - result = multi_transformer(data) - - # fused_multi_transformer have no backward - result.stop_gradient = True - predict = paddle.mean(result) - return predict - - -class TestModelParallel(TestDistRunnerBase): - def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): - # Input data - seq_len = 2 - data_in = paddle.static.data( - name='data_in', shape=[batch_size, seq_len, hidden], dtype=DTYPE - ) - - if dist_strategy: - data_loader = base.io.DataLoader.from_generator( - feed_list=[data_in], - capacity=64, - use_double_buffer=False, - iterable=False, - ) - - if dist_strategy: - fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - strategy.tensor_parallel = True - strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} - - rank = fleet.worker_index() if dist_strategy else None - avg_cost = create_model(data_in, rank) - opt = paddle.optimizer.SGD(0.1) - - if dist_strategy: - dist_opt = fleet.distributed_optimizer( - optimizer=opt, strategy=strategy - ) - dist_opt.minimize(avg_cost) - else: - opt.minimize(avg_cost) - - def gen_data(): - np.random.seed(2021) - while True: - data = [np.random.random([seq_len, hidden]).astype(DTYPE)] - yield data - - train_reader = paddle.batch(gen_data, batch_size=batch_size) - - if dist_strategy: - return None, avg_cost, train_reader, None, None, None, data_loader - else: - return None, avg_cost, train_reader, None, None, None - - -if __name__ == "__main__": - runtime_main(TestModelParallel) diff --git a/test/legacy_test/test_empty_like_op.py b/test/legacy_test/test_empty_like_op.py index f25ce87664d07..7a1e191c49460 100644 --- a/test/legacy_test/test_empty_like_op.py +++ b/test/legacy_test/test_empty_like_op.py @@ -39,7 +39,14 @@ def __check_out__(self, out): f'shape should be {self.dst_shape}, but get {shape}', ) - if data_type in ['float16', 'float32', 'float64', 'int32', 'int64']: + if data_type in [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ]: max_value = np.nanmax(out) min_value = np.nanmin(out) always_non_full_zero = max_value >= min_value diff --git a/test/legacy_test/test_empty_op.py b/test/legacy_test/test_empty_op.py index 49fb2526b7e53..2db103333a6cf 100644 --- a/test/legacy_test/test_empty_op.py +++ b/test/legacy_test/test_empty_op.py @@ -35,7 +35,14 @@ def test_check_output(self): def verify_output(self, outs): data_type = outs[0].dtype - if data_type in ['float16', 'float32', 'float64', 'int32', 'int64']: + if data_type in [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ]: max_value = np.nanmax(outs[0]) min_value = np.nanmin(outs[0]) diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index b7fec52341be6..9a94a052ebcd2 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -18,28 +18,40 @@ import numpy as np from op_test import OpTest +from test_sparse_attention_op import get_cuda_version import paddle import paddle.nn.functional as F from paddle import tensor +from paddle.base.framework import default_main_program +from paddle.base.param_attr import ParamAttr from paddle.incubate.nn import FusedMultiTransformer from paddle.incubate.nn.functional import fused_multi_transformer from paddle.nn.layer.common import Dropout, Linear from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.transformer import _convert_attention_mask -from paddle.pir_utils import test_with_pir_api seed = 42 random.seed(seed) +default_main_program().random_seed = seed np.random.seed(seed) paddle.seed(seed) +# now only support flash_attention_v2 and variable +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOp(OpTest): def setUp(self): - self.with_new_comm() self.config() + self.kv_num_heads = ( + self.gqa_group_size if self.gqa_group_size > 0 else self.num_heads + ) self.generate_input_data() self.rtol = 1e-5 @@ -58,8 +70,8 @@ def setUp(self): # use autograd to check grad in this unittest. self.__class__.no_need_check_grad = False - bias_attr = paddle.base.ParamAttr( - initializer=paddle.nn.initializer.Constant(value=0.0005) + bias_attr = ParamAttr( + initializer=paddle.paddle.nn.initializer.Constant(value=0.0005) ) self.q_proj = Linear( self.embed_dim, @@ -71,13 +83,13 @@ def setUp(self): self.k_proj = Linear( self.kdim, - self.embed_dim, + self.kv_num_heads * self.head_dim, self.weight_attr, bias_attr=self.bias_attr, ) self.v_proj = Linear( self.vdim, - self.embed_dim, + self.kv_num_heads * self.head_dim, self.weight_attr, bias_attr=self.bias_attr, ) @@ -109,9 +121,6 @@ def setUp(self): self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") self.activation = getattr(F, self.act_method) - def with_new_comm(self): - os.environ["FLAGS_dynamic_static_unified_comm"] = "0" - def config(self): # for debug self.debug = False @@ -131,6 +140,7 @@ def config(self): self.has_pre_cache = False self.rotary_embs = None self.rotary_emb_dims = 0 + self.neox_rotary_style = False self.remove_padding = False @@ -146,6 +156,10 @@ def config(self): self.num_heads = 16 self.embed_dim = self.head_dim * self.num_heads + # For GQA + self.gqa_group_size = -1 + self.use_fake_mha = False + self.dropout_prob = 0.0 self.attn_dropout_prob = 0.0 self.act_method = 'gelu' @@ -173,7 +187,7 @@ def generate_input_data(self): ( 2, self.batch_size, - self.num_heads, + self.kv_num_heads, self.cache_length, self.head_dim, ), @@ -207,7 +221,8 @@ def generate_input_data(self): ] = self.query_length self.seq_lens = np.array(self.seq_lens).astype(np.int32) - if self.has_pre_cache: + if self.has_pre_cache and self.gqa_group_size <= 0: + assert self.gqa_group_size <= 0, "GQA does not support pre cache" out_seq_len += self.pre_cache_num self.pre_cache_kv = np.random.uniform( -1, @@ -249,6 +264,9 @@ def generate_input_data(self): self.attn_mask = None if self.rotary_emb_dims > 0: + self.rotary_emb_dims = ( + 1 if not self.neox_rotary_style else self.rotary_emb_dims + ) self.rotary_emb = np.random.uniform( -1, 1, @@ -260,11 +278,18 @@ def generate_input_data(self): self.head_dim // 2 // self.rotary_emb_dims, ), ).astype(self.x_type) - concat_nums = 2 * self.rotary_emb_dims - rotary_embs = [] - for _ in range(concat_nums): - rotary_embs.append(self.rotary_emb) - self.rotary_embs = np.concatenate(rotary_embs, -1) + if self.neox_rotary_style: + concat_nums = 2 * self.rotary_emb_dims + rotary_embs = [] + for _ in range(concat_nums): + rotary_embs.append(self.rotary_emb) + self.rotary_embs = np.concatenate(rotary_embs, -1) + else: + rotary_emb = paddle.to_tensor(self.rotary_emb) + self.rotary_embs = paddle.reshape( + paddle.stack([rotary_emb, rotary_emb], axis=-1), + [2, self.batch_size, 1, self.query_length, self.head_dim], + ).numpy() self.key, self.value = self.query, self.query @@ -276,7 +301,7 @@ def rotate_half(self, x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return paddle.concat((-x2, x1), axis=-1) - def apply_rotary_emb(self, x, cos_emb, sin_emb, rotary_emb_dims): + def apply_neoXrotary_emb(self, x, cos_emb, sin_emb, rotary_emb_dims): # x shape [bsz, num_heads, seq_len, head_dim] # cos_emb, sin_emb shape [bsz, 1, seq_len, head_dim] x_dims = paddle.split(x, num_or_sections=rotary_emb_dims, axis=-1) @@ -294,6 +319,15 @@ def apply_rotary_emb(self, x, cos_emb, sin_emb, rotary_emb_dims): ) return paddle.concat(rotary_dims, axis=-1) + def apply_rotary_emb(self, x, cos_emb, sin_emb): + # x shape [bsz, num_heads, seq_len, head_dim] + # cos_emb, sin_emb shape [bsz, 1, seq_len, head_dim] + rotate_half_x = paddle.reshape( + paddle.stack([-x[:, :, :, 1::2], x[:, :, :, 0::2]], axis=-1), + paddle.shape(x), + ) + return x * cos_emb + rotate_half_x * sin_emb + def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) @@ -329,20 +363,28 @@ def GetBaselineOut(self): q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(ln1_out) v = self.v_proj(ln1_out) - k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.reshape( + x=k, shape=[0, 0, self.kv_num_heads, self.head_dim] + ) k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) - v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.reshape( + x=v, shape=[0, 0, self.kv_num_heads, self.head_dim] + ) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) if self.rotary_emb_dims > 0: cos_emb = rotary_embs[0] sin_emb = rotary_embs[1] - q_out = self.apply_rotary_emb( - q_out, cos_emb, sin_emb, self.rotary_emb_dims - ) - k_out = self.apply_rotary_emb( - k_out, cos_emb, sin_emb, self.rotary_emb_dims - ) + if self.neox_rotary_style: + q_out = self.apply_neoXrotary_emb( + q_out, cos_emb, sin_emb, self.rotary_emb_dims + ) + k_out = self.apply_neoXrotary_emb( + k_out, cos_emb, sin_emb, self.rotary_emb_dims + ) + else: + q_out = self.apply_rotary_emb(q_out, cos_emb, sin_emb) + k_out = self.apply_rotary_emb(k_out, cos_emb, sin_emb) if self.has_cache_kv: # [1, B, n_head, cache_seq_len, head_dim] @@ -493,12 +535,17 @@ def GetVariableDecoderBaselineOut(self): if self.rotary_emb_dims > 0: cos_emb = rotary_embs[0][i : i + 1] sin_emb = rotary_embs[1][i : i + 1] - q_out = self.apply_rotary_emb( - q_out, cos_emb, sin_emb, self.rotary_emb_dims - ) - k_out = self.apply_rotary_emb( - k_out, cos_emb, sin_emb, self.rotary_emb_dims - ) + + if self.neox_rotary_style: + q_out = self.apply_neoXrotary_emb( + q_out, cos_emb, sin_emb, self.rotary_emb_dims + ) + k_out = self.apply_neoXrotary_emb( + k_out, cos_emb, sin_emb, self.rotary_emb_dims + ) + else: + q_out = self.apply_rotary_emb(q_out, cos_emb, sin_emb) + k_out = self.apply_rotary_emb(k_out, cos_emb, sin_emb) if self.has_cache_kv: # [1, B, n_head, cache_seq_len, head_dim] @@ -639,7 +686,7 @@ def GetFusedMultiTransformerOut(self): if self.rotary_emb_dims > 0: rotary_embs = paddle.to_tensor( self.rotary_embs, stop_gradient=False - ) + ).astype('float32') else: rotary_embs = None @@ -793,6 +840,8 @@ def GetFusedMultiTransformerOut(self): dropout_rate=self.dropout_prob, activation=self.act_method, training=self.training, + use_neox_rotary_style=self.neox_rotary_style, + gqa_group_size=self.gqa_group_size, ) if self.has_cache_kv: @@ -1013,12 +1062,14 @@ def GetFusedMultiTransformerOutStatic(self): paddle.disable_static() return out - @test_with_pir_api def test_fused_multi_transformer_op(self): + if not self.remove_padding or self.gqa_group_size > 0: + return if self.has_cache_kv and not self.gen_cache_kv and self.remove_padding: final_out_ref = self.GetVariableDecoderBaselineOut() else: final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedMultiTransformerOut() if self.has_cache_kv: final_out, cache_kv_out = final_out @@ -1129,154 +1180,438 @@ def test_fused_multi_transformer_op(self): final_out_ref, final_out, rtol=self.rtol, atol=self.atol ) + def GetFusedMultiTransformerGQAOut(self): + if self.has_cache_kv and self.use_fake_mha: + if self.gen_cache_kv: + self.cache_kv[:] = 0 + shape = [ + 2, + self.batch_size, + self.num_heads, + self.cache_length, + self.head_dim, + ] + self.cache_kv = paddle.to_tensor(self.cache_kv) + self.cache_kv = paddle.stack( + [self.cache_kv] * (self.num_heads // self.kv_num_heads), axis=3 + ) -class TestFusedMultiTransformerOpWithNewComm(TestFusedMultiTransformerOp): - def with_new_comm(self): - os.environ["FLAGS_dynamic_static_unified_comm"] = "1" - - -class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.x_type = np.float16 - self.rotary_emb_dims = 1 + # import pdb;pdb.set_trace() + self.cache_kv = paddle.reshape(self.cache_kv, shape).numpy() -class TestFusedMultiTransformerOpGenRotaryFP16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.x_type = np.float16 - self.has_cache_kv = True - self.gen_cache_kv = False - self.query_length = 1 - self.key_length, self.value_length = ( - self.query_length, - self.query_length, + paddle.disable_static(place=paddle.CUDAPlace(0)) + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False + ) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False + ) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False ) - self.rotary_emb_dims = 2 - - -class TestFusedMultiTransformerOpGenCacheRotaryFP16( - TestFusedMultiTransformerOp -): - def config(self): - super().config() - self.x_type = np.float16 - self.has_cache_kv = True - self.gen_cache_kv = True - self.rotary_emb_dims = 1 - -class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.x_type = np.float16 - self.layers = 3 # odd layers + if self.use_fake_mha: + origin_shape = [self.embed_dim, self.embed_dim] + k_proj_weight = paddle.reshape( + k_proj_weight, + [self.embed_dim, self.kv_num_heads, self.head_dim], + ) + v_proj_weight = paddle.reshape( + v_proj_weight, + [self.embed_dim, self.kv_num_heads, self.head_dim], + ) -class TestFusedMultiTransformerOpActReluFp16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.x_type = np.float16 - self.act_method = "relu" - self.layers = 3 # odd layers + k_proj_weight = paddle.stack( + [k_proj_weight] * (self.num_heads // self.kv_num_heads), axis=-2 + ) + v_proj_weight = paddle.stack( + [v_proj_weight] * (self.num_heads // self.kv_num_heads), axis=-2 + ) + k_proj_weight = paddle.reshape(k_proj_weight, origin_shape) + v_proj_weight = paddle.reshape(v_proj_weight, origin_shape) -class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.has_cache_kv = True - self.query_length = 1 - self.key_length, self.value_length = 1, 1 - self.layers = 3 # odd layers + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False + ) + ffn1_weight = paddle.to_tensor( + self.ffn1_proj.weight, stop_gradient=False + ) + ffn2_weight = paddle.to_tensor( + self.ffn2_proj.weight, stop_gradient=False + ) + if self.bias_attr is False: + qkv_bias_tensor = None + out_linear_bias = None + else: + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False + ) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False + ) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False + ) + if self.use_fake_mha: + origin_shape = [self.embed_dim] + k_proj_bias = paddle.reshape( + k_proj_bias, [self.kv_num_heads, self.head_dim] + ) + v_proj_bias = paddle.reshape( + v_proj_bias, [self.kv_num_heads, self.head_dim] + ) -class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.has_cache_kv = True - self.query_length = 1 - self.key_length, self.value_length = 1, 1 - self.x_type = np.float16 + k_proj_bias = paddle.reshape( + paddle.stack( + [k_proj_bias] * (self.num_heads // self.kv_num_heads), + axis=-2, + ), + origin_shape, + ) + v_proj_bias = paddle.reshape( + paddle.stack( + [v_proj_bias] * (self.num_heads // self.kv_num_heads), + axis=-2, + ), + origin_shape, + ) + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()) + ) + if self.gqa_group_size <= 0 or self.use_fake_mha: + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + else: + qkv_bias = qkv_bias.reshape( + (self.num_heads + 2 * self.kv_num_heads, self.head_dim) + ) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False + ) + ffn1_bias = paddle.to_tensor( + self.ffn1_proj.bias, stop_gradient=False + ) + ffn2_bias = paddle.to_tensor( + self.ffn2_proj.bias, stop_gradient=False + ) -class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.has_cache_kv = True - self.gen_cache_kv = True + ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False) + ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False) + ffn_ln_scale = paddle.to_tensor( + self.ffn_norm.weight, stop_gradient=False + ) + ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False) + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight) + ) + if self.gqa_group_size <= 0 or self.use_fake_mha: + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim) + ) + else: + qkv_weight = qkv_weight.reshape( + ( + self.num_heads + 2 * self.kv_num_heads, + self.head_dim, + self.embed_dim, + ) + ) -class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.has_cache_kv = True - self.gen_cache_kv = True - self.x_type = np.float16 - self.layers = 3 # odd layers + if self.rotary_emb_dims > 0: + rotary_embs = paddle.to_tensor( + self.rotary_embs, stop_gradient=False + ).astype('float32') + else: + rotary_embs = None + x = paddle.to_tensor(self.query, stop_gradient=False) + cache_kvs, cache_kv = None, None + time_step = None + pre_caches = None -class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.x_type = np.float16 - self.layers = 3 # odd layers - self.pre_layer_norm = False + fuse_kv_num_heads = ( + self.kv_num_heads + if self.gqa_group_size > 0 and not self.use_fake_mha + else self.num_heads + ) + if self.has_cache_kv: + cache_kvs = [] + max_seq_length = (self.cache_length + 128) // 128 * 128 + cache_kv = np.zeros( + [ + 2, + self.batch_size, + fuse_kv_num_heads, + max_seq_length, + self.head_dim, + ], + dtype=self.x_type, + ) -class TestFusedMultiTransformerOpCacheKVPostLayerNorm( - TestFusedMultiTransformerOp -): - def config(self): - super().config() - self.has_cache_kv = True - self.query_length = 1 - self.key_length, self.value_length = 1, 1 - self.layers = 3 # odd layers - self.pre_layer_norm = False + elems = 4 + if self.x_type is np.float16: + elems = 8 + assert self.head_dim % elems == 0 + v_elems = self.head_dim // elems -class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16( - TestFusedMultiTransformerOp -): - def config(self): - super().config() - self.has_cache_kv = True - self.query_length = 1 - self.key_length, self.value_length = 1, 1 - self.x_type = np.float16 - self.pre_layer_norm = False + # [B, num_head, 128, head_dim] + # cache_k_tmp = self.cache_kv[0, :] + # [B, num_head, 128, head_dim / 4, 4] + cache_k_tmp = self.cache_kv[0].reshape( + [ + self.batch_size, + fuse_kv_num_heads, + self.cache_length, + v_elems, + elems, + ] + ) + # [B, num_head, head_dim / 4, 128, 4] + cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4]) + cache_kv[0, :].reshape( + [ + self.batch_size, + fuse_kv_num_heads, + v_elems, + max_seq_length, + elems, + ] + )[:, :, :, : self.cache_length, :] = cache_k_tmp -class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm( - TestFusedMultiTransformerOp -): - def config(self): - super().config() - self.has_cache_kv = True - self.gen_cache_kv = True - self.pre_layer_norm = False + cache_kv[1, :, :, : self.cache_length, :] = self.cache_kv[1] + if self.gen_cache_kv: + assert self.query_length == self.cache_length + cache_kv[:] = 0 + else: + time_step = paddle.to_tensor( + [self.cache_length], dtype='int32', place=paddle.CPUPlace() + ) + if self.remove_padding: + seq_lens = paddle.to_tensor(self.seq_lens, dtype='int32') + else: + seq_lens = None -class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16( - TestFusedMultiTransformerOp -): - def config(self): - super().config() - self.has_cache_kv = True - self.gen_cache_kv = True - self.x_type = np.float16 - self.layers = 3 # odd layers - self.pre_layer_norm = False + if self.has_pre_cache: + cache_kvs = [] + max_seq_length = ( + self.cache_length + 128 + ) // 128 * 128 + self.pre_cache_num + cache_kv = np.zeros( + [ + 2, + self.batch_size, + fuse_kv_num_heads, + max_seq_length, + self.head_dim, + ], + dtype=self.x_type, + ) + pre_caches = [] + + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + if attn_mask is not None and self.attn_mask_type != np.bool_: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) -class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp): - def config(self): - super().config() - self.has_pre_cache = True - self.x_type = np.float16 + qkv_weights, qkv_biases = [], [] + out_weights, out_biases = [], [] + ln_scales, ln_biases = [], [] + ffn1_weights, ffn1_biases = [], [] + ffn2_weights, ffn2_biases = [], [] + ffn_ln_scales, ffn_ln_biases = [], [] + for i in range(self.layers): + qkv_weights.append(qkv_weight_tensor) + qkv_biases.append(qkv_bias_tensor) + out_weights.append(out_linear_weight) + out_biases.append(out_linear_bias) + ln_scales.append(ln_scale) + ln_biases.append(ln_bias) + ffn1_weights.append(ffn1_weight) + ffn1_biases.append(ffn1_bias) + ffn2_weights.append(ffn2_weight) + ffn2_biases.append(ffn2_bias) + ffn_ln_scales.append(ffn_ln_scale) + ffn_ln_biases.append(ffn_ln_bias) + if self.has_cache_kv: + cache_kvs.append( + paddle.to_tensor(cache_kv, stop_gradient=False) + ) + if self.has_pre_cache: + cache_kvs.append( + paddle.to_tensor(cache_kv, stop_gradient=False) + ) + pre_caches.append( + paddle.to_tensor(self.pre_cache_kv, stop_gradient=False) + ) + + final_out = fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + out_weights, + out_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm=self.pre_layer_norm, + epsilon=epsilon, + cache_kvs=cache_kvs, + rotary_embs=rotary_embs, + rotary_emb_dims=self.rotary_emb_dims, + pre_caches=pre_caches, + time_step=time_step, + seq_lens=seq_lens, + attn_mask=attn_mask, + dropout_rate=self.dropout_prob, + activation=self.act_method, + training=self.training, + use_neox_rotary_style=self.neox_rotary_style, + gqa_group_size=self.gqa_group_size if not self.use_fake_mha else -1, + ) + + if self.has_cache_kv: + return final_out[0], final_out[1] + + if self.has_pre_cache: + return final_out[0] + + return final_out + + def test_fused_multi_transformer_gqa_op(self): + if not self.remove_padding or self.gqa_group_size <= 0: + return + + final_out = self.GetFusedMultiTransformerGQAOut() + self.use_fake_mha = True + final_out_ref = self.GetFusedMultiTransformerGQAOut() + + if self.has_cache_kv: + final_out, cache_kv_out = final_out + s = cache_kv_out[0].shape + bsz = s[1] + num_head = s[2] + max_seq_len = s[3] + head_dim = s[4] + elems = 8 if self.x_type is np.float16 else 4 + v_elems = head_dim // elems + + if self.debug: + print("cache_k out timestep=128") + print( + cache_kv_out[0].reshape( + [2, bsz, num_head, v_elems, max_seq_len, elems] + )[0, 0, 0, :, self.cache_length, :] + ) + + print("cache_v out timestep=128") + print(cache_kv_out[0][1, 0, 0, self.cache_length, :]) + + if self.remove_padding and not self.gen_cache_kv: + # test decoder + final_out_ref, cache_kvs = final_out_ref + for i in range(self.batch_size): + for j in range(self.layers): + cache_k = cache_kv_out[j][0, :, -1] + cache_v = cache_kv_out[j][1, :, -1] + + cache_k_ref = cache_kvs[j][0, :, -1] + cache_v_ref = cache_kvs[j][1, :, -1] + np.testing.assert_allclose( + cache_k_ref, + cache_k, + rtol=self.rtol, + atol=self.atol, + ) + np.testing.assert_allclose( + cache_v_ref, + cache_v, + rtol=self.rtol, + atol=self.atol, + ) + + if self.gen_cache_kv: + final_out_ref, cache_kvs = final_out_ref + for i in range(self.layers): + cache_k_ref = cache_kvs[i][0, :, 0] + cache_v_ref = cache_kvs[i][1, :, 0] + cache_k = cache_kv_out[i][0, :, 0] + cache_v = cache_kv_out[i][1, :, 0] + if self.remove_padding: + for i in range(self.batch_size): + np.testing.assert_allclose( + cache_k_ref, + cache_k, + rtol=self.rtol, + atol=self.atol, + ) + np.testing.assert_allclose( + cache_v_ref, + cache_v, + rtol=self.rtol, + atol=self.atol, + ) + else: + np.testing.assert_allclose( + cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol + ) + np.testing.assert_allclose( + cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol + ) + if i == 0: + break + + if self.remove_padding: + for i in range(self.batch_size): + np.testing.assert_allclose( + final_out_ref[i, : self.seq_lens[i]], + final_out[i, : self.seq_lens[i]], + rtol=self.rtol, + atol=self.atol, + ) + else: + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol + ) + + +class TestFusedMultiTransformerOpWithNewComm(TestFusedMultiTransformerOp): + def with_new_comm(self): + self.remove_padding = True + os.environ["FLAGS_dynamic_static_unified_comm"] = "1" + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableGenCache1(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1288,6 +1623,12 @@ def config(self): self.pre_layer_norm = False +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableGenCache2(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1295,8 +1636,15 @@ def config(self): self.gen_cache_kv = True self.remove_padding = True self.layers = 4 # even layers + self.x_type = np.float16 +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableGenCache3(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1305,8 +1653,15 @@ def config(self): self.remove_padding = True self.layers = 4 # even layers self.rotary_emb_dims = 2 + self.x_type = np.float16 +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableGenCache4(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1315,8 +1670,15 @@ def config(self): self.remove_padding = True self.layers = 3 # odd layers self.rotary_emb_dims = 2 + self.x_type = np.float16 +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableNormTransformer1( TestFusedMultiTransformerOp ): @@ -1330,6 +1692,12 @@ def config(self): self.pre_layer_norm = False +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableNormTransformer2( TestFusedMultiTransformerOp ): @@ -1339,8 +1707,15 @@ def config(self): self.gen_cache_kv = False self.remove_padding = True self.layers = 4 # even layers + self.x_type = np.float16 +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableDecoder1(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1354,6 +1729,12 @@ def config(self): self.pre_layer_norm = False +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableDecoder2(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1363,8 +1744,15 @@ def config(self): self.query_length = 1 self.key_length, self.value_length = 1, 1 self.layers = 4 # even layers + self.x_type = np.float16 +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) class TestFusedMultiTransformerOpVariableDecoder3(TestFusedMultiTransformerOp): def config(self): super().config() @@ -1375,13 +1763,159 @@ def config(self): self.key_length, self.value_length = 1, 1 self.layers = 4 # even layers self.rotary_emb_dims = 2 + self.x_type = np.float16 + + +# gqa test +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQAGenCache1( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = True + self.remove_padding = True + self.x_type = np.float16 + self.layers = 3 # odd layers + # self.pre_layer_norm = False + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQAGenCache2( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = True + self.remove_padding = True + self.layers = 4 # even layers + self.x_type = np.float16 + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQAGenCache3( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = True + self.remove_padding = True + self.layers = 4 # even layers + self.rotary_emb_dims = 2 + self.x_type = np.float16 + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQAGenCache4( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = True + self.remove_padding = True + self.layers = 3 # odd layers + self.rotary_emb_dims = 2 + self.x_type = np.float16 + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQADecoder1( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = False + self.remove_padding = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQADecoder2( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = False + self.remove_padding = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 4 # even layers + self.x_type = np.float16 + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or paddle.device.cuda.get_device_capability()[0] < 8, + "FusedMultiTransformer requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedMultiTransformerOpVariableGQADecoder3( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.gqa_group_size = 8 + self.has_cache_kv = True + self.gen_cache_kv = False + self.remove_padding = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 4 # even layers + self.rotary_emb_dims = 2 + self.x_type = np.float16 class TestFusedMultiTransformerOpPreCacheStatic1(TestFusedMultiTransformerOp): def config(self): super().config() self.has_attn_mask = False - self.x_type = np.float32 + self.x_type = np.float16 self.weight_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.Constant(0.0) ) @@ -1395,19 +1929,7 @@ def config(self): initializer=paddle.nn.initializer.Constant(0.0) ) - @test_with_pir_api def test_fused_multi_transformer_op(self): - self.has_pre_cache = True - self.remove_padding = False - self.rotary_emb_dims = 2 - self.generate_input_data() - final_out_ref = self.GetBaselineOut() - final_out = self.GetFusedMultiTransformerOutStatic()[0] - - np.testing.assert_allclose( - final_out_ref, final_out, rtol=self.rtol, atol=self.atol - ) - self.has_pre_cache = False self.remove_padding = True self.generate_input_data() diff --git a/test/legacy_test/test_static_model_parallel_fused_multi_transformer.py b/test/legacy_test/test_static_model_parallel_fused_multi_transformer.py deleted file mode 100644 index 729772699d90e..0000000000000 --- a/test/legacy_test/test_static_model_parallel_fused_multi_transformer.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest - -from test_dist_base import TestDistBase - -import paddle - -paddle.enable_static() -flag_name = os.path.splitext(__file__)[0] - - -class TestStaticModelParallel(TestDistBase): - def _setup_config(self): - self._sync_mode = True - self._use_reduce = False - self._use_reader_alloc = False - self._nccl_comm_num = 1 - self._pipeline_mode = True - - def test_dist_static_model_parallel_fused_multi_transformer(self): - from paddle import base - - if ( - base.core.is_compiled_with_cuda() - and not paddle.is_compiled_with_rocm() - ): - self.check_with_place( - "static_model_parallel_fused_multi_transformer.py", - delta=1e-5, - check_error_log=True, - log_name=flag_name, - ) - - -if __name__ == '__main__': - unittest.main()