Skip to content

Commit

Permalink
fix SM90 compilation error (PaddlePaddle#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Dec 8, 2023
1 parent 0598fa2 commit 2197ddb
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ target_link_libraries(flashattn flashattn_with_bias_mask)
add_dependencies(flashattn flashattn_with_bias_mask)


if (NOT DEFINED NVCC_ARCH_BIN)
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.")
endif()

if (NVCC_ARCH_BIN STREQUAL "")
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.")
endif()

STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN})

set(FA_GENCODE_OPTION "SHELL:")
foreach(arch ${FA_NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}")
endif()
endforeach()

target_compile_options(flashattn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
-w
-Xcompiler="-fPIC"
Expand All @@ -96,7 +113,7 @@ target_compile_options(flashattn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
--expt-relaxed-constexpr
--expt-extended-lambda
--use_fast_math
"SHELL:-gencode arch=compute_80,code=sm_80"
"${FA_GENCODE_OPTION}"
>)

target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
Expand All @@ -111,7 +128,7 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
--expt-relaxed-constexpr
--expt-extended-lambda
--use_fast_math
"SHELL:-gencode arch=compute_80,code=sm_80"
"${FA_GENCODE_OPTION}"
>)

INSTALL(TARGETS flashattn
Expand Down

0 comments on commit 2197ddb

Please sign in to comment.