forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Try to crop the library's size. (#6)
* Hide the internal symbols and remve the hdim16 implementation. * Remove SYMBOL_EXPORT. * Remove some debugging statements. * Change has_attn_bias and has_attn_mask to arguments instead of template. * Avoid to compile the no-used .cu files. * Remove return_softmax related template and argument. * Remove the support of is_causal for the implementation with mask and bias. * Reorganize codes. * Polish codes. * Add check of softmax_scale.
- Loading branch information
Showing
20 changed files
with
443 additions
and
636 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
// Copyright (c) 2022, Tri Dao. | ||
|
||
// Splitting the different head dimensions to different files to speed up compilation. | ||
#include "fmha_bwd_with_mask_bias_launch_template.h" | ||
|
||
#include "fmha_bwd_launch_template.h" | ||
|
||
void run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream) { | ||
bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream) { | ||
bool status = true; | ||
FP16_SWITCH(params.is_bf16, ([&] { | ||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
})); | ||
return status; | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
// Copyright (c) 2022, Tri Dao. | ||
|
||
// Splitting the different head dimensions to different files to speed up compilation. | ||
#include "fmha_bwd_with_mask_bias_launch_template.h" | ||
|
||
#include "fmha_bwd_launch_template.h" | ||
|
||
void run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream) { | ||
bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream) { | ||
bool status = true; | ||
FP16_SWITCH(params.is_bf16, ([&] { | ||
if( params.seqlen_k == 128 ) { | ||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
} else if( params.seqlen_k >= 256 ) { | ||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
} | ||
})); | ||
return status; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,30 @@ | ||
// Copyright (c) 2022, Tri Dao. | ||
|
||
// Splitting the different head dimensions to different files to speed up compilation. | ||
#include "fmha_bwd_with_mask_bias_launch_template.h" | ||
|
||
#include "fmha_bwd_launch_template.h" | ||
|
||
void run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream) { | ||
bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream) { | ||
bool status = true; | ||
auto dprops = GetDeviceProperties(-1); | ||
FP16_SWITCH(params.is_bf16, ([&] { | ||
if( params.seqlen_k == 128 ) { | ||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
} else if( params.seqlen_k >= 256 ) { | ||
if (dprops->major == 8 && dprops->minor == 0) { | ||
// Don't share smem for K & V, and don't keep V in registers | ||
// This speeds things up by 2-3% by avoiding register spills, but it | ||
// uses more shared memory, which is fine on A100 but not other GPUs. | ||
// For other GPUs, we keep V in registers. | ||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
} else if (dprops->major == 8 && dprops->minor > 0) { | ||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
} else if (dprops->major == 7 && dprops->minor == 5) { | ||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; | ||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
status = run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); | ||
} | ||
} | ||
})); | ||
return status; | ||
} |
Oops, something went wrong.