Skip to content

Commit

Permalink
add cutlass3.0, support moe expert aggregate gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Jan 5, 2024
2 parents 97e5554 + 18d403e commit 4b57358
Show file tree
Hide file tree
Showing 62 changed files with 4,491 additions and 4,148 deletions.
5 changes: 3 additions & 2 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}, NVCC_ARCH_BIN: ${NVCC_ARCH_BIN}")

# Set C++14 support
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
set(CUDA_PROPAGATE_HOST_FLAGS ON)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD 17)

# (Note) For windows, if delete /W[1-4], /W1 will be added defaultly and conflic with -w
# So replace /W[1-4] with /W0
Expand Down Expand Up @@ -321,6 +321,7 @@ if(WIN32)
endforeach()
endif()
endif()
message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")

mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
Expand Down
4 changes: 2 additions & 2 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ include(ExternalProject)

set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)
set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
set(CUTLASS_TAG v2.11.0)
set(CUTLASS_TAG v3.3.0)

set(CUTLASS_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/cutlass)
set(CUTLASS_SOURCE_DIR ${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass)
include_directories("${CUTLASS_SOURCE_DIR}/")
include_directories("${CUTLASS_SOURCE_DIR}/include/")
include_directories("${CUTLASS_SOURCE_DIR}/tools/util/include/")
Expand Down
6 changes: 3 additions & 3 deletions cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(SOURCE_DIR ${THIRD_PARTY_PATH}/flashattn/src/extern_flashattn)
#set(FLASHATTN_TAG 0598fa245bbfb8c4462002600864518c0e37e714)
set(FLASHATTN_TAG 705e8c69fe1511aa6abd4bfea493f24e119193ee)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
"${FLASHATTN_INSTALL_DIR}/lib"ex
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
Expand Down Expand Up @@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
#set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
Expand Down
2 changes: 1 addition & 1 deletion cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ endfunction()

checkcompilercxx14flag()
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
else()
set(CMAKE_CXX_STANDARD 14)
endif()
Expand Down
2 changes: 1 addition & 1 deletion cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
message(STATUS "add cutlass lib")
include(external/cutlass) # download, build, install cutlass
list(APPEND third_party_deps extern_cutlass)
Expand Down
53 changes: 22 additions & 31 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,9 @@ template <typename T>
class AttnMatMulWeightOnly {
#if defined(PADDLE_WITH_CUTLASS)
using InputType = typename phi::PDDataTypeTraits<T>::DataType;
using GemRunnerInt8 =
phi::CutlassFpAIntBGemmRunner<InputType,
uint8_t>;
using GemRunnerInt8 = phi::CutlassFpAIntBGemmRunner<InputType, uint8_t>;
using GemRunnerInt4 =
phi::CutlassFpAIntBGemmRunner<InputType,
cutlass::uint4b_t>;
phi::CutlassFpAIntBGemmRunner<InputType, cutlass::uint4b_t>;
#endif
public:
// (m, n, k) = bsz_seq, output_size, input_size
Expand All @@ -277,11 +274,11 @@ class AttnMatMulWeightOnly {

~AttnMatMulWeightOnly() {}
// get activation
int GetActivation(const std::string &act_method) {
int GetActivation(const std::string& act_method) {
#if defined(PADDLE_WITH_CUTLASS)
return static_cast<int>(phi::getActivationType(act_method));
return static_cast<int>(phi::getActivationType(act_method));
#else
return 0;
return 0;
#endif
}
void Linear(const phi::DenseTensor& x,
Expand Down Expand Up @@ -311,33 +308,30 @@ class AttnMatMulWeightOnly {
dev_ctx_.template GetWorkSpacePtr(mixgemm_workspace_size_bytes));
if (bias_data) {
mixed_gemm_runner_int4_.gemm_bias_act(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const cutlass::uint4b_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(
bias_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(bias_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
static_cast<phi::ActivationType>(act_method),
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx_.stream());
dev_ctx_.stream());
} else {
mixed_gemm_runner_int4_.gemm(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const cutlass::uint4b_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx_.stream());
dev_ctx_.stream());
}
} else {
int mixgemm_max_size = std::max(m, k);
Expand All @@ -348,27 +342,24 @@ class AttnMatMulWeightOnly {
dev_ctx_.template GetWorkSpacePtr(mixgemm_workspace_size_bytes));
if (bias_data) {
mixed_gemm_runner_int8_.gemm_bias_act(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const uint8_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(
bias_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<const InputType*>(bias_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
static_cast<phi::ActivationType>(act_method),
static_cast<phi::ActivationType>(act_method),
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx_.stream());
} else {
mixed_gemm_runner_int8_.gemm(
reinterpret_cast<const InputType*>(
x_data),
reinterpret_cast<const InputType*>(x_data),
reinterpret_cast<const uint8_t*>(weight_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType *>(out_data),
reinterpret_cast<const InputType*>(weight_scale_data),
reinterpret_cast<InputType*>(out_data),
m,
n,
k,
Expand Down
Loading

0 comments on commit 4b57358

Please sign in to comment.