Skip to content

Commit 9638902

Browse files
committed
fix mla swap_states
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 7e15c1a commit 9638902

File tree

4 files changed

+62
-2
lines changed

4 files changed

+62
-2
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
from vllm.forward_context import ForwardContext, get_forward_context
2828
from vllm.utils import direct_register_custom_op
2929
from vllm.v1.core.sched.output import SchedulerOutput
30-
from vllm.v1.worker.gpu_input_batch import InputBatch
3130

3231
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3332
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3433
nd_to_nz_2d, nd_to_nz_spec)
34+
from vllm_ascend.worker.npu_input_batch import InputBatch
3535

3636

3737
class AscendAttentionBackend(AttentionBackend):

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
if TYPE_CHECKING:
2727
from vllm.v1.core.sched.output import SchedulerOutput
28-
from vllm.v1.worker.gpu_input_batch import InputBatch
28+
29+
from vllm_ascend.worker.npu_input_batch import InputBatch
2930

3031

3132
@dataclass

vllm_ascend/pool/__init__.py

Whitespace-only changes.

vllm_ascend/worker/npu_input_batch.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
2727
from vllm.pooling_params import PoolingParams
2828
from vllm.sampling_params import SamplingParams, SamplingType
29+
from vllm.utils import swap_dict_values
2930
from vllm.v1.outputs import LogprobsTensors
3031
from vllm.v1.sample.metadata import SamplingMetadata
3132
from vllm.v1.utils import copy_slice
@@ -423,6 +424,64 @@ def remove_request(self, req_id: str) -> Optional[int]:
423424
self.pooling_params.pop(req_id, None)
424425
return req_index
425426

427+
def swap_states(self, i1: int, i2: int) -> None:
428+
old_id_i1 = self._req_ids[i1]
429+
old_id_i2 = self._req_ids[i2]
430+
self._req_ids[i1], self._req_ids[i2] =\
431+
self._req_ids[i2], self._req_ids[i1] # noqa
432+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
433+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
434+
assert old_id_i1 is not None and old_id_i2 is not None
435+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
436+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
437+
self.num_tokens[i1], self.num_tokens[i2] =\
438+
self.num_tokens[i2], self.num_tokens[i1]
439+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
440+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
441+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
442+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
443+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
444+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
445+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
446+
self.temperature_cpu[i2], self.temperature_cpu[i1]
447+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
448+
self.top_p_cpu[i2], self.top_p_cpu[i1]
449+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
450+
self.top_k_cpu[i2], self.top_k_cpu[i1]
451+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
452+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
453+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
454+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
455+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
456+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
457+
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
458+
self.min_p_cpu[i2], self.min_p_cpu[i1]
459+
460+
# NOTE: the following is unsafe
461+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
462+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
463+
# instead, we need to temporiarily copy the data for one of the indices
464+
# TODO(lucas): optimize this by only copying valid indices
465+
tmp = self.token_ids_cpu[i1, ...].copy()
466+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
467+
self.token_ids_cpu[i2, ...] = tmp
468+
469+
swap_dict_values(self.generators, i1, i2)
470+
swap_dict_values(self.min_tokens, i1, i2)
471+
swap_dict_values(self.bad_words_token_ids, i1, i2)
472+
473+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
474+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
475+
self.logit_bias[i1], self.logit_bias[i2] =\
476+
self.logit_bias[i2], self.logit_bias[i1]
477+
478+
if self.allowed_token_ids_mask_cpu_tensor is not None:
479+
self.allowed_token_ids_mask_cpu_tensor[i1], \
480+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
481+
self.allowed_token_ids_mask_cpu_tensor[i2], \
482+
self.allowed_token_ids_mask_cpu_tensor[i1]
483+
self.block_table.swap_row(i1, i2)
484+
426485
def condense(self, empty_req_indices: list[int]) -> None:
427486
"""Move non-empty requests down into lower, empty indices.
428487

0 commit comments

Comments
 (0)