Skip to content

Commit

Permalink
fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent 27ad92c commit c34e5ca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
13 changes: 8 additions & 5 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ def _q_proj_and_k_up_proj(self, x):
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)
return torch.matmul(x, self.W_UK.T)\
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)

def process_weights_after_loading(self):
Expand Down Expand Up @@ -249,13 +250,15 @@ def process_weights_after_loading(self):
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
#quant_config=self.o_proj.quant_method, TODO
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
)

self.o_proj_absored.weight = torch.nn.Parameter(self.W_UV_O.T)
else:
print("Not absorbing weights")
self.W_UK, self.W_UV, self.W_Q = W_UK, W_UV, W_Q
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)

@abstractmethod
def _forward_prefill(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def begin_forward(self, model_input):

@dataclass(kw_only=True)
class TritonMLAMetadata(MLAMetadataCommon):
"""Metadata for FlashAttentionBackend.
"""Metadata for TritonMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
Expand Down Expand Up @@ -189,7 +189,7 @@ class TritonMLAMetadata(MLAMetadataCommon):

num_prefill_tokens: int

num_kv_splits: int = 4
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None

Expand Down
8 changes: 5 additions & 3 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),

# Flag that can control whether
#
#
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
}
Expand Down

0 comments on commit c34e5ca

Please sign in to comment.