Skip to content

Commit

Permalink
bugfix: bug fix on determine_attention_backend condition (#688)
Browse files Browse the repository at this point in the history
Should only enable fa3 for cuda 12.3+
  • Loading branch information
yzh119 authored Dec 20, 2024
1 parent 3470329 commit bcf7a3e
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

import torch
import torch.version
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

Expand Down Expand Up @@ -342,12 +343,16 @@ def determine_attention_backend(
"""
major, _ = get_compute_capability(device)

if major >= 9 and is_fa3_backend_supported(
pos_encoding_mode,
allow_fp16_qk_reductions,
use_custom_mask,
dtype_q,
dtype_kv,
if (
major >= 9
and torch.version.cuda >= "12.3"
and is_fa3_backend_supported(
pos_encoding_mode,
allow_fp16_qk_reductions,
use_custom_mask,
dtype_q,
dtype_kv,
)
):
return "fa3"
else:
Expand Down

0 comments on commit bcf7a3e

Please sign in to comment.