diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2a24227d65f4..e9b488be0030 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -36,6 +36,7 @@ logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}") + def check_xformers_availability(): global USE_XFORMERS_OPS if USE_XFORMERS_OPS is not None: @@ -262,12 +263,12 @@ def forward( if output_shape is not None else query.shape) if positions is not None: output = torch.empty(output_shape, - dtype=query.dtype, - device=query.device) + dtype=query.dtype, + device=query.device) else: output = torch.zeros(output_shape, - dtype=query.dtype, - device=query.device) + dtype=query.dtype, + device=query.device) hidden_size = output_shape[-1] # We skip reshaping query, key and value tensors for the MLA @@ -290,15 +291,21 @@ def forward( attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output) + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name, None, positions=positions) + query, + key, + value, + output, + self.layer_name, + None, + positions=positions) return output.view(-1, hidden_size) else: if self.use_direct_call: @@ -515,7 +522,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, - positions: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) @@ -526,31 +533,45 @@ def unified_attention_with_output( self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl - if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterFlashAttentionImpl) or isinstance(self.impl, AiterMLAImpl)): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionImpl) + from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl + + # Not all layers can use RoPE fusing, so check that they were given all + # needed inputs along with the environment variable to enable this. + if ( + hasattr(self.impl, "rotary_emb") + and self.impl.rotary_emb is not None + and positions is not None + and isinstance( + self.impl, (TritonAttentionImpl, AiterFlashAttentionImpl, AiterMLAImpl) + ) + ): + assert (VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE + ), "rotary_emb and positions provided when unexpected." # fusing RoPE with flushing kv_cache operation - assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None and positions is not None, f"rotary_emb not found in {self.impl=} and positions cannot be None" - self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - positions=positions) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + positions=positions, + ) else: assert positions is None, f"positions must be None {positions=}" self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale) + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -562,7 +583,7 @@ def unified_attention_with_output_fake( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, - positions: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> None: return diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 087c5004bde0..1c61a7692db3 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -58,7 +58,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, kv_sharing_target_layer_name: Optional[str] = None, - prefix: str = ""): + prefix: str = "", + **kwargs): dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -88,4 +89,5 @@ def __init__(self, quant_config=quant_config, prefix=prefix, kv_sharing_target_layer_name=kv_sharing_target_layer_name, - attn_backend=attn_backend) + attn_backend=attn_backend, + **kwargs) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index ba08e6f81f7f..75a049975d60 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -24,6 +24,7 @@ from torch import nn from transformers import Llama4TextConfig +import vllm.envs as envs from vllm.attention import Attention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile @@ -38,11 +39,16 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.platforms import current_platform from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, is_pp_missing_parameter) +VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( + current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE) + class Llama4MoE(nn.Module): @@ -198,17 +204,24 @@ def __init__(self, use_chunked_local_attn = not self.nope and config.attention_chunk_size attn_cls = (ChunkedLocalAttention if use_chunked_local_attn else Attention) - self.attn = attn_cls( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - **({ - "attention_chunk_size": config.attention_chunk_size - } if use_chunked_local_attn else {})) + extra_args = {} + if use_chunked_local_attn: + extra_args["attention_chunk_size"] = config.attention_chunk_size + # Use the rotary_emb in attention only when it's supported + self.use_fused_rope = ( + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE + and self.rotary_emb is not None and self.qk_norm is None + and not self.attn_temperature_tuning) + if self.use_fused_rope: + extra_args["rotary_emb"] = self.rotary_emb + self.attn = attn_cls(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **extra_args) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) @@ -224,6 +237,15 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # rotary_emb is fused into self.attn in this case + if self.use_fused_rope: + assert not ( + self.attn_temperature_tuning + ), f"{self.attn_temperature_tuning=} must be False when using fused rope" + attn_output = self.attn(q, k, v, positions=positions) + output, _ = self.o_proj(attn_output) + return output + if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k)