Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Attention Runtime Error for CLIP model #17729

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,27 +140,29 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
#endif

if (!use_flash_attention) {
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_trt_flash_attention_, parameters.scale);
if (is_unidirectional_) { // GPT
if (enable_fused_causal_attention_) {
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_trt_flash_attention_, parameters.scale);
}

// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
fused_runner = fused_fp16_runner_.get();
}

// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
fused_runner = fused_fp16_runner_.get();
}
} else { // BERT
bool use_fused_runner = !disable_fused_self_attention_ &&
Expand Down