diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index a437ec46d7fe2..47d6a76f2c6c1 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -62,7 +62,7 @@ else() set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS}) set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) - set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") + #set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) endif() diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 96e64e387e693..8d7919ac06874 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -263,12 +263,9 @@ template class AttnMatMulWeightOnly { #if defined(PADDLE_WITH_CUTLASS) using InputType = typename phi::PDDataTypeTraits::DataType; - using GemRunnerInt8 = - phi::CutlassFpAIntBGemmRunner; + using GemRunnerInt8 = phi::CutlassFpAIntBGemmRunner; using GemRunnerInt4 = - phi::CutlassFpAIntBGemmRunner; + phi::CutlassFpAIntBGemmRunner; #endif public: // (m, n, k) = bsz_seq, output_size, input_size @@ -277,11 +274,11 @@ class AttnMatMulWeightOnly { ~AttnMatMulWeightOnly() {} // get activation - int GetActivation(const std::string &act_method) { + int GetActivation(const std::string& act_method) { #if defined(PADDLE_WITH_CUTLASS) - return static_cast(phi::getActivationType(act_method)); + return static_cast(phi::getActivationType(act_method)); #else - return 0; + return 0; #endif } void Linear(const phi::DenseTensor& x, @@ -311,33 +308,30 @@ class AttnMatMulWeightOnly { dev_ctx_.template GetWorkSpacePtr(mixgemm_workspace_size_bytes)); if (bias_data) { mixed_gemm_runner_int4_.gemm_bias_act( - reinterpret_cast( - x_data), + reinterpret_cast(x_data), reinterpret_cast(weight_data), - reinterpret_cast(weight_scale_data), - reinterpret_cast( - bias_data), - reinterpret_cast(out_data), + reinterpret_cast(weight_scale_data), + reinterpret_cast(bias_data), + reinterpret_cast(out_data), m, n, k, static_cast(act_method), mixgemm_workspace_data, mixgemm_workspace_size_bytes, - dev_ctx_.stream()); + dev_ctx_.stream()); } else { mixed_gemm_runner_int4_.gemm( - reinterpret_cast( - x_data), + reinterpret_cast(x_data), reinterpret_cast(weight_data), - reinterpret_cast(weight_scale_data), - reinterpret_cast(out_data), + reinterpret_cast(weight_scale_data), + reinterpret_cast(out_data), m, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, - dev_ctx_.stream()); + dev_ctx_.stream()); } } else { int mixgemm_max_size = std::max(m, k); @@ -348,27 +342,24 @@ class AttnMatMulWeightOnly { dev_ctx_.template GetWorkSpacePtr(mixgemm_workspace_size_bytes)); if (bias_data) { mixed_gemm_runner_int8_.gemm_bias_act( - reinterpret_cast( - x_data), + reinterpret_cast(x_data), reinterpret_cast(weight_data), - reinterpret_cast(weight_scale_data), - reinterpret_cast( - bias_data), - reinterpret_cast(out_data), + reinterpret_cast(weight_scale_data), + reinterpret_cast(bias_data), + reinterpret_cast(out_data), m, n, k, - static_cast(act_method), + static_cast(act_method), mixgemm_workspace_data, mixgemm_workspace_size_bytes, dev_ctx_.stream()); } else { mixed_gemm_runner_int8_.gemm( - reinterpret_cast( - x_data), + reinterpret_cast(x_data), reinterpret_cast(weight_data), - reinterpret_cast(weight_scale_data), - reinterpret_cast(out_data), + reinterpret_cast(weight_scale_data), + reinterpret_cast(out_data), m, n, k, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu index d214dfffd1d96..6a9675d4f7d13 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu @@ -8,12 +8,16 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -//#define DEBUG_MOE_TMPROFILE +// #define DEBUG_MOE_TMPROFILE #include "paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" #ifdef DEBUG_MOE_TMPROFILE #include "paddle/fluid/platform/timer.h" #endif +#if defined(PADDLE_WITH_CUTLASS) +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" +#endif +DECLARE_bool(enable_moe_gemm_cutlass); namespace paddle { namespace operators { @@ -106,6 +110,16 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { output_size, input_size, compute_bias); +#if defined(PADDLE_WITH_CUTLASS) + using InputType = typename phi::PDDataTypeTraits::DataType; + phi::MoeGemmRunner gemm_runner; + auto default_act = phi::getActivationType("none"); + auto expert_act = phi::getActivationType(act_method); +#else + PADDLE_ENFORCE_EQ(!FLAGS_enable_moe_gemm_cutlass, + "not support cutlass fused moe gemm please disable " + "FLAGS_enable_moe_gemm_cutlass"); +#endif Tensor qkv_out; qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); auto *qkv_out_data = @@ -202,7 +216,10 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { auto expert_weights2 = ctx.MultiInput("ExpertWeight2"); auto expert_biases2 = ctx.MultiInput("ExpertBias2"); int dim_feedforward = expert_weights1[0]->dims()[1]; - // int dim_feedforward = expert_weights1[0]->dims()[2]; // batched gemm + // gemm cutlass used ColumnMajor store + if (FLAGS_enable_moe_gemm_cutlass) { + dim_feedforward = expert_weights1[0]->dims()[0]; // batched gemm + } int topk = ctx.Attr("topk"); int mp_size = ctx.Attr("mp_size"); int mp_rank = ctx.Attr("mp_rank"); @@ -244,15 +261,19 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { local_expert_count.numel() * sizeof(int64_t)); dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size - Tensor fwd_expert_count, fwd_batch_size; - Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + Tensor fwd_expert_count, fwd_expert_csum_len; + Tensor fwd_expert_csum_len_cpu; fwd_expert_count.Resize({{num_expert}}); - fwd_batch_size.Resize({{1}}); + fwd_expert_csum_len.Resize({{num_expert + 1}}); dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); - dev_ctx.Alloc(&fwd_batch_size, - fwd_batch_size.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&fwd_expert_csum_len, + fwd_expert_csum_len.numel() * sizeof(int64_t)); + phi::funcs::set_constant( + dev_ctx, &fwd_expert_csum_len, static_cast(0)); + // pos, temp pos Tensor pos, temp_pos; pos.Resize({{out_batch_size}}); @@ -326,16 +347,6 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { ln_tm.Pause(); #endif } - // auto *ln_scale_data = ln_scales[i]->data(); - // auto *ln_bias_data = ln_biases[i]->data(); - // // TODO(wangxi): can remove mean var in inference - // ln_compute.ComputeForward(x_data, - // ln_scale_data, - // ln_bias_data, - // buf0.data(), - // ln_mean_data, - // ln_var_data); - // step2. qkv #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2, qkv"; @@ -559,42 +570,26 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { false, &fwd_expert_count); // fwd batch size - phi::SumKernel( - dev_ctx, - fwd_expert_count, - phi::IntArray({}), // axis is None - fwd_expert_count.dtype(), - false, - &fwd_batch_size); + phi::CumsumTensorValue( + dev_ctx, fwd_expert_count, &fwd_expert_csum_len, 1); // step4.3 cumsum & assign pos #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, cumsum"; #endif - phi::CumsumKernel( - dev_ctx, local_expert_count, 0, false, false, false, &lec_cum); + phi::CumsumTensorValue(dev_ctx, local_expert_count, &lec_cum); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, assign pos"; #endif - phi::AssignPosCompute( - dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "moe, floor divide"; -#endif - if (topk > 1) { - phi::FloorDivideKernel( - dev_ctx, pos, topk_tensor, &temp_pos); - } else { - temp_pos = pos; - } + phi::AssignInsAndPosCompute( + dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size, topk, &temp_pos); + #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, tensor copy"; #endif framework::TensorCopy( - fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); - framework::TensorCopy( - fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + fwd_expert_csum_len, platform::CPUPlace(), &fwd_expert_csum_len_cpu); dev_ctx.Wait(); - int fwd_bsz = fwd_batch_size_cpu.data()[0]; + int fwd_bsz = fwd_expert_csum_len_cpu.data()[num_expert]; Tensor global_scatter_out; global_scatter_out.Resize({{fwd_bsz, dim_embed}}); @@ -605,9 +600,6 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { all_expert_out.Resize({{fwd_bsz, dim_embed}}); dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); - // global_scatter_out.Resize({{fwd_bsz, dim_embed}}); - // all_expert_out.Resize({{fwd_bsz, dim_embed}}); - // step 5, MOEScatter // step 5.1, index select // suppose tmp_pos->shape != [0] @@ -644,76 +636,148 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { #endif if (fwd_bsz != 0) { // encoder, use matmul - int last_index = 0; - for (int idx = 0; idx < num_expert; idx++) { - int cur_expert_count = fwd_expert_count_cpu.data()[idx]; - if (cur_expert_count <= 0) { - continue; - } - int end = cur_expert_count + last_index; + Tensor expert_out1; + if (FLAGS_enable_moe_gemm_cutlass) { +#if defined(PADDLE_WITH_CUTLASS) + int expert_idx = i * num_expert; + // csum length + int64_t *total_rows_before_expert = + fwd_expert_csum_len.data(); + const T *permuted_data = global_scatter_out.data(); + const T *fc1_expert_weights = expert_weights1[expert_idx]->data(); + const T *fc_scales = nullptr; + const T *fc1_expert_biases = expert_biases1[expert_idx]->data(); - Tensor expert_out1; - expert_out1.Resize({{cur_expert_count, dim_feedforward}}); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + std::ostringstream ostr; + int64_t *pnum = fwd_expert_csum_len_cpu.data(); + for (int j = 0; j <= num_expert; ++j) { + ostr << pnum[j] << ","; + } + VLOG(0) + << "layer id=" << i << ", expert_idx=" << expert_idx + << ", numel=" << fwd_expert_count.numel() + << ", dim_feedforward=" << dim_feedforward + << ", dim_embed=" << dim_embed << ", num_expert=" << num_expert + << ", global_scatter_out=" << global_scatter_out.dims() + << ", expert_weights1=" << expert_weights1[expert_idx]->dims() + << ", start ptr=" + << (int64_t)(expert_weights1[expert_idx]->data()) << ", end ptr=" + << (int64_t)(expert_weights1[expert_idx + num_expert - 1]->data()) + << ", numel=" << expert_weights1[expert_idx]->numel() + << ", expert_weights2=" << expert_weights2[expert_idx]->dims() + << ", expert nums=" << ostr.str(); +#endif + + expert_out1.Resize({{ fwd_bsz, dim_feedforward }}); dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); - Tensor tmp_inp = global_scatter_out.Slice(last_index, end); - int expert_idx = i * num_expert + idx; - // cuda 11.4 + T *fc1_result = expert_out1.data(); + + gemm_runner.moe_gemm_bias_act( + reinterpret_cast(permuted_data), + reinterpret_cast(fc1_expert_weights), + reinterpret_cast(fc_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_result), + total_rows_before_expert, + fwd_bsz, + dim_feedforward, + dim_embed, + num_expert, + static_cast(expert_act), + dev_ctx.stream()); + + const T *fc2_expert_weights = expert_weights2[expert_idx]->data(); + const T *fc2_expert_biases = expert_biases2[expert_idx]->data(); + T *fc2_result = all_expert_out.data(); + + gemm_runner.moe_gemm_bias_act( + reinterpret_cast(fc1_result), + reinterpret_cast(fc2_expert_weights), + reinterpret_cast(fc_scales), + reinterpret_cast(fc2_expert_biases), + reinterpret_cast(fc2_result), + total_rows_before_expert, + fwd_bsz, + dim_embed, + dim_feedforward, + num_expert, + static_cast(default_act), + dev_ctx.stream()); +#endif + } else { + int last_index = 0; + int64_t *csum_len = fwd_expert_csum_len_cpu.data(); + for (int idx = 0; idx < num_expert; idx++) { + int end = csum_len[idx + 1]; + int cur_expert_count = end - last_index; + if (cur_expert_count <= 0) { + continue; + } + + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); + int expert_idx = i * num_expert + idx; + // cuda 11.4 #if (CUDA_VERSION >= 11040) - phi::MatMulAndAddGelu(dev_ctx, - expert_weights1[expert_idx], - &tmp_inp, - expert_biases1[expert_idx], - false, - false, - false, // dont compute bias - &expert_out1); + phi::MatMulAndAddGelu(dev_ctx, + expert_weights1[expert_idx], + &tmp_inp, + expert_biases1[expert_idx], + false, + false, + false, // dont compute bias + &expert_out1); #else - // linear1 matmul - // VLOG(0) << "moe, Expert Computation, linear1 mul"; - phi::MatMulAndAdd(dev_ctx, - expert_weights1[expert_idx], - &tmp_inp, - nullptr, - false, - false, - false, // dont compute bias - &expert_out1, - nullptr); - // bias gelu - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, cur_expert_count, dim_feedforward, dropout_param); - // VLOG(0) << "moe, Expert Computation, add bias & gelu"; - // inplace - fused_act_dropout_helper.DropoutActBias( - dev_ctx, - expert_out1.data(), - expert_biases1[expert_idx]->data(), - "gelu", - expert_out1.data(), - nullptr, - 1.0, - nullptr, - 0, - 1.0, - 1, - 127.0, - -127.0, - approximate); -#endif - // linear2 matmul & add - // VLOG(0) << "moe, Expert Computation, linear2 matmul & add"; - Tensor expert_out2 = all_expert_out.Slice(last_index, end); - phi::MatMulAndAdd(dev_ctx, - expert_weights2[expert_idx], - &expert_out1, - expert_biases2[expert_idx], - false, - false, - true, // compute bias - &expert_out2, - &expert_out2); - last_index = end; + // linear1 matmul + // VLOG(0) << "moe, Expert Computation, linear1 mul"; + phi::MatMulAndAdd(dev_ctx, + expert_weights1[expert_idx], + &tmp_inp, + nullptr, + false, + false, + false, // dont compute bias + &expert_out1, + nullptr); + // bias gelu + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + // VLOG(0) << "moe, Expert Computation, add bias & gelu"; + // inplace + fused_act_dropout_helper.DropoutActBias( + dev_ctx, + expert_out1.data(), + expert_biases1[expert_idx]->data(), + "gelu", + expert_out1.data(), + nullptr, + 1.0, + nullptr, + 0, + 1.0, + 1, + 127.0, + -127.0, + approximate); +#endif + // linear2 matmul & add + // VLOG(0) << "moe, Expert Computation, linear2 matmul & add"; + Tensor expert_out2 = all_expert_out.Slice(last_index, end); + phi::MatMulAndAdd(dev_ctx, + expert_weights2[expert_idx], + &expert_out1, + expert_biases2[expert_idx], + false, + false, + true, // compute bias + &expert_out2, + &expert_out2); + last_index = end; + } } // at last, concat all expert out } else { diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu index 9e996392f8f8e..37e477c96db2d 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu @@ -8,16 +8,16 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define DEBUG_PRINT_LINEAR_SHAPE -#define DEBUG_TMPROFILE_WEIGHT_ONLY +// #define DEBUG_PRINT_LINEAR_SHAPE +// #define DEBUG_TMPROFILE_WEIGHT_ONLY #include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" #ifdef DEBUG_TMPROFILE_WEIGHT_ONLY #include "paddle/fluid/platform/timer.h" #endif #include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/moe_expert_gemm.h" #include "paddle/phi/common/datatype_traits.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" -#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" #include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h" #include "paddle/phi/kernels/weight_only_linear_kernel.h" @@ -26,7 +26,6 @@ PADDLE_DEFINE_EXPORTED_bool(enable_moe_gemm_cutlass, "enable moe gemm cutlass ,default false"); namespace paddle { namespace operators { - using Tensor = phi::DenseTensor; // #define _DEBUG_FUSED_MULTI_TRANSFORMER inline bool CheckFlashAttn(const phi::GPUContext &dev_ctx, @@ -105,20 +104,24 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel const auto qkv_w_dims = qkv_weights[0]->dims(); int num_head = qkv_w_dims[1]; int dim_head = qkv_w_dims[2]; - if (weight_dtype == "int4") { + const bool is_int4 = (weight_dtype == "int4"); + if (is_int4) { // int4 weight: [3, num_head, dim_head / 2, dim_embed] dim_head = dim_head * 2; } int hidden_size = num_head * dim_head; int qkv_output_size = 3 * hidden_size; // weight only gemm - auto weight_only_gemm = - AttnMatMulWeightOnly(dev_ctx, (weight_dtype == "int4")); + auto weight_only_gemm = AttnMatMulWeightOnly(dev_ctx, is_int4); int default_act = weight_only_gemm.GetActivation("none"); int expert_act = weight_only_gemm.GetActivation(act_method); - using InputType = typename phi::PDDataTypeTraits::DataType; - phi::MoeGemmRunner gemm_runner; +#ifndef PADDLE_WITH_CUTLASS + PADDLE_ENFORCE_EQ(!FLAGS_enable_moe_gemm_cutlass, + "not support cutlass fused moe gemm please disable " + "FLAGS_enable_moe_gemm_cutlass"); +#endif + auto moe_expert_gemm = MoeExpertGemmWeightOnly(dev_ctx, is_int4); Tensor qkv_out; qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); @@ -218,7 +221,7 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel // expert_weights1: int8 [dim_feedforward, dim_embed] int8 [dim_feedforward // / 2, dim_embed] int dim_feedforward = expert_weights1[0]->dims()[0]; - if (weight_dtype == "int4") { + if (is_int4) { dim_feedforward = dim_feedforward * 2; } @@ -264,19 +267,25 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); // fwd_expert_count, fwd_batch_size - Tensor fwd_expert_count, fwd_batch_size; - Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + Tensor fwd_expert_count, fwd_expert_csum_len; + Tensor fwd_expert_csum_len_cpu; fwd_expert_count.Resize({{num_expert}}); - fwd_batch_size.Resize({{1}}); + fwd_expert_csum_len.Resize({{num_expert + 1}}); dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); - dev_ctx.Alloc(&fwd_batch_size, - fwd_batch_size.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&fwd_expert_csum_len, + fwd_expert_csum_len.numel() * sizeof(int64_t)); + phi::funcs::set_constant( + dev_ctx, &fwd_expert_csum_len, static_cast(0)); + // pos, temp pos - Tensor pos; + Tensor pos, ins_pos; pos.Resize({{out_batch_size}}); + ins_pos.Resize({{out_batch_size}}); dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); - + if (topk > 1) { + dev_ctx.Alloc(&ins_pos, ins_pos.numel() * sizeof(int64_t)); + } // cumsum Tensor lec_cum; lec_cum.Resize({{tot_expert}}); @@ -311,7 +320,7 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel dev_ctx.Alloc(&buf0, buf0.numel() * sizeof(T)); moe_out.ShareDataWith(*out); moe_out.Resize({{bsz_seq, dim_embed}}); - // expert + // expert Tensor expert_out1; Tensor global_scatter_out; Tensor all_expert_out; @@ -552,7 +561,9 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel gate_nccl_tm.Resume(); #endif if (world_size > 1) { +#ifdef DEBUG_PRINT_LINEAR_SHAPE VLOG(0) << "layer id=" << i << ", begin all2all"; +#endif moe_pg.AllToAll(local_expert_count, global_expert_count); } else { global_expert_count = local_expert_count; @@ -571,29 +582,18 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel false, &fwd_expert_count); // fwd batch size - phi::SumKernel( - dev_ctx, - fwd_expert_count, - phi::IntArray({}), // axis is None - fwd_expert_count.dtype(), - false, - &fwd_batch_size); + phi::CumsumTensorValue( + dev_ctx, fwd_expert_count, &fwd_expert_csum_len, 1); // step4.3 cumsum & assign pos - phi::CumsumKernel( - dev_ctx, local_expert_count, 0, false, false, false, &lec_cum); - // phi::funcs::set_constant(dev_ctx, &pos, static_cast(-1)); - phi::AssignPosCompute( - dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); - if (topk > 1) { - phi::FloorDivideKernel( - dev_ctx, pos, topk_tensor, &pos); - } - framework::TensorCopy( - fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + phi::CumsumTensorValue(dev_ctx, local_expert_count, &lec_cum); + // 1. assign pos and input ins pos + phi::AssignInsAndPosCompute( + dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size, topk, &ins_pos); + framework::TensorCopy( - fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + fwd_expert_csum_len, platform::CPUPlace(), &fwd_expert_csum_len_cpu); dev_ctx.Wait(); - int fwd_bsz = fwd_batch_size_cpu.data()[0]; + int fwd_bsz = fwd_expert_csum_len_cpu.data()[num_expert]; global_scatter_out.Resize({{fwd_bsz, dim_embed}}); dev_ctx.Alloc(&global_scatter_out, @@ -605,7 +605,7 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel // step 5, MOEScatter // step 5.1, index select phi::IndexSelectKernel( - dev_ctx, sliced_inp, pos, 0, &index_select_out); + dev_ctx, sliced_inp, ins_pos, 0, &index_select_out); #ifdef DEBUG_TMPROFILE_WEIGHT_ONLY dev_ctx.Wait(); gate_tm.Pause(); @@ -614,7 +614,10 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel scatter_tm.Resume(); #endif if (world_size > 1) { - VLOG(0) << "layer id=" << i << ", begin scatter x=" << index_select_out.dims(); +#ifdef DEBUG_PRINT_LINEAR_SHAPE + VLOG(0) << "layer id=" << i + << ", begin scatter x=" << index_select_out.dims(); +#endif moe_pg.Scatter(&index_select_out, local_expert_count, global_expert_count, @@ -630,10 +633,9 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel expert_tm.Resume(); #endif #ifdef DEBUG_PRINT_LINEAR_SHAPE - VLOG(0) << "layer id=" << i - << ", begin expert fwd_bsz=" << fwd_bsz - << ", dim_feedforward=" << dim_feedforward - << ", dim_embed=" << dim_embed; + VLOG(0) << "layer id=" << i << ", begin expert fwd_bsz=" << fwd_bsz + << ", dim_feedforward=" << dim_feedforward + << ", dim_embed=" << dim_embed; #endif // step 6, Expert Computation if (fwd_bsz != 0) { @@ -641,68 +643,54 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel if (FLAGS_enable_moe_gemm_cutlass) { int expert_idx = i * num_expert; #ifdef DEBUG_PRINT_LINEAR_SHAPE - VLOG(0) << "layer id=" << i << ", expert_idx=" << expert_idx - << ", numel=" << fwd_expert_count.numel() - << ", dim_feedforward=" << dim_feedforward - << ", dim_embed=" << dim_embed - << ", num_expert=" << num_expert - << ", global_scatter_out=" << global_scatter_out.dims() - << ", expert_weights1=" << expert_weights1[expert_idx]->dims() - << ", expert_weights2=" << expert_weights2[expert_idx]->dims(); -#endif - int64_t *total_rows_before_expert = fwd_expert_count.data(); - const T *permuted_data = global_scatter_out.data(); - const int8_t *fc1_expert_weights = - expert_weights1[expert_idx]->data(); - const T *fc1_scales = expert_scales1[expert_idx]->data(); - const T *fc1_expert_biases = expert_biases1[expert_idx]->data(); - + std::ostringstream ostr; + int64_t *pnum = fwd_expert_csum_len_cpu.data(); + for (int j = 0; j <= num_expert; ++j) { + ostr << pnum[j] << ","; + } + VLOG(0) + << "layer id=" << i << ", expert_idx=" << expert_idx + << ", numel=" << fwd_expert_count.numel() + << ", dim_feedforward=" << dim_feedforward + << ", dim_embed=" << dim_embed << ", num_expert=" << num_expert + << ", global_scatter_out=" << global_scatter_out.dims() + << ", expert_weights1=" << expert_weights1[expert_idx]->dims() + << ", start ptr=" + << (int64_t)(expert_weights1[expert_idx]->data()) << ", end ptr=" + << (int64_t)(expert_weights1[expert_idx + num_expert - 1]->data()) + << ", numel=" << expert_weights1[expert_idx]->numel() + << ", expert_weights2=" << expert_weights2[expert_idx]->dims() + << ", expert nums=" << ostr.str(); +#endif + // step 6.1, expert gemm expert_out1.Resize({{fwd_bsz, dim_feedforward}}); dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); - T *fc1_result = expert_out1.data(); - - gemm_runner.moe_gemm_bias_act( - reinterpret_cast(permuted_data), - reinterpret_cast(fc1_expert_weights), - reinterpret_cast(fc1_scales), - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_result), - total_rows_before_expert, - fwd_bsz, - dim_feedforward, - dim_embed, - num_expert, - static_cast(expert_act), - dev_ctx.stream()); - - const int8_t *fc2_expert_weights = - expert_weights2[expert_idx]->data(); - const T *fc2_scales = expert_scales2[expert_idx]->data(); - const T *fc2_expert_biases = expert_biases2[expert_idx]->data(); - T *fc2_result = all_expert_out.data(); - - gemm_runner.moe_gemm_bias_act( - reinterpret_cast(fc1_result), - reinterpret_cast(fc2_expert_weights), - reinterpret_cast(fc2_scales), - reinterpret_cast(fc2_expert_biases), - reinterpret_cast(fc2_result), - total_rows_before_expert, - fwd_bsz, - dim_embed, - dim_feedforward, - num_expert, - static_cast(default_act), - dev_ctx.stream()); + moe_expert_gemm.moe_gemm(fwd_expert_csum_len, + global_scatter_out, + expert_weights1[expert_idx], + expert_scales1[expert_idx], + expert_biases1[expert_idx], + expert_weights2[expert_idx], + expert_scales2[expert_idx], + expert_biases2[expert_idx], + fwd_bsz, + dim_feedforward, + dim_embed, + num_expert, + expert_act, + default_act, + &expert_out1, + &all_expert_out); } else { int last_index = 0; + int64_t *csum_len = fwd_expert_csum_len_cpu.data(); for (int idx = 0; idx < num_expert; idx++) { - int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + int end = csum_len[idx + 1]; + int cur_expert_count = end - last_index; if (cur_expert_count <= 0) { continue; } - int end = cur_expert_count + last_index; expert_out1.Resize({{cur_expert_count, dim_feedforward}}); dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); @@ -714,9 +702,13 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel #ifdef DEBUG_PRINT_LINEAR_SHAPE VLOG(0) << "expert id=" << idx << ", liner1 input=" << tmp_inp.dims() - << ", weight=" << expert_weights1[expert_idx]->dims() << ", ptr=" << (int64_t)(expert_weights1[expert_idx]->data()) - << ", bias=" << expert_biases1[expert_idx]->dims() << ", ptr=" << (int64_t)(expert_biases1[expert_idx]->data()) - << ", scale=" << expert_scales1[expert_idx]->dims() << ", ptr=" << (int64_t)(expert_scales1[expert_idx]->data()) + << ", weight=" << expert_weights1[expert_idx]->dims() + << ", ptr=" + << (int64_t)(expert_weights1[expert_idx]->data()) + << ", bias=" << expert_biases1[expert_idx]->dims() + << ", ptr=" << (int64_t)(expert_biases1[expert_idx]->data()) + << ", scale=" << expert_scales1[expert_idx]->dims() + << ", ptr=" << (int64_t)(expert_scales1[expert_idx]->data()) << ", expert_out1=" << expert_out1.dims(); #endif @@ -752,7 +744,8 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel default_act, // none &expert_out2); #ifdef DEBUG_PRINT_LINEAR_SHAPE - VLOG(0) << "layer id=" << i << ", expert_idx=" << expert_idx << " end"; + VLOG(0) << "layer id=" << i << ", expert_idx=" << expert_idx + << " end"; dev_ctx.Wait(); #endif last_index = end; @@ -769,8 +762,12 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel #endif // step7. MOEGather if (world_size > 1) { - VLOG(0) << "layer id=" << i << ", begin gather data all_expert_out=" << all_expert_out.dims() - << ", global_gather_out=" << global_gather_out.dims() << ", pos=" << pos.dims(); +#ifdef DEBUG_PRINT_LINEAR_SHAPE + VLOG(0) << "layer id=" << i << ", begin gather data all_expert_out=" + << all_expert_out.dims() + << ", global_gather_out=" << global_gather_out.dims() + << ", pos=" << pos.dims(); +#endif moe_pg.Gather(&all_expert_out, &global_gather_out); } else { global_gather_out = all_expert_out; @@ -779,10 +776,12 @@ class FusedMultiTransformerMoeWeightOnlyOpKernel dev_ctx.Wait(); gather_tm.Pause(); #endif - VLOG(0) << "layer id=" << i - << ", begin global_gather_out=" << global_gather_out.dims() - << ", pos=" << pos.dims() - << ", moe_gather_out=" << moe_gather_out.dims(); +#ifdef DEBUG_PRINT_LINEAR_SHAPE + VLOG(0) << "layer id=" << i + << ", begin global_gather_out=" << global_gather_out.dims() + << ", pos=" << pos.dims() + << ", moe_gather_out=" << moe_gather_out.dims(); +#endif // step 7.2, local_gather or scatter phi::funcs::GPUScatterAssign( dev_ctx, global_gather_out, pos, &moe_gather_out, true); @@ -867,4 +866,4 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( fused_multi_transformer_moe_weight_only, - ops::FusedMultiTransformerMoeWeightOnlyOpKernel); + ops::FusedMultiTransformerMoeWeightOnlyOpKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/moe_expert_gemm.h b/paddle/fluid/operators/fused/moe_expert_gemm.h new file mode 100644 index 0000000000000..3fcf40c4a59e6 --- /dev/null +++ b/paddle/fluid/operators/fused/moe_expert_gemm.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#if defined(PADDLE_WITH_CUTLASS) +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" +#endif +namespace paddle { +namespace operators { + +template +class MoeExpertGemmWeightOnly { +#if defined(PADDLE_WITH_CUTLASS) + using InputType = typename phi::PDDataTypeTraits::DataType; + using GemRunnerInt8 = phi::MoeGemmRunner; + using GemRunnerInt4 = phi::MoeGemmRunner; +#endif + public: + MoeExpertGemmWeightOnly(const phi::GPUContext &dev_ctx, bool is_uint4) + : dev_ctx_(dev_ctx), is_uint4_(is_uint4) {} + + ~MoeExpertGemmWeightOnly() {} + void moe_gemm(const phi::DenseTensor &expert_rows_pos, + const phi::DenseTensor &x, + const phi::DenseTensor *expert_weights1, + const phi::DenseTensor *expert_scales1, + const phi::DenseTensor *expert_biases1, + const phi::DenseTensor *expert_weights2, + const phi::DenseTensor *expert_scales2, + const phi::DenseTensor *expert_biases2, + const int fwd_bsz, + const int dim_feedforward, + const int dim_embed, + const int num_expert, + const int &act_method, // none, gelu, relu + const int &default_act, // none, gelu, relu + phi::DenseTensor *expert_out1, + phi::DenseTensor *out) { +#if defined(PADDLE_WITH_CUTLASS) + // csum length + const int64_t *total_rows_before_expert = expert_rows_pos.data(); + const T *permuted_data = x.data(); + const int8_t *fc1_expert_weights = expert_weights1->data(); + const T *fc1_scales = expert_scales1->data(); + const T *fc1_expert_biases = expert_biases1->data(); + T *fc1_result = expert_out1->data(); + + const int8_t *fc2_expert_weights = expert_weights2->data(); + const T *fc2_scales = expert_scales2->data(); + const T *fc2_expert_biases = expert_biases2->data(); + T *fc2_result = out->data(); + + if (is_uint4_) { + gemm_runner_int4_.moe_gemm_bias_act( + reinterpret_cast(permuted_data), + reinterpret_cast(fc1_expert_weights), + reinterpret_cast(fc1_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_result), + const_cast(total_rows_before_expert), + fwd_bsz, + dim_feedforward, + dim_embed, + num_expert, + static_cast(act_method), + dev_ctx_.stream()); + gemm_runner_int4_.moe_gemm_bias_act( + reinterpret_cast(fc1_result), + reinterpret_cast(fc2_expert_weights), + reinterpret_cast(fc2_scales), + reinterpret_cast(fc2_expert_biases), + reinterpret_cast(fc2_result), + const_cast(total_rows_before_expert), + fwd_bsz, + dim_embed, + dim_feedforward, + num_expert, + static_cast(default_act), + dev_ctx_.stream()); + } else { + gemm_runner_int8_.moe_gemm_bias_act( + reinterpret_cast(permuted_data), + reinterpret_cast(fc1_expert_weights), + reinterpret_cast(fc1_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_result), + const_cast(total_rows_before_expert), + fwd_bsz, + dim_feedforward, + dim_embed, + num_expert, + static_cast(act_method), + dev_ctx_.stream()); + + gemm_runner_int8_.moe_gemm_bias_act( + reinterpret_cast(fc1_result), + reinterpret_cast(fc2_expert_weights), + reinterpret_cast(fc2_scales), + reinterpret_cast(fc2_expert_biases), + reinterpret_cast(fc2_result), + const_cast(total_rows_before_expert), + fwd_bsz, + dim_embed, + dim_feedforward, + num_expert, + static_cast(default_act), + dev_ctx_.stream()); + } +#else + PADDLE_THROW(platform::errors::InvalidArgument( + "this machine not support weight only")); +#endif + } + + private: + const phi::GPUContext &dev_ctx_; +#if defined(PADDLE_WITH_CUTLASS) + GemRunnerInt8 gemm_runner_int8_; + GemRunnerInt4 gemm_runner_int4_; +#endif + bool is_uint4_ = false; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 870c5cac1dd6b..591657c73b15c 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -500,6 +500,7 @@ void BindTensor(pybind11::module &m) { // NOLINT }) .def("_share_data_with", &framework::Tensor::ShareDataWith) .def("_share_data_buffer", &framework::Tensor::ShareBufferWith) + .def("_share_buffer_with_tensors", &framework::Tensor::ShareBufferWithTensors) .def("__getitem__", PySliceTensor, py::return_value_policy::reference) .def("__str__", [](const framework::Tensor &self) { diff --git a/paddle/phi/core/dense_tensor.inl b/paddle/phi/core/dense_tensor.inl index 58ed734dafd33..ced279dd7ee50 100644 --- a/paddle/phi/core/dense_tensor.inl +++ b/paddle/phi/core/dense_tensor.inl @@ -63,7 +63,8 @@ void clear() { holder_.reset(); meta_.offset = 0; } - +// tensor list used continue memory buffer +void ShareBufferWithTensors(const std::vector& tensors); void ShareBufferWith(const DenseTensor& tensor, bool with_dtype = true); void ShareDataTypeWith(const DenseTensor& tensor) { meta_.dtype = tensor.meta().dtype; diff --git a/paddle/phi/core/dense_tensor_impl.cc b/paddle/phi/core/dense_tensor_impl.cc index 4254871d4505c..1009fd5b305a2 100644 --- a/paddle/phi/core/dense_tensor_impl.cc +++ b/paddle/phi/core/dense_tensor_impl.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" @@ -186,6 +187,35 @@ void DenseTensor::ShareBufferWith(const DenseTensor& tensor, bool with_dtype) { } } +void DenseTensor::ShareBufferWithTensors(const std::vector& tensors) { + int64_t total_num = 0; + for(auto &t : tensors) { + PADDLE_ENFORCE_EQ(t.dtype(), tensors[0].meta().dtype, "data type mismatch"); + total_num += t.numel(); + } + + size_t dtype_size = SizeOf(dtype()); + int64_t need_mem_size = total_num * dtype_size; + + auto place = tensors[0].place(); + if (holder_ == nullptr || !(holder_->place() == place) || + holder_->size() < need_mem_size + meta_.offset) { + holder_.reset(); + holder_ = paddle::memory::AllocShared(place, need_mem_size); + } + + char *ptr = reinterpret_cast(holder_->ptr()); + int64_t offset = 0; + for (size_t i = 0; i < tensors.size(); ++i) { + DenseTensor *tensor = const_cast(&tensors[i]); + int64_t data_len = tensor->numel() * dtype_size; + paddle::memory::Copy(place, (ptr + offset), place, tensor->data(), data_len); + tensor->set_offset(offset); + tensor->ResetHolder(holder_); + offset += data_len; + } +} + #define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \ template dtype* DenseTensor::mutable_data( \ const DDim& dims, const Place& place, size_t requested_size); \ diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 6b0bd0ac26025..ba635694d7771 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -182,7 +182,7 @@ if(WITH_CUTLASS) file( GLOB cutlass_cu RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "fusion/cutlass/*.cu" + # "fusion/cutlass/*.cu" # "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu" # "fusion/cutlass/memory_efficient_attention/autogen_variable/impl/*.cu" "fusion/cutlass/cutlass_kernels/*.cu" diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h index 4cbccace6b87d..bc59f8665ac37 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h @@ -41,20 +41,23 @@ inline int compute_occupancy_for_kernel() { int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); if (smem_size > (48 << 10)) { - cudaError_t status = - cudaFuncSetAttribute(cutlass::Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - if (status == cudaError::cudaErrorInvalidValue) { - // Clear the error bit since we can ignore this. - // This should mean that smem_size > - // cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an - // occupancy of 0. This will cause the heuristic to ignore this - // configuration. - status = cudaGetLastError(); + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + check_cuda_error(cudaGetDevice(&device)); + check_cuda_error(cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + check_cuda_error( + cudaFuncGetAttributes(&attr, cutlass::Kernel)); + if (smem_size + attr.sharedSizeBytes >= + static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, + // cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) wouldn't work. + // In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. return 0; } - check_cuda_error(status); } int max_active_blocks = -1; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index 9f82f3964792e..5f03318946924 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -343,7 +343,7 @@ struct MoeFCGemm { // The dummy template parameter is not used and exists so that we can compile // this code using a standard earlier than C++17. Prior to C++17, fully // specialized templates HAD to exists in a namespace - template + template struct KernelRunner { CUTLASS_DEVICE static void run_kernel(Params const& params, @@ -353,7 +353,7 @@ struct MoeFCGemm { }; template - struct KernelRunner { + struct KernelRunner { CUTLASS_DEVICE static void run_kernel(Params const& params, SharedStorage& shared_storage) { @@ -369,6 +369,9 @@ struct MoeFCGemm { using LayoutC = typename Epilogue::OutputTileIterator::Layout; static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert( + platform::is_same::value, + "B must be column major."); static_assert( platform::is_same::value && kInterleave == 1 || @@ -387,11 +390,21 @@ struct MoeFCGemm { int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + typename LayoutA::LongIndex ldm_A = gemm_k; + // typename LayoutB::LongIndex ldm_B = + // platform::is_same::value + // ? gemm_n + // : gemm_k * kInterleave; + typename LayoutB::LongIndex ldm_B = gemm_k * kInterleave; + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + typename Mma::FragmentC accumulators; // Outer 'persistent' loop to iterate over tiles - int loop = 0; while (problem_visitor.next_tile()) { - loop++; - GemmCoord problem_size = problem_visitor.problem_size(); int32_t problem_idx = problem_visitor.problem_index(); int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); @@ -405,21 +418,168 @@ struct MoeFCGemm { // Load element pointers. Exchange pointers and strides if working on // the transpose - const int64_t rows_to_jump = - problem_idx == 0 - ? 0 - : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + const int64_t rows_to_jump = params.problem_visitor.last_row_for_problem[problem_idx]; ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0}; + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B); + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + // + // Matrix multiply phase + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the + // previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + Mma mma(shared_storage.main_loop, + thread_idx, + warp_idx, + lane_idx); + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + // + // Epilogue + // + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = + reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size.mn(), + thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue( + shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, + SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability + // to implement a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert( + platform::is_same::value, + "B must be column major."); + static_assert( + platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = + (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + typename LayoutA::LongIndex ldm_A = gemm_k; + // typename LayoutB::LongIndex ldm_B = + // platform::is_same::value + // ? gemm_n + // : gemm_k * kInterleave; + typename LayoutB::LongIndex ldm_B = gemm_k * kInterleave; + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + typename Mma::FragmentC accumulators; + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Load element pointers. Exchange pointers and strides if working on + // the transpose + const int64_t rows_to_jump = params.problem_visitor.last_row_for_problem[problem_idx]; + + ElementA* ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value - ? gemm_n - : gemm_k * kInterleave; // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ @@ -449,34 +609,15 @@ struct MoeFCGemm { thread_idx, tb_offset_B); - typename Mma::FragmentC accumulators; - accumulators.clear(); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; // // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - auto CreateMMA = [&]() { - if constexpr (use_dq_gemm::value) - return Mma(shared_storage.main_loop, - params.group_size, - thread_idx, - warp_idx, - lane_idx); - else - return Mma( - shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - }; - Mma mma = CreateMMA(); - // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; @@ -485,37 +626,32 @@ struct MoeFCGemm { // previous tile. __syncthreads(); - // Compute threadblock-scoped matrix multiply-add + // Compute threadblock-scoped matrix multiply-add + // weight scale ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - - if constexpr (use_dq_gemm::value) { - const MatrixCoord scale_extent = {1, problem_size.n()}; - typename Mma::IteratorScale iterator_scale( - Mma::IteratorScale::Layout(scale_extent.column()), - weight_scale_ptr, - scale_extent, - thread_idx, - tb_offset_scale); - - mma(gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - iterator_scale, - accumulators); - } else { - mma(gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators); - } + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale( + Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, + scale_extent, + thread_idx, + tb_offset_scale); + Mma mma(shared_storage.main_loop, + params.group_size, + thread_idx, + warp_idx, + lane_idx); + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_scale, + accumulators); // // Epilogue // - EpilogueOutputOp output_op(params.output_op); ElementC* ptr_C = @@ -523,12 +659,6 @@ struct MoeFCGemm { ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - LayoutC layout_C(0); - LayoutC layout_D(gemm_n); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( params_C, @@ -564,24 +694,25 @@ struct MoeFCGemm { /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const& params, SharedStorage& shared_storage) { + static constexpr bool is_dq_gemm = use_dq_gemm::value; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are // available static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + KernelRunner::run_kernel(params, shared_storage); #else CUTLASS_NOT_IMPLEMENTED(); #endif diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index 7f5b6991af3aa..9754ff08a5312 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -73,7 +73,7 @@ struct BaseMoeProblemVisitor { void const* workspace; int32_t tile_count; - // + //ppc // Methods // @@ -155,9 +155,8 @@ struct BaseMoeProblemVisitor { CUTLASS_HOST_DEVICE cutlass::gemm::GemmCoord problem_size(int idx) const { - const int64_t prev_problem_row = - idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; - const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t prev_problem_row = params.last_row_for_problem[idx]; + const int64_t current_problem_row = params.last_row_for_problem[idx + 1]; const int64_t gemm_m = current_problem_row - prev_problem_row; GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 0f6cbb95d935d..ac2edf925cf65 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -72,17 +72,14 @@ void generic_mixed_gemm_kernelLauncher(const T* A, size_t workspace_bytes, cudaStream_t stream, int* occupancy) { -#ifdef PADDLE_CUDA_BF16 + static_assert(cutlass::platform::is_same::value || +#ifdef PADDLE_CUDA_BF16 cutlass::platform::is_same::value || +#endif cutlass::platform::is_same::value, "Specialized for bfloat16, half, float"); -#else - static_assert(cutlass::platform::is_same::value || - cutlass::platform::is_same::value, - "Specialized for half, float"); -#endif - + static_assert( cutlass::platform::is_same::value || cutlass::platform::is_same::value || diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/generic_moe_gemm_kernelLauncher.py b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/generic_moe_gemm_kernelLauncher.py index 5b0e841b1a9c7..06ca1bda171ad 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/generic_moe_gemm_kernelLauncher.py +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/generic_moe_gemm_kernelLauncher.py @@ -85,28 +85,47 @@ WeightTypes = { "fp16": ["half", "uint8_t", "cutlass::uint4b_t"], - "bf16": ["uint8_t", "cutlass::uint4b_t"]} -ThreadblockShapes = [ - "cutlass::gemm::GemmShape<32, 128, 64>", - "cutlass::gemm::GemmShape<64, 128, 64>", - "cutlass::gemm::GemmShape<128, 128, 64>", -] -WarpShapes = [ - "cutlass::gemm::GemmShape<32, 32, 64>", - "cutlass::gemm::GemmShape<64, 32, 64>", - "cutlass::gemm::GemmShape<128, 32, 64>", -] -ThreadblockShapes_sm70 = [ - "cutlass::gemm::GemmShape<32, 128, 64>", - "cutlass::gemm::GemmShape<64, 128, 64>", -] -WarpShapes_sm70 = [ - "cutlass::gemm::GemmShape<32, 32, 64>", - "cutlass::gemm::GemmShape<64, 32, 64>", -] + "bf16": ["uint8_t", "cutlass::uint4b_t"], + "float": ["float"] +} + +ThreadBlockWrapShapes = { + "float" : { + "ThreadblockShapes": [ + "cutlass::gemm::GemmShape<128, 128, 8>", + ], + "WarpShapes": [ + "cutlass::gemm::GemmShape<64, 64, 8>" + ], + }, + "half" : { + "ThreadblockShapes": [ + "cutlass::gemm::GemmShape<32, 128, 64>", + "cutlass::gemm::GemmShape<64, 128, 64>", + "cutlass::gemm::GemmShape<128, 128, 64>", + ], + "WarpShapes": [ + "cutlass::gemm::GemmShape<32, 32, 64>", + "cutlass::gemm::GemmShape<32, 64, 64>", + "cutlass::gemm::GemmShape<64, 32, 64>", + ], + }, + "default" : { + "ThreadblockShapes" : [ + "cutlass::gemm::GemmShape<32, 128, 64>", + "cutlass::gemm::GemmShape<64, 128, 64>", + "cutlass::gemm::GemmShape<128, 128, 64>", + ], + "WarpShapes": [ + "cutlass::gemm::GemmShape<32, 32, 64>", + "cutlass::gemm::GemmShape<64, 32, 64>", + "cutlass::gemm::GemmShape<128, 32, 64>", + ] + } +} StagesList = {70: [2], 80: [2, 3, 4]} -ElementTypes = {"fp16": "half", "bf16": "__nv_bfloat16"} +ElementTypes = {"fp16": "half", "bf16": "__nv_bfloat16", "float": "float"} Archs = { 70: "cutlass::arch::Sm70", 80: "cutlass::arch::Sm80", @@ -176,11 +195,12 @@ def generate_source_cu( element_type: str, WeightType:str, arch: int, epilogue_tag: str, stages: int ): all_code = CommonHead - ThreadblockShapes_arch = ThreadblockShapes - WarpShapes_arch = WarpShapes - if arch < 80: - ThreadblockShapes_arch = ThreadblockShapes_sm70 - WarpShapes_arch = WarpShapes_sm70 + ThreadblockShapes_arch = ThreadBlockWrapShapes["default"]["ThreadblockShapes"] + WarpShapes_arch = ThreadBlockWrapShapes["default"]["WarpShapes"] + if WeightType in ThreadBlockWrapShapes: + ThreadblockShapes_arch = ThreadBlockWrapShapes[WeightType]["ThreadblockShapes"] + WarpShapes_arch = ThreadBlockWrapShapes[WeightType]["WarpShapes"] + for i in range(len(ThreadblockShapes_arch)): value_dict = { "T": ElementTypes[element_type], diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels.h index ec5c0d4477729..2157146b0873c 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm_configs.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/activation_types.h" @@ -24,6 +25,7 @@ namespace phi { template class MoeGemmRunner { + using CutlassGemmConfigCache = typename std::unordered_map; public: MoeGemmRunner(); @@ -83,6 +85,7 @@ class MoeGemmRunner { private: int sm_; int multi_processor_count_; + CutlassGemmConfigCache config_cache_; }; } // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu index 866e8d4af6779..787442795994d 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu @@ -42,30 +42,30 @@ template -void dispatch_moe_gemm_config(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t num_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy) { +void dispatchGemmConfig(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t num_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); switch (gemm_config.stages) { case 2: using DispatcherStages2 = moe_dispatch_stages; + WeightType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 2>; DispatcherStages2::dispatch(A, B, weight_scales, @@ -83,12 +83,12 @@ void dispatch_moe_gemm_config(const T* A, break; case 3: using DispatcherStages3 = moe_dispatch_stages; + WeightType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 3>; DispatcherStages3::dispatch(A, B, weight_scales, @@ -106,12 +106,12 @@ void dispatch_moe_gemm_config(const T* A, break; case 4: using DispatcherStages4 = moe_dispatch_stages; + WeightType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 4>; DispatcherStages4::dispatch(A, B, weight_scales, @@ -135,41 +135,102 @@ void dispatch_moe_gemm_config(const T* A, break; } } - // This overload will handle simt gemms. It is disabled via SFINAE for tensorop. -template -void dispatch_moe_gemm_to_cutlass(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t num_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy) { - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - +template < + typename T, + typename WeightType, + typename arch, + typename EpilogueTag, + typename std::enable_if::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 8>>( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error("GEMM config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error("GEMM config should have already been set by heuristic."); + break; + default: + throw std::runtime_error("Unsupported config for float MoE gemm."); + break; + } +} +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they +// will not be used and we can improve compile time +template ::value && + std::is_same::value>::type* = + nullptr> +void dispatchMoeGemmToCutlass(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_moe_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>( + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, - num_rows, + total_rows, gemm_n, gemm_k, num_experts, @@ -178,20 +239,20 @@ void dispatch_moe_gemm_to_cutlass(const T* A, stream, occupancy); break; - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_moe_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>( + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, - num_rows, + total_rows, gemm_n, gemm_k, num_experts, @@ -200,20 +261,20 @@ void dispatch_moe_gemm_to_cutlass(const T* A, stream, occupancy); break; - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_moe_gemm_config, - cutlass::gemm::GemmShape<128, 32, 64>>( + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, - num_rows, + total_rows, gemm_n, gemm_k, num_experts, @@ -223,58 +284,58 @@ void dispatch_moe_gemm_to_cutlass(const T* A, occupancy); break; case CutlassTileConfig::Undefined: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config " - "undefined."); + throw std::runtime_error("GEMM config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should " - "have already been set by heuristic."); + "GEMM config should have already been set by heuristic."); break; default: throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config " - "for float MoE gemm."); + "Config is invalid for same type tensorop GEMM."); break; } } - -template -void dispatch_moe_gemm_to_cutlass_sm7x(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t num_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy) { - // VLOG(3)<<__PRETTY_FUNCTION__; - // Note that SIMT configs are omitted here since they are not supported for - // fpA_intB. We also only instantiate configs here where threadblockShapeM == - // warpShapeM since those usually perform the best for mixed type gemms. +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they +// will not be used and we can improve compile time +template ::value && + !std::is_same::value>::type* = + nullptr> +void dispatchMoeGemmToCutlass(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy) { switch (gemm_config.tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: // - dispatch_moe_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>( + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, - num_rows, + total_rows, gemm_n, gemm_k, num_experts, @@ -284,19 +345,41 @@ void dispatch_moe_gemm_to_cutlass_sm7x(const T* A, occupancy); break; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_moe_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>( + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, - num_rows, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, gemm_n, gemm_k, num_experts, @@ -306,18 +389,15 @@ void dispatch_moe_gemm_to_cutlass_sm7x(const T* A, occupancy); break; case CutlassTileConfig::Undefined: - throw std::runtime_error( - "[FT][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error("GEMM config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( - "[FT][dispatch_moe_gemm_to_cutlass] gemm config should have " - "already been set by heuristic."); + "GEMM config should have already been set by heuristic."); break; default: throw std::runtime_error( - "[FT][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type " - "GEMM."); + "Config is invalid for mixed type tensorop GEMM."); break; } } @@ -352,24 +432,22 @@ void MoeGemmRunner::dispatch_to_arch( if (sm_ >= 70 && sm_ < 75) { #if defined(USE_FPAINTB_GEMM_WITH_SM70) - dispatch_moe_gemm_to_cutlass_sm7x(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - num_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - sm_, - multi_processor_count_, - stream, - occupancy); + dispatchMoeGemmToCutlass( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + num_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); #else throw std::runtime_error( "[MoeGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed " @@ -378,46 +456,42 @@ void MoeGemmRunner::dispatch_to_arch( } #if defined(USE_FPAINTB_GEMM_WITH_SM75) else if (sm_ >= 75 && sm_ < 80) { - dispatch_moe_gemm_to_cutlass_sm7x(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - num_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - sm_, - multi_processor_count_, - stream, - occupancy); + dispatchMoeGemmToCutlass( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + num_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); } #endif #if defined(USE_FPAINTB_GEMM_WITH_SM80) || defined(USE_FPAINTB_GEMM_WITH_SM90) else if (sm_ >= 80 && sm_ <= 90) { - dispatch_moe_gemm_to_cutlass(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - num_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - sm_, - multi_processor_count_, - stream, - occupancy); + dispatchMoeGemmToCutlass( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + num_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); } #endif else { @@ -441,42 +515,52 @@ void MoeGemmRunner::run_gemm( int num_experts, cudaStream_t stream) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - static constexpr bool is_weight_only = !std::is_same::value; - static constexpr bool only_simt_configs = std::is_same::value; + int64_t key = + static_cast((static_cast(num_experts) << 44) | + static_cast(gemm_n) << 22 | gemm_k); + CutlassGemmConfig chosen_config; + auto it = config_cache_.find(key); + if (it == config_cache_.end()) { + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; - static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. - static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. - std::vector candidate_configs = - get_candidate_configs(sm_, is_weight_only, only_simt_configs, false, split_k_limit); - std::vector occupancies(candidate_configs.size()); + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = + 1; // MoE GEMM does not support split-k. + std::vector candidate_configs = get_candidate_configs( + sm_, is_weight_only, only_simt_configs, false, split_k_limit); + std::vector occupancies(candidate_configs.size()); - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - candidate_configs[ii], - stream, - &occupancies[ii]); - } + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + candidate_configs[ii], + stream, + &occupancies[ii]); + } - CutlassGemmConfig chosen_config = - estimate_best_config_from_occupancies(candidate_configs, - occupancies, - total_rows, - gemm_n, - gemm_k, - num_experts, - split_k_limit, - workspace_bytes, - multi_processor_count_, - is_weight_only); + chosen_config = + estimate_best_config_from_occupancies(candidate_configs, + occupancies, + total_rows, + gemm_n, + gemm_k, + num_experts, + split_k_limit, + workspace_bytes, + multi_processor_count_, + is_weight_only); + } else { + chosen_config = it->second; + } dispatch_to_arch(A, B, @@ -588,21 +672,22 @@ void MoeGemmRunner::moe_gemm(const T* A, cudaStream_t stream) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); run_gemm(A, - B, - weight_scales, - nullptr, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + stream); } template class MoeGemmRunner; template class MoeGemmRunner; template class MoeGemmRunner; +template class MoeGemmRunner; #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner<__nv_bfloat16, uint8_t>; template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h index 4b70d53d7b5fa..c9c88d579a348 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -382,36 +382,19 @@ template -void dispatch_moe_gemm_config(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t num_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy); - -template -void dispatch_moe_gemm_to_cutlass(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy); +void dispatchGemmConfig(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t num_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy); } // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/moe/moe_kernel_impl.h b/paddle/phi/kernels/fusion/cutlass/moe/moe_kernel_impl.h index 7d63e74fb9370..bd5f311a38276 100644 --- a/paddle/phi/kernels/fusion/cutlass/moe/moe_kernel_impl.h +++ b/paddle/phi/kernels/fusion/cutlass/moe/moe_kernel_impl.h @@ -179,7 +179,7 @@ size_t getWorkspaceSize(const int num_rows, total_ws_bytes += sizeof(int) * num_experts * k; // permuted_experts_ total_ws_bytes += buf_size * sizeof(T); // permuted_data_ total_ws_bytes += - padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + (padded_experts + 1) * sizeof(int64_t); // Hold total_rows_before_expert_ add 1 for the first element 0 total_ws_bytes += sizeof(T) * num_moe_inputs; // attr_mask: [e, n] total_ws_bytes += sizeof(T) * padded_num_moe_inputs; // sorted_softmax_output @@ -222,7 +222,10 @@ __global__ void initialize_expert_choice_route_kernel( attr_mask[start + i] = (T)1.0f; } if (threadIdx.x == 0) { - total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1); + if (blockIdx.x == 0) { + total_rows_before_expert[0] = 0; + } + total_rows_before_expert[blockIdx.x + 1] = batch_size * k * (blockIdx.x + 1); } } diff --git a/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu b/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu index 1c4cd2bd87b25..ba77cc3329e5e 100644 --- a/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu +++ b/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu @@ -329,7 +329,7 @@ void MoeKernel(const Context& ctx, total_rows_before_expert = reinterpret_cast(padded_expert_scales + padded_num_moe_inputs); sorted_softmax_output = - reinterpret_cast(total_rows_before_expert + padded_experts); + reinterpret_cast(total_rows_before_expert + padded_experts + 1); attr_mask = reinterpret_cast(sorted_softmax_output + padded_num_moe_inputs); fc1_result = reinterpret_cast(attr_mask + num_moe_inputs); diff --git a/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h b/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h index 6ef7eb3f81b94..6946d84f1429d 100644 --- a/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h +++ b/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include #include "paddle/phi/kernels/fused_moe_kernel.h" DECLARE_bool(avoid_op_randomness); @@ -40,11 +40,12 @@ __global__ void AssignPos(T* cum_count, } template -void AssignPosCompute(const phi::GPUContext &dev_ctx, - framework::Tensor* cum_count, // (counter number) int32 | int64 - framework::Tensor* numbers, // (batch_size * seq_len, topk) int32 - framework::Tensor* out, - const int eff_num_len) { +void AssignPosCompute( + const phi::GPUContext& dev_ctx, + framework::Tensor* cum_count, // (counter number) int32 | int64 + framework::Tensor* numbers, // (batch_size * seq_len, topk) int32 + framework::Tensor* out, + const int eff_num_len) { auto place = dev_ctx.GetPlace(); auto numel = numbers->numel(); T* cum_data = const_cast(cum_count->data()); @@ -60,4 +61,71 @@ void AssignPosCompute(const phi::GPUContext &dev_ctx, AssignPos<<>>( cum_data, num_data, out_data, numel); } -} \ No newline at end of file + +template +__global__ void AssignInsAndPos(T* cum_count, + const T* numbers, + T* out, + int64_t limit, + const int topk, + T* ins_out) { + CUDA_KERNEL_LOOP(i, limit) { + auto& number_idx = numbers[i]; + if (number_idx > -1) { + T p = platform::CudaAtomicAdd(cum_count + number_idx, -1); + out[p - 1] = static_cast(i); + ins_out[p - 1] = static_cast(i / topk); + } + } +} + +template +void AssignInsAndPosCompute( + const phi::GPUContext& dev_ctx, + phi::DenseTensor* cum_count, // (counter number) int32 | int64 + const phi::DenseTensor* numbers, // (batch_size * seq_len, topk) int32 + phi::DenseTensor* out, + const int eff_num_len, + const int topk, + phi::DenseTensor* ins_out) { + auto place = dev_ctx.GetPlace(); + auto numel = numbers->numel(); + T* cum_data = const_cast(cum_count->data()); + + framework::DDim out_dims = phi::make_ddim({eff_num_len}); + auto out_data = out->mutable_data(out_dims, place); + + const T* num_data = numbers->data(); + + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + if (topk > 1) { + AssignInsAndPos<<>>( + cum_data, num_data, out_data, numel, topk, ins_out->data()); + } else { + AssignPos<<>>( + cum_data, num_data, out_data, numel); + ins_out = out; + } +} +template +void CumsumTensorValue(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& in, + phi::DenseTensor* out, + const int out_offset = 0) { + const T* d_in = in.data(); + T* d_out = &out->data()[out_offset]; + int num_items = in.numel(); + auto stream = dev_ctx.stream(); + + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum( + NULL, temp_storage_bytes, d_in, d_out, num_items, stream); + // Allocate temporary storage for inclusive prefix sum + void* d_temp_storage = dev_ctx.GetWorkSpacePtr(temp_storage_bytes); + // Run inclusive prefix sum + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); +} +} // namespace phi \ No newline at end of file diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 2bd0502d75281..b4ccb531a6474 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from paddle.nn import functional as F from paddle.incubate.nn import functional as incubate_f from paddle.nn import Layer @@ -1985,6 +1986,9 @@ def __init__( ), "Expected dim_feedforward to be greater than 0, but received {}".format( dim_feedforward ) + gemm_cutlass = (os.getenv("FLAGS_enable_moe_gemm_cutlass", "false") == "true") + if gemm_cutlass: + print("FusedMultiTransformerMoe use cutlass gemm") # only support mp/dp # for moe config self.group = moe_group @@ -2137,7 +2141,7 @@ def get_attr(attrs, idx): expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) expert_weight1 = self.create_parameter( - shape=[d_model, dim_feedforward], + shape=[d_model, dim_feedforward] if not gemm_cutlass else [dim_feedforward, d_model], attr=expert_weight1_attr, dtype=self._dtype, is_bias=False, @@ -2151,7 +2155,7 @@ def get_attr(attrs, idx): default_initializer=nn.initializer.Constant(value=0.0) ) expert_weight2 = self.create_parameter( - shape=[dim_feedforward, d_model], + shape=[dim_feedforward, d_model] if not gemm_cutlass else [d_model, dim_feedforward], attr=expert_weight2_attr, dtype=self._dtype, is_bias=False, @@ -2175,6 +2179,8 @@ def get_attr(attrs, idx): self.dropout_rate = dropout_rate self.activation = activation self.name = name + if gemm_cutlass: + self._share_expert_param(num_layers, num_expert, dim_feedforward, d_model) def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): """ @@ -2256,6 +2262,54 @@ def trans_to_fp16(l): trans_to_fp16(self.expert_biases2) self._dtype = dtype + def _share_expert_param(self, num_layers, num_expert, dim_feedforward, d_model): + """ + share_param + """ + def shard_tensor(dst_tensor, parent_tensor, pos): + tmp = parent_tensor.value().get_tensor()._slice(pos, pos + 1) + dst_tensor.value().get_tensor()._share_data_buffer(tmp, False) + #print(dst_tensor) + + self.shared_weights1, self.shared_biases1 = ParameterList(), ParameterList() + self.shared_weights2, self.shared_biases2 = ParameterList(), ParameterList() + + for i in range(num_layers): + shared_weight1 = paddle.create_parameter( + name=f"moe.expert.layer{i}.shared_weight1", + shape=[num_expert, dim_feedforward, d_model], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(value=0.0)) + shared_bias1 = paddle.create_parameter( + name=f"moe.expert.layer{i}.shared_bias1", + shape=[num_expert, dim_feedforward], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(value=0.0)) + + shared_weight2 = paddle.create_parameter( + name=f"moe.expert.layer{i}.shared_weight2", + shape=[num_expert, d_model, dim_feedforward], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(value=0.0)) + shared_bias2 = paddle.create_parameter( + name=f"moe.expert.layer{i}.shared_bias2", + shape=[num_expert, d_model], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(value=0.0)) + + for j in range(self.num_expert): + expert_idx = j + i * self.num_expert + shard_tensor(self.expert_weights1[expert_idx], shared_weight1, j) + shard_tensor(self.expert_biases1[expert_idx], shared_bias1, j) + shard_tensor(self.expert_weights2[expert_idx], shared_weight2, j) + shard_tensor(self.expert_biases2[expert_idx], shared_bias2, j) + + self.shared_weights1.append(shared_weight1) + self.shared_biases1.append(shared_bias1) + + self.shared_weights2.append(shared_weight2) + self.shared_biases2.append(shared_bias2) + class FusedMultiTransformerMoeINT8(Layer): """ @@ -2937,7 +2991,7 @@ def get_attr(attrs, idx): self.activation = activation self.name = name self._int8_decorate() - #self._share_expert_param(num_layers, num_expert, dim_feedforward, d_model, weight_int8) + self._share_expert_param(num_layers, num_expert, dim_feedforward, d_model, weight_int8) self._dtype = "int8" def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None):