diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index a53a1afbb9930..5b305fa98c01b 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -84,6 +84,23 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) +if (NOT DEFINED NVCC_ARCH_BIN) + message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.") +endif() + +if (NVCC_ARCH_BIN STREQUAL "") + message(FATAL_ERROR "NVCC_ARCH_BIN is not set.") +endif() + +STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) + +set(FA_GENCODE_OPTION "SHELL:") +foreach(arch ${FA_NVCC_ARCH_BIN}) + if(${arch} GREATER_EQUAL 80) + set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}") + endif() +endforeach() + target_compile_options(flashattn PRIVATE $<$: -w -Xcompiler="-fPIC" @@ -96,7 +113,7 @@ target_compile_options(flashattn PRIVATE $<$: --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math - "SHELL:-gencode arch=compute_80,code=sm_80" + "${FA_GENCODE_OPTION}" >) target_compile_options(flashattn_with_bias_mask PRIVATE $<$: @@ -111,7 +128,7 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) INSTALL(TARGETS flashattn