Skip to content

Commit

Permalink
Revert "build flash-attn whl (PaddlePaddle#33)" (PaddlePaddle#39)
Browse files Browse the repository at this point in the history
This reverts commit 4b554d0.
  • Loading branch information
kircle888 authored Apr 19, 2024
1 parent 5341f98 commit ce8a8fe
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 215 deletions.
54 changes: 7 additions & 47 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
project(flash-attention LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)


find_package(Git QUIET REQUIRED)

execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE GIT_SUBMOD_RESULT)

#cmake -DWITH_ADVANCED=ON
if (WITH_ADVANCED)
add_compile_definitions(PADDLE_WITH_ADVANCED)
endif()

add_definitions("-DFLASH_ATTN_WITH_TORCH=0")

set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass)
set(BINARY_DIR ${CMAKE_BINARY_DIR})

set(FA2_SOURCES_CU
flash_attn/src/cuda_utils.cu
Expand Down Expand Up @@ -64,7 +55,6 @@ target_include_directories(flashattn PRIVATE
flash_attn
${CUTLASS_3_DIR}/include)

if (WITH_ADVANCED)
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
Expand All @@ -75,12 +65,6 @@ set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/utils.cu)
else()
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
flash_attn_with_bias_and_mask/src/utils.cu)
endif()

add_library(flashattn_with_bias_mask STATIC
flash_attn_with_bias_and_mask/
Expand All @@ -99,14 +83,18 @@ target_link_libraries(flashattn flashattn_with_bias_mask)

add_dependencies(flashattn flashattn_with_bias_mask)

set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures")

message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}")
if (NOT DEFINED NVCC_ARCH_BIN)
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.")
endif()

if (NVCC_ARCH_BIN STREQUAL "")
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.")
endif()

STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN})

set(FA_GENCODE_OPTION "SHELL:")

foreach(arch ${FA_NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}")
Expand Down Expand Up @@ -143,35 +131,7 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
"${FA_GENCODE_OPTION}"
>)


INSTALL(TARGETS flashattn
LIBRARY DESTINATION "lib")

INSTALL(FILES capi/flash_attn.h DESTINATION "include")

if (WITH_ADVANCED)
if(WIN32)
set(target_output_name "flashattn")
else()
set(target_output_name "libflashattn")
endif()
set_target_properties(flashattn PROPERTIES
OUTPUT_NAME ${target_output_name}_advanced
PREFIX ""
)

configure_file(${CMAKE_SOURCE_DIR}/env_dict.py.in ${CMAKE_SOURCE_DIR}/env_dict.py @ONLY)
set_target_properties(flashattn PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/paddle_flash_attn/
)
add_custom_target(build_whl
COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
DEPENDS flashattn
COMMENT "Running build wheel"
)

add_custom_target(default_target DEPENDS build_whl)

set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target)
endif()
3 changes: 0 additions & 3 deletions csrc/env_dict.py.in

This file was deleted.

4 changes: 2 additions & 2 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH_ADVANCED(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH_ADVANCED(is_deterministic, Is_deterministic, [&] {
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH(is_deterministic, Is_deterministic, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst, Is_attn_mask && !IsCausalConst, Is_deterministic>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH_ADVANCED(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal, Is_equal_seq_qk>;
Expand Down
19 changes: 0 additions & 19 deletions csrc/flash_attn/src/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,6 @@
} \
}()

#ifdef PADDLE_WITH_ADVANCED
#define BOOL_SWITCH_ADVANCED(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#else
#define BOOL_SWITCH_ADVANCED(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#endif

#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
Expand Down
143 changes: 0 additions & 143 deletions csrc/setup.py

This file was deleted.

0 comments on commit ce8a8fe

Please sign in to comment.