Skip to content

Commit 59392a8

Browse files
youkaichaolulmer
authored andcommitted
[bugfix] fix early import of flash attention (vllm-project#12959)
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent c32d4bf commit 59392a8

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

vllm/attention/backends/flash_attn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
AttentionMetadataBuilder,
1515
AttentionType)
1616
from vllm.attention.backends.utils import (
17-
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
18-
compute_slot_mapping, compute_slot_mapping_start_idx,
17+
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
18+
compute_slot_mapping_start_idx, get_flash_attn_version,
1919
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
2020
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
2121
is_block_tables_empty)
@@ -640,6 +640,7 @@ def __init__(
640640
f"Head size {head_size} is not supported by FlashAttention. "
641641
f"Supported head sizes are: {support_head_sizes}.")
642642
self.attn_type = attn_type
643+
self.vllm_flash_attn_version = get_flash_attn_version()
643644

644645
def forward(
645646
self,
@@ -759,7 +760,7 @@ def forward(
759760
alibi_slopes=alibi_slopes,
760761
softcap=logits_soft_cap,
761762
out=prefill_output,
762-
fa_version=VLLM_FLASH_ATTN_VERSION,
763+
fa_version=self.vllm_flash_attn_version,
763764
)
764765
else:
765766
# prefix-enabled attention
@@ -782,7 +783,7 @@ def forward(
782783
block_table=prefill_meta.block_tables,
783784
softcap=logits_soft_cap,
784785
out=prefill_output,
785-
fa_version=VLLM_FLASH_ATTN_VERSION,
786+
fa_version=self.vllm_flash_attn_version,
786787
)
787788

788789
if decode_meta := attn_metadata.decode_metadata:
@@ -811,7 +812,7 @@ def forward(
811812
softcap=logits_soft_cap,
812813
block_table=decode_meta.block_tables,
813814
out=decode_output,
814-
fa_version=VLLM_FLASH_ATTN_VERSION,
815+
fa_version=self.vllm_flash_attn_version,
815816
)
816817
else:
817818
# Use flash_attn_with_kvcache for normal decoding.
@@ -832,7 +833,7 @@ def forward(
832833
alibi_slopes=alibi_slopes,
833834
softcap=logits_soft_cap,
834835
out=decode_output.unsqueeze(1),
835-
fa_version=VLLM_FLASH_ATTN_VERSION,
836+
fa_version=self.vllm_flash_attn_version,
836837
)
837838
return output
838839

vllm/attention/backends/mla/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.attention.backends.abstract import (AttentionLayer,
1313
AttentionMetadata,
1414
MLAAttentionImpl, T)
15-
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
15+
from vllm.attention.backends.utils import get_flash_attn_version
1616
from vllm.distributed import (get_tensor_model_parallel_world_size,
1717
tensor_model_parallel_all_reduce)
1818
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -181,6 +181,7 @@ def __init__(
181181
self.q_proj = q_proj
182182
self.kv_b_proj = kv_b_proj
183183
self.o_proj = o_proj
184+
self.vllm_flash_attn_version = get_flash_attn_version()
184185

185186
def _v_up_proj_and_o_proj(self, x):
186187
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
@@ -515,7 +516,7 @@ def _forward_prefill_flash(
515516
max_seqlen_k=max_prefill_seq_len,
516517
softmax_scale=self.scale,
517518
causal=True,
518-
fa_version=VLLM_FLASH_ATTN_VERSION,
519+
fa_version=self.vllm_flash_attn_version,
519520
)
520521
attn_output = attn_output\
521522
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\

vllm/attention/backends/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens(
587587
num_decode_query_tokens)
588588

589589

590-
try:
591-
from vllm.vllm_flash_attn.flash_attn_interface import (
592-
fa_version_unsupported_reason, is_fa_version_supported)
590+
def get_flash_attn_version():
591+
try:
592+
from vllm.vllm_flash_attn.flash_attn_interface import (
593+
fa_version_unsupported_reason, is_fa_version_supported)
593594

594-
def flash_attn_version():
595595
# if hopper default to FA3, otherwise stick to FA2 for now
596596
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
597597
# use FA3 as default for both
@@ -610,7 +610,5 @@ def flash_attn_version():
610610

611611
assert is_fa_version_supported(fa_version)
612612
return fa_version
613-
614-
VLLM_FLASH_ATTN_VERSION = flash_attn_version()
615-
except (ImportError, AssertionError):
616-
VLLM_FLASH_ATTN_VERSION = None
613+
except (ImportError, AssertionError):
614+
return None

vllm/v1/attention/backends/flash_attn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1212
AttentionMetadata, AttentionType)
13-
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
13+
from vllm.attention.backends.utils import get_flash_attn_version
1414
from vllm.logger import init_logger
1515
from vllm.utils import cdiv
1616
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -132,6 +132,7 @@ def __init__(
132132
"encoder/decoder cross-attention "
133133
"are not implemented for "
134134
"FlashAttentionImpl")
135+
self.vllm_flash_attn_version = get_flash_attn_version()
135136

136137
def forward(
137138
self,
@@ -205,7 +206,7 @@ def forward(
205206
window_size=self.sliding_window,
206207
block_table=attn_metadata.block_table,
207208
softcap=self.logits_soft_cap,
208-
fa_version=VLLM_FLASH_ATTN_VERSION,
209+
fa_version=self.vllm_flash_attn_version,
209210
)
210211
return output
211212

@@ -227,7 +228,7 @@ def forward(
227228
logits_soft_cap=self.logits_soft_cap,
228229
block_table=attn_metadata.block_table,
229230
common_prefix_len=attn_metadata.common_prefix_len,
230-
fa_version=VLLM_FLASH_ATTN_VERSION,
231+
fa_version=self.vllm_flash_attn_version,
231232
)
232233
return output
233234

0 commit comments

Comments
 (0)