Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window

# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)

quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
Expand Down
5 changes: 0 additions & 5 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,12 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if logits_soft_cap is not None:
logger.warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
if use_irope:
logger.warning_once(
"Using irope in Torch SPDA is not supported yet, it will fall"
" back to global attention for long context.")
self.paged_attn_impl = _get_paged_attn_impl()
self.num_heads = num_heads
self.head_size = head_size
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -381,7 +380,6 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,6 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -509,7 +508,6 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope

self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand Down
5 changes: 0 additions & 5 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,7 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -367,7 +366,6 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"AiterFlashAttention does not support fp8 kv-cache on this "
Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
vllm_config.parallel_config)
self.headdim = model_config.get_head_size()

self.attention_chunk_size = getattr(vllm_config.scheduler_config,
'attention_chunk_size', None)

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
Expand Down Expand Up @@ -208,7 +205,6 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -228,8 +224,6 @@ def __init__(
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

self.use_irope = use_irope

self.num_queries_per_kv = self.num_heads // self.num_kv_heads

TritonAttentionBackend.validate_head_size(head_size)
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2710,8 +2710,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
# TODO: Support other attention modules, e.g., cross-attention
if attn_module.attn_type == AttentionType.DECODER:
use_local_attention = (self.attention_chunk_size is not None
and getattr(attn_module.impl,
"use_irope", False))
and attn_module.use_irope)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
Expand All @@ -2724,13 +2723,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"attention module can not be with ",
"both local attention and sliding window")
elif use_local_attention:
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
use_mla=use_mla))
use_mla=use_mla)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
continue

if attn_module.attn_type == AttentionType.DECODER:
if attn_module.use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context.")
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
Expand Down