Skip to content

Commit

Permalink
build flash-attn whl (PaddlePaddle#33)
Browse files Browse the repository at this point in the history
* simplify code

* add files

* fix python version

* windows fixed

* del time
  • Loading branch information
zhangting2020 authored Feb 26, 2024
1 parent d98d8a3 commit 4b554d0
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 10 deletions.
54 changes: 47 additions & 7 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
project(flash-attention LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)


find_package(Git QUIET REQUIRED)

execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE GIT_SUBMOD_RESULT)

#cmake -DWITH_ADVANCED=ON
if (WITH_ADVANCED)
add_compile_definitions(PADDLE_WITH_ADVANCED)
endif()

add_definitions("-DFLASH_ATTN_WITH_TORCH=0")

set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass)
set(BINARY_DIR ${CMAKE_BINARY_DIR})

set(FA2_SOURCES_CU
flash_attn/src/cuda_utils.cu
Expand Down Expand Up @@ -55,6 +64,7 @@ target_include_directories(flashattn PRIVATE
flash_attn
${CUTLASS_3_DIR}/include)

if (WITH_ADVANCED)
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
Expand All @@ -65,6 +75,12 @@ set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/utils.cu)
else()
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
flash_attn_with_bias_and_mask/src/utils.cu)
endif()

add_library(flashattn_with_bias_mask STATIC
flash_attn_with_bias_and_mask/
Expand All @@ -83,18 +99,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask)

add_dependencies(flashattn flashattn_with_bias_mask)

set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures")

if (NOT DEFINED NVCC_ARCH_BIN)
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.")
endif()

if (NVCC_ARCH_BIN STREQUAL "")
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.")
endif()
message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}")

STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN})

set(FA_GENCODE_OPTION "SHELL:")

foreach(arch ${FA_NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}")
Expand Down Expand Up @@ -131,7 +143,35 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
"${FA_GENCODE_OPTION}"
>)


INSTALL(TARGETS flashattn
LIBRARY DESTINATION "lib")

INSTALL(FILES capi/flash_attn.h DESTINATION "include")

if (WITH_ADVANCED)
if(WIN32)
set(target_output_name "flashattn")
else()
set(target_output_name "libflashattn")
endif()
set_target_properties(flashattn PROPERTIES
OUTPUT_NAME ${target_output_name}_advanced
PREFIX ""
)

configure_file(${CMAKE_SOURCE_DIR}/env_dict.py.in ${CMAKE_SOURCE_DIR}/env_dict.py @ONLY)
set_target_properties(flashattn PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/paddle_flash_attn/
)
add_custom_target(build_whl
COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
DEPENDS flashattn
COMMENT "Running build wheel"
)

add_custom_target(default_target DEPENDS build_whl)

set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target)
endif()
3 changes: 3 additions & 0 deletions csrc/env_dict.py.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
env_dict = {
'CMAKE_BINARY_DIR': '@CMAKE_BINARY_DIR@'
}
4 changes: 2 additions & 2 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH(is_deterministic, Is_deterministic, [&] {
BOOL_SWITCH_ADVANCED(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH_ADVANCED(is_deterministic, Is_deterministic, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst, Is_attn_mask && !IsCausalConst, Is_deterministic>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH_ADVANCED(is_attn_mask, Is_attn_mask, [&] {
BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal, Is_equal_seq_qk>;
Expand Down
19 changes: 19 additions & 0 deletions csrc/flash_attn/src/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
} \
}()

#ifdef PADDLE_WITH_ADVANCED
#define BOOL_SWITCH_ADVANCED(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#else
#define BOOL_SWITCH_ADVANCED(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#endif

#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
Expand Down
143 changes: 143 additions & 0 deletions csrc/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import ast
import os
import re
import subprocess
import sys
from pathlib import Path

from env_dict import env_dict
from setuptools import setup
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel


with open("../../README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()


cur_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "paddle-flash-attn"


def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith('linux'):
return 'linux_x86_64'
elif sys.platform == 'win32':
return 'win_amd64'
else:
raise ValueError(f'Unsupported platform: {sys.platform}')


def get_cuda_version():
try:
result = subprocess.run(
['nvcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode == 0:
output_lines = result.stdout.split('\n')
for line in output_lines:
if line.startswith('Cuda compilation tools'):
cuda_version = (
line.split('release')[1].strip().split(',')[0]
)
return cuda_version
else:
print("Error:", result.stderr)

except Exception as e:
print("Error:", str(e))

return None


def get_package_version():
with open(Path(cur_dir) / "../flash_attn" / "__init__.py", "r") as f:
version_match = re.search(
r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE
)
public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
if local_version:
return f"{public_version}+{local_version}"
else:
return str(public_version)


def get_package_data():
binary_dir = env_dict.get("CMAKE_BINARY_DIR")
lib = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
binary_dir + '/paddle_flash_attn/*',
)
package_data = {'paddle_flash_attn': [lib]}
return package_data


class CustomWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""

def run(self):
self.run_command('build_ext')
super().run()
cuda_version = get_cuda_version()
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
wheel_name = 'paddle_flash_attn'

# Determine wheel URL based on CUDA version, python version and OS
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = (
f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
)
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
wheel_filename = f'{wheel_name}-{flash_version}+cu{cuda_version}-{impl_tag}-{abi_tag}-{platform_name}.whl'
os.rename(wheel_path, os.path.join(self.dist_dir, wheel_filename))


setup(
name=PACKAGE_NAME,
version=get_package_version(),
packages=['paddle_flash_attn'],
package_data=get_package_data(),
author_email="Paddle-better@baidu.com",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/PaddlePaddle/flash-attention",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
cmdclass={
'bdist_wheel': CustomWheelsCommand,
},
python_requires=">=3.7",
)

0 comments on commit 4b554d0

Please sign in to comment.