Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent cfb2d26 commit 5afc1bf
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def forward(
raise NotImplementedError


class MLAAttentionImpl(AttentionImpl):
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):

@abstractmethod
def forward(
Expand Down
19 changes: 10 additions & 9 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generic, List, Optional

import torch

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
MLAAttentionImpl, T)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func


@dataclass(kw_only=True)
class MLAMetadataCommon(AttentionMetadata):
@dataclass
class MLACommonMetadata(AttentionMetadata):
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor


class MLACommonImpl(MLAAttentionImpl):
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
Common class for implementing repeated parts
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: Optional[ColumnParallelLinear],
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
Expand Down Expand Up @@ -252,7 +252,7 @@ def _forward_prefill(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: MLAMetadataCommon,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError

Expand All @@ -262,7 +262,7 @@ def _forward_decode(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: MLAMetadataCommon,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError

Expand All @@ -273,7 +273,7 @@ def forward(
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: MLAMetadataCommon,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
Expand All @@ -289,6 +289,7 @@ def forward(

# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")

if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
Expand Down
23 changes: 11 additions & 12 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
Expand Down Expand Up @@ -122,8 +122,8 @@ def begin_forward(self, model_input):
return


@dataclass(kw_only=True)
class TritonMLAMetadata(MLAMetadataCommon):
@dataclass
class TritonMLAMetadata(MLACommonMetadata):
"""Metadata for TritonMLAMetadata.
NOTE: Any python object stored here is not updated when it is
Expand Down Expand Up @@ -212,10 +212,8 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata

assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None

# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
Expand Down Expand Up @@ -243,6 +241,7 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
Expand All @@ -254,7 +253,6 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_prefill_metadata

Expand All @@ -265,8 +263,7 @@ def decode_metadata(self) -> Optional["TritonMLAMetadata"]:

if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
assert self.seq_lens_tensor is not None

# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
Expand Down Expand Up @@ -569,6 +566,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
input_positions=input_positions,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
Expand All @@ -579,13 +577,12 @@ def build(self, seq_lens: List[int], query_lens: List[int],
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
input_positions=input_positions,
num_kv_splits=num_kv_splits,
head_dim=self.runner.model_config.get_head_size(),
)


class TritonMLAImpl(MLACommonImpl):
class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):

def __init__(
self,
Expand Down Expand Up @@ -628,6 +625,7 @@ def _forward_prefill(
k_pe: torch.Tensor,
attn_metadata: TritonMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, TritonMLAMetadata)
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
attn_metadata.seq_start_loc,
attn_metadata.max_prefill_seq_len)
Expand All @@ -644,6 +642,7 @@ def _forward_decode(
raise NotImplementedError("FP8 Triton MLA not yet supported")

decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]

q = torch.cat([q_nope, q_pe], dim=-1)
Expand Down

0 comments on commit 5afc1bf

Please sign in to comment.