99 MLAAttentionImpl )
1010from vllm .attention .backends .utils import PAD_SLOT_ID
1111from vllm .config import get_current_vllm_config
12- from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
13- LinearBase , RowParallelLinear ,
12+ from vllm .model_executor .layers .linear import (LinearBase ,
1413 UnquantizedLinearMethod )
15- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
1614
1715from vllm_ascend .attention .attention_v1 import AscendAttentionState
1816from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -422,20 +420,7 @@ def __init__(
422420 blocksparse_params : Optional [dict [str , Any ]],
423421 logits_soft_cap : Optional [float ],
424422 attn_type : str ,
425- # MLA Specific Arguments
426- q_lora_rank : Optional [int ],
427- kv_lora_rank : int ,
428- qk_nope_head_dim : int ,
429- qk_rope_head_dim : int ,
430- qk_head_dim : int ,
431- v_head_dim : int ,
432- rotary_emb : RotaryEmbedding ,
433- # q_proj should be q_b_proj if q_lora_rank is not None, but from an
434- # attention backend perspective we rely on the layer to pass in the
435- # correct matrix
436- q_proj : ColumnParallelLinear ,
437- kv_b_proj : ColumnParallelLinear ,
438- o_proj : RowParallelLinear ,
423+ kv_sharing_target_layer_name : Optional [str ] = None ,
439424 ** kwargs ,
440425 ) -> None :
441426 self .num_heads = num_heads
@@ -444,25 +429,20 @@ def __init__(
444429 self .num_kv_heads = num_kv_heads
445430 self .kv_cache_dtype = kv_cache_dtype
446431
447- self .q_lora_rank = q_lora_rank
448- self .kv_lora_rank = kv_lora_rank
449- self .qk_nope_head_dim = qk_nope_head_dim
450- self .qk_rope_head_dim = qk_rope_head_dim
451- self .qk_head_dim = qk_head_dim
452- self .v_head_dim = v_head_dim
453-
454- # Hack for V1 for now to avoid torch library overhead (since we are
455- # already inside an attention custom op), pull out the forward
456- # method from the rotary embedding and call it directly
457- # TODO(lucas): we should probably find a cleaner way to do this
458- self .rotary_emb = rotary_emb
459-
460- self .q_proj = q_proj
461- self .kv_b_proj = kv_b_proj
462- self .o_proj = o_proj
463-
432+ # MLA Args
433+ self .q_lora_rank = kwargs ['q_lora_rank' ]
434+ self .kv_lora_rank = kwargs ['kv_lora_rank' ]
435+ self .qk_nope_head_dim = kwargs ['qk_nope_head_dim' ]
436+ self .qk_rope_head_dim = kwargs ['qk_rope_head_dim' ]
437+ self .qk_head_dim = kwargs ['qk_head_dim' ]
438+ self .v_head_dim = kwargs ['v_head_dim' ]
439+ self .rotary_emb = kwargs ['rotary_emb' ]
440+ self .q_proj = kwargs ['q_proj' ]
441+ self .kv_b_proj = kwargs ['kv_b_proj' ]
442+ self .o_proj = kwargs ['o_proj' ]
464443 self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
465444 self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
445+
466446 # Handle the differences between the flash_attn_varlen from flash_attn
467447 # and the one from vllm_flash_attn. The former is used on RoCM and the
468448 # latter has an additional parameter to control FA2 vs FA3
0 commit comments