99 MLAAttentionImpl )
1010from vllm .attention .backends .utils import PAD_SLOT_ID
1111from vllm .config import get_current_vllm_config
12- from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
13- LinearBase , RowParallelLinear ,
12+ from vllm .model_executor .layers .linear import (LinearBase ,
1413 UnquantizedLinearMethod )
15- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
1614
1715from vllm_ascend .attention .attention_v1 import AscendAttentionState
1816from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -278,6 +276,9 @@ def build_dummy(self, num_reqs: int,
278276 attn_state = AscendAttentionState .DecodeOnly ,
279277 prefill = None ,
280278 decode = decode_metadata ,
279+ query_start_loc = None ,
280+ seq_lens = seq_lens ,
281+ block_tables = block_table ,
281282 )
282283
283284 def build (self ,
@@ -409,20 +410,7 @@ def __init__(
409410 blocksparse_params : Optional [dict [str , Any ]],
410411 logits_soft_cap : Optional [float ],
411412 attn_type : str ,
412- # MLA Specific Arguments
413- q_lora_rank : Optional [int ],
414- kv_lora_rank : int ,
415- qk_nope_head_dim : int ,
416- qk_rope_head_dim : int ,
417- qk_head_dim : int ,
418- v_head_dim : int ,
419- rotary_emb : RotaryEmbedding ,
420- # q_proj should be q_b_proj if q_lora_rank is not None, but from an
421- # attention backend perspective we rely on the layer to pass in the
422- # correct matrix
423- q_proj : ColumnParallelLinear ,
424- kv_b_proj : ColumnParallelLinear ,
425- o_proj : RowParallelLinear ,
413+ kv_sharing_target_layer_name : Optional [str ] = None ,
426414 ** kwargs ,
427415 ) -> None :
428416 self .num_heads = num_heads
@@ -431,25 +419,20 @@ def __init__(
431419 self .num_kv_heads = num_kv_heads
432420 self .kv_cache_dtype = kv_cache_dtype
433421
434- self .q_lora_rank = q_lora_rank
435- self .kv_lora_rank = kv_lora_rank
436- self .qk_nope_head_dim = qk_nope_head_dim
437- self .qk_rope_head_dim = qk_rope_head_dim
438- self .qk_head_dim = qk_head_dim
439- self .v_head_dim = v_head_dim
440-
441- # Hack for V1 for now to avoid torch library overhead (since we are
442- # already inside an attention custom op), pull out the forward
443- # method from the rotary embedding and call it directly
444- # TODO(lucas): we should probably find a cleaner way to do this
445- self .rotary_emb = rotary_emb
446-
447- self .q_proj = q_proj
448- self .kv_b_proj = kv_b_proj
449- self .o_proj = o_proj
450-
422+ # MLA Args
423+ self .q_lora_rank = kwargs ['q_lora_rank' ]
424+ self .kv_lora_rank = kwargs ['kv_lora_rank' ]
425+ self .qk_nope_head_dim = kwargs ['qk_nope_head_dim' ]
426+ self .qk_rope_head_dim = kwargs ['qk_rope_head_dim' ]
427+ self .qk_head_dim = kwargs ['qk_head_dim' ]
428+ self .v_head_dim = kwargs ['v_head_dim' ]
429+ self .rotary_emb = kwargs ['rotary_emb' ]
430+ self .q_proj = kwargs ['q_proj' ]
431+ self .kv_b_proj = kwargs ['kv_b_proj' ]
432+ self .o_proj = kwargs ['o_proj' ]
451433 self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
452434 self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
435+
453436 # Handle the differences between the flash_attn_varlen from flash_attn
454437 # and the one from vllm_flash_attn. The former is used on RoCM and the
455438 # latter has an additional parameter to control FA2 vs FA3
0 commit comments