File tree Expand file tree Collapse file tree 2 files changed +18
-6
lines changed
v1/attention/backends/mla Expand file tree Collapse file tree 2 files changed +18
-6
lines changed Original file line number Diff line number Diff line change @@ -161,8 +161,13 @@ def forward_cuda(
161161 ) -> Tuple [torch .Tensor , torch .Tensor ]:
162162 from vllm import _custom_ops as ops
163163
164- self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
165- dtype = query .dtype )
164+ # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
165+ # is expensive, so avoid calling it if possible
166+ if self .cos_sin_cache .device != query .device or \
167+ self .cos_sin_cache .dtype != query .dtype :
168+ self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
169+ dtype = query .dtype )
170+
166171 # ops.rotary_embedding()/batched_rotary_embedding()
167172 # are in-place operations that update the query and key tensors.
168173 if offsets is not None :
Original file line number Diff line number Diff line change 222222 apply_fp8_linear_generic , current_platform_fp8_dtype , is_fp8 )
223223from vllm .model_executor .layers .quantization .utils .quant_utils import (
224224 scaled_quantize )
225- from vllm .model_executor .layers .rotary_embedding import (
226- DeepseekScalingRotaryEmbedding , RotaryEmbedding )
225+ from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
226+ from vllm . platforms import current_platform
227227from vllm .utils import cdiv , round_down
228228
229229try :
@@ -627,8 +627,15 @@ def __init__(
627627 self .v_head_dim = v_head_dim
628628
629629 self .rotary_emb = rotary_emb
630- self .use_yarn_rope = isinstance (rotary_emb ,
631- DeepseekScalingRotaryEmbedding )
630+
631+ if current_platform .is_cuda ():
632+ # Hack for V1 for now to avoid torch library overhead (since we are
633+ # already inside an attention custom op), pull out the forward
634+ # method from the rotary embedding and call it directly (and avoid
635+ # calling forward_native, when we can call forward_cuda)
636+ # TODO(lucas): we should probably find a cleaner way to do this
637+ self .rotary_emb = rotary_emb .forward_cuda
638+
632639 self .q_proj = q_proj
633640 self .kv_b_proj = kv_b_proj
634641 self .o_proj = o_proj
You can’t perform that action at this time.
0 commit comments