Skip to content

Commit

Permalink
add cutlass3.0, support moe expert aggregate gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Jan 5, 2024
1 parent 9293d34 commit 18d403e
Show file tree
Hide file tree
Showing 21 changed files with 1,201 additions and 631 deletions.
2 changes: 1 addition & 1 deletion cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
#set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
Expand Down
53 changes: 22 additions & 31 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,9 @@ template <typename T>
class AttnMatMulWeightOnly {
#if defined(PADDLE_WITH_CUTLASS)
using InputType = typename phi::PDDataTypeTraits<T>::DataType;
using GemRunnerInt8 =
phi::CutlassFpAIntBGemmRunner<InputType,
uint8_t>;
using GemRunnerInt8 = phi::CutlassFpAIntBGemmRunner<InputType, uint8_t>;
using GemRunnerInt4 =
phi::CutlassFpAIntBGemmRunner<InputType,
cutlass::uint4b_t>;
phi::CutlassFpAIntBGemmRunner<InputType, cutlass::uint4b_t>;
#endif
public:
// (m, n, k) = bsz_seq, output_size, input_size
Expand All @@ -277,11 +274,11 @@ class AttnMatMulWeightOnly {

~AttnMatMulWeightOnly() {}
// get activation
int GetActivation(const std::string &act_method) {
int GetActivation(const std::string& act_method) {
#if defined(PADDLE_WITH_CUTLASS)
return static_cast<int>(phi::getActivationType(act_method));
return static_cast<int>(phi::getActivationType(act_method));
#else
return 0;
return 0;
#endif
}
void Linear(const phi::DenseTensor& x,
Expand Down Expand Up @@ -311,33 +308,30 @@ class AttnMatMulWeightOnly {
dev_ctx_.template GetWorkSpacePtr(mixgemm_workspace_size_bytes));
if (bias_data) {
mixed_gemm_runner_int4_.gemm_bias_act(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const cutlass::uint4b_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(
bias_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(bias_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
static_cast<phi::ActivationType>(act_method),
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx_.stream());
dev_ctx_.stream());
} else {
mixed_gemm_runner_int4_.gemm(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const cutlass::uint4b_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx_.stream());
dev_ctx_.stream());
}
} else {
int mixgemm_max_size = std::max(m, k);
Expand All @@ -348,27 +342,24 @@ class AttnMatMulWeightOnly {
dev_ctx_.template GetWorkSpacePtr(mixgemm_workspace_size_bytes));
if (bias_data) {
mixed_gemm_runner_int8_.gemm_bias_act(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const uint8_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(
bias_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(bias_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
static_cast<phi::ActivationType>(act_method),
static_cast<phi::ActivationType>(act_method),
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx_.stream());
} else {
mixed_gemm_runner_int8_.gemm(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const uint8_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
Expand Down
Loading

0 comments on commit 18d403e

Please sign in to comment.