diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9f72210c60bf9..889ebc6c2e1ff 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,6 +67,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor +from vllm.utils import is_cpu from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -281,6 +282,21 @@ def forward( context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + elif is_cpu(): + seq_length = q.size(1) + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + context_layer = rearrange(output, "b h s d -> b s h d ") else: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d7d7d65659b73..cebb0f36a2b28 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -12,11 +12,13 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) + def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, + computed_len: int): + mm_kwargs = self.multi_modal_input_mapper(mm_data) + + # special processing for mrope position deltas. + mrope_positions = None + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + return mm_kwargs, mrope_positions + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -153,6 +187,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] + slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -171,14 +207,20 @@ def _prepare_prompt( seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids + mrope_positions = None + if (mm_data := seq_group_metadata.multi_modal_data): + mm_kwargs, mrope_positions = self._compute_multi_modal_input( + seq_data, mm_data, computed_len) + multi_modal_inputs_list.append(mm_kwargs) + # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) - - if (mm_data := seq_group_metadata.multi_modal_data): - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) + if mrope_positions: + for idx in range(3): + input_mrope_positions[idx].extend(mrope_positions[idx]) + else: + input_positions.extend(list(range(computed_len, seq_len))) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -202,12 +244,18 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, @@ -238,6 +286,7 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -255,7 +304,17 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append(position) + if seq_data.mrope_position_delta is not None: + context_len = seq_data.get_num_computed_tokens() + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + input_mrope_positions[idx].extend(next_pos[idx]) + else: + input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -273,12 +332,18 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, @@ -373,6 +438,15 @@ def __init__( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config,