Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add should_use_tensor_core #2179

Merged
merged 6 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
from sglang.srt.utils import (
get_bool_env_var,
is_flashinfer_available,
should_use_tensor_cores,
)

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
Expand All @@ -31,7 +35,6 @@
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels


class WrapperDispatch(Enum):
Expand All @@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()

# Parse constants
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
self.decode_use_tensor_cores = get_bool_env_var(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
else:
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.decode_use_tensor_cores = should_use_tensor_cores(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads
// model_runner.tp_size,
num_kv_heads=model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
),
)

self.max_context_len = model_runner.model_config.context_len

Expand Down
47 changes: 47 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,3 +1108,50 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))


def should_use_tensor_cores(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.

Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads

Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable and dtype-based logic first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"

# Try to use _grouped_size_compiled_for_decode_kernels if available
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
zhyncs marked this conversation as resolved.
Show resolved Hide resolved

if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass

# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads

# Determine based on dtype and GQA group size
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size > 4
zhyncs marked this conversation as resolved.
Show resolved Hide resolved
else:
return False
8 changes: 4 additions & 4 deletions scripts/deprecated/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extend_attention_fwd,
redundant_attention,
)
from sglang.srt.utils import should_use_tensor_cores

flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None
Expand Down Expand Up @@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(


def init_flashinfer(num_attention_heads, num_kv_heads):
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
use_tensor_cores = True
else:
use_tensor_cores = False
use_tensor_cores = should_use_tensor_cores(
torch.half, num_attention_heads, num_kv_heads
)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")

Expand Down
Loading