Skip to content

Commit dae6896

Browse files
[Perf] Reduce MLA CPU overheads in V1 (#14384)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent c34eeec commit dae6896

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,13 @@ def forward_cuda(
161161
) -> Tuple[torch.Tensor, torch.Tensor]:
162162
from vllm import _custom_ops as ops
163163

164-
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
165-
dtype=query.dtype)
164+
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
165+
# is expensive, so avoid calling it if possible
166+
if self.cos_sin_cache.device != query.device or \
167+
self.cos_sin_cache.dtype != query.dtype:
168+
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
169+
dtype=query.dtype)
170+
166171
# ops.rotary_embedding()/batched_rotary_embedding()
167172
# are in-place operations that update the query and key tensors.
168173
if offsets is not None:

vllm/v1/attention/backends/mla/common.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@
222222
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
223223
from vllm.model_executor.layers.quantization.utils.quant_utils import (
224224
scaled_quantize)
225-
from vllm.model_executor.layers.rotary_embedding import (
226-
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
225+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
226+
from vllm.platforms import current_platform
227227
from vllm.utils import cdiv, round_down
228228

229229
try:
@@ -627,8 +627,15 @@ def __init__(
627627
self.v_head_dim = v_head_dim
628628

629629
self.rotary_emb = rotary_emb
630-
self.use_yarn_rope = isinstance(rotary_emb,
631-
DeepseekScalingRotaryEmbedding)
630+
631+
if current_platform.is_cuda():
632+
# Hack for V1 for now to avoid torch library overhead (since we are
633+
# already inside an attention custom op), pull out the forward
634+
# method from the rotary embedding and call it directly (and avoid
635+
# calling forward_native, when we can call forward_cuda)
636+
# TODO(lucas): we should probably find a cleaner way to do this
637+
self.rotary_emb = rotary_emb.forward_cuda
638+
632639
self.q_proj = q_proj
633640
self.kv_b_proj = kv_b_proj
634641
self.o_proj = o_proj

0 commit comments

Comments
 (0)