Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 54 additions & 33 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
USE_XFORMERS_OPS = None

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and envs.VLLM_ROCM_USE_AITER_MLA

Check failure on line 33 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:33:81: E501 Line too long (149 > 80)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")

Check failure on line 37 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/attention/layer.py:37:13: G004 Logging statement uses f-string


def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
Expand Down Expand Up @@ -262,12 +263,12 @@
if output_shape is not None else query.shape)
if positions is not None:
output = torch.empty(output_shape,
dtype=query.dtype,
device=query.device)
dtype=query.dtype,
device=query.device)
else:
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
dtype=query.dtype,
device=query.device)

hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
Expand All @@ -290,15 +291,21 @@
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, positions=positions)
query,
key,
value,
output,
self.layer_name,
None,
positions=positions)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
Expand Down Expand Up @@ -515,7 +522,7 @@
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
Expand All @@ -526,31 +533,45 @@
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]

from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl
from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterFlashAttentionImpl) or isinstance(self.impl, AiterMLAImpl)):
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionImpl)
from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl

# Not all layers can use RoPE fusing, so check that they were given all
# needed inputs along with the environment variable to enable this.
if (
hasattr(self.impl, "rotary_emb")
and self.impl.rotary_emb is not None
and positions is not None
and isinstance(
self.impl, (TritonAttentionImpl, AiterFlashAttentionImpl, AiterMLAImpl)
)
):
assert (VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
), "rotary_emb and positions provided when unexpected."
# fusing RoPE with flushing kv_cache operation
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None and positions is not None, f"rotary_emb not found in {self.impl=} and positions cannot be None"
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
positions=positions)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
positions=positions,
)
else:
assert positions is None, f"positions must be None {positions=}"
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale)
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale)

maybe_save_kv_layer_to_connector(layer_name, kv_cache)

Expand All @@ -562,7 +583,7 @@
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None,
prefix: str = ""):
prefix: str = "",
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
Expand Down Expand Up @@ -88,4 +89,5 @@ def __init__(self,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend)
attn_backend=attn_backend,
**kwargs)
44 changes: 33 additions & 11 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch import nn
from transformers import Llama4TextConfig

import vllm.envs as envs
from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
Expand All @@ -38,11 +39,16 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.platforms import current_platform

from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
is_pp_missing_parameter)

VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE)


class Llama4MoE(nn.Module):

Expand Down Expand Up @@ -198,17 +204,24 @@ def __init__(self,
use_chunked_local_attn = not self.nope and config.attention_chunk_size
attn_cls = (ChunkedLocalAttention
if use_chunked_local_attn else Attention)
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**({
"attention_chunk_size": config.attention_chunk_size
} if use_chunked_local_attn else {}))
extra_args = {}
if use_chunked_local_attn:
extra_args["attention_chunk_size"] = config.attention_chunk_size
# Use the rotary_emb in attention only when it's supported
self.use_fused_rope = (
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
and self.rotary_emb is not None and self.qk_norm is None
and not self.attn_temperature_tuning)
if self.use_fused_rope:
extra_args["rotary_emb"] = self.rotary_emb
self.attn = attn_cls(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**extra_args)

def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
Expand All @@ -224,6 +237,15 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

# rotary_emb is fused into self.attn in this case
if self.use_fused_rope:
assert not (
self.attn_temperature_tuning
), f"{self.attn_temperature_tuning=} must be False when using fused rope"
attn_output = self.attn(q, k, v, positions=positions)
output, _ = self.o_proj(attn_output)
return output

if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)

Expand Down
Loading