Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Dec 1, 2024
1 parent 5f12f0e commit d57c747
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 19 deletions.
28 changes: 13 additions & 15 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
from sglang.srt.utils import (
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_cores,
)

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
Expand All @@ -31,7 +35,6 @@
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels


class WrapperDispatch(Enum):
Expand All @@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()

# Parse constants
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
self.decode_use_tensor_cores = get_bool_env_var(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
else:
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.decode_use_tensor_cores = should_use_tensor_cores(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads
// model_runner.tp_size,
num_kv_heads=model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
),
)

self.max_context_len = model_runner.model_config.context_len

Expand Down
47 changes: 47 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,3 +1108,50 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))


def should_use_tensor_cores(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to environment variable and dtype-based logic first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"

# Try to use _grouped_size_compiled_for_decode_kernels if available
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels

if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass

# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads

# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
else:
return False
8 changes: 4 additions & 4 deletions scripts/deprecated/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extend_attention_fwd,
redundant_attention,
)
from sglang.srt.utils import should_use_tensor_cores

flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None
Expand Down Expand Up @@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(


def init_flashinfer(num_attention_heads, num_kv_heads):
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
use_tensor_cores = True
else:
use_tensor_cores = False
use_tensor_cores = should_use_tensor_cores(
torch.half, num_attention_heads, num_kv_heads
)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")

Expand Down

0 comments on commit d57c747

Please sign in to comment.