Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#52 from bob80333/main
Browse files Browse the repository at this point in the history
Make flash attention compile on Windows.
  • Loading branch information
tridao authored Oct 5, 2022
2 parents 0c01568 + 2211db5 commit 88dc204
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
5 changes: 3 additions & 2 deletions csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*/

#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"

Expand Down Expand Up @@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
}

void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d == 16) {
if( params.seqlen_k == 128 ) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <cuda_bf16.h>

#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"

Expand Down Expand Up @@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,

void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
FP16_SWITCH(launch_params.params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (launch_params.params.d == 16) {
if( launch_params.params.seqlen_k == 128 ) {
Expand Down
27 changes: 27 additions & 0 deletions csrc/flash_attn/src/fp16_switch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h

// modified from static_switch.h
// because MSVC cannot handle std::conditional with constexpr variable

#pragma once

/// @param COND - a boolean expression to switch by
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// FP16_SWITCH(flag, [&] {
/// some_function(...);
/// });
/// ```
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
using elem_type = __half; \
return __VA_ARGS__(); \
} \
}()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ def append_nvcc_threads(nvcc_extra_args):
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
],
extra_compile_args={
"cxx": ["-O3"] + generator_flag,
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
Expand Down

0 comments on commit 88dc204

Please sign in to comment.