File tree Expand file tree Collapse file tree 2 files changed +6
-18
lines changed
v1/attention/backends/mla Expand file tree Collapse file tree 2 files changed +6
-18
lines changed Original file line number Diff line number Diff line change @@ -161,13 +161,8 @@ def forward_cuda(
161161 ) -> Tuple [torch .Tensor , torch .Tensor ]:
162162 from vllm import _custom_ops as ops
163163
164- # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
165- # is expensive, so avoid calling it if possible
166- if self .cos_sin_cache .device != query .device or \
167- self .cos_sin_cache .dtype != query .dtype :
168- self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
169- dtype = query .dtype )
170-
164+ self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
165+ dtype = query .dtype )
171166 # ops.rotary_embedding()/batched_rotary_embedding()
172167 # are in-place operations that update the query and key tensors.
173168 if offsets is not None :
Original file line number Diff line number Diff line change 222222 Fp8LinearGenericOp , current_platform_fp8_dtype , is_fp8 )
223223from vllm .model_executor .layers .quantization .utils .quant_utils import (
224224 scaled_quantize )
225- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
226- from vllm . platforms import current_platform
225+ from vllm .model_executor .layers .rotary_embedding import (
226+ DeepseekScalingRotaryEmbedding , RotaryEmbedding )
227227from vllm .utils import cdiv , round_down
228228
229229try :
@@ -627,15 +627,8 @@ def __init__(
627627 self .v_head_dim = v_head_dim
628628
629629 self .rotary_emb = rotary_emb
630-
631- if current_platform .is_cuda ():
632- # Hack for V1 for now to avoid torch library overhead (since we are
633- # already inside an attention custom op), pull out the forward
634- # method from the rotary embedding and call it directly (and avoid
635- # calling forward_native, when we can call forward_cuda)
636- # TODO(lucas): we should probably find a cleaner way to do this
637- self .rotary_emb = rotary_emb .forward_cuda
638-
630+ self .use_yarn_rope = isinstance (rotary_emb ,
631+ DeepseekScalingRotaryEmbedding )
639632 self .q_proj = q_proj
640633 self .kv_b_proj = kv_b_proj
641634 self .o_proj = o_proj
You can’t perform that action at this time.
0 commit comments