From c3cacb8424dabcd98049203bd864a96c9020e363 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 19 Jan 2025 03:52:13 -0800 Subject: [PATCH] [V1] Add V1 support of Qwen2-VL (#12128) Signed-off-by: Roger Wang Signed-off-by: DarkLight1337 Co-authored-by: imkero Co-authored-by: DarkLight1337 Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/source/models/supported_models.md | 2 +- .../vision_language/test_qwen2_vl.py | 18 +-- vllm/compilation/decorators.py | 14 +- .../model_executor/layers/rotary_embedding.py | 44 +++++- vllm/model_executor/models/llava_onevision.py | 6 +- vllm/model_executor/models/qwen2.py | 10 +- vllm/model_executor/models/qwen2_vl.py | 142 ++++++++++-------- vllm/v1/worker/gpu_input_batch.py | 3 + vllm/v1/worker/gpu_model_runner.py | 138 ++++++++++++++++- 9 files changed, 292 insertions(+), 85 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 2edb610ddf959..eb1bde9ec0089 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ - `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ - - + - ✅︎ * - `UltravoxModel` - Ultravox - T + AE+ diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 16e256e040a74..2fd22f0cc88ec 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -105,7 +105,7 @@ def batch_make_image_embeddings( pixel_values = preprocess_result["pixel_values"] image_grid_thw = preprocess_result["image_grid_thw"] - # pixel values to embeddinds & grid_thws + # pixel values to embeddings & grid_thws with torch.no_grad(): visual = llm.llm_engine.model_executor.driver_worker. \ model_runner.model.visual @@ -124,11 +124,10 @@ def batch_make_image_embeddings( for image_batch in image_batches_: cur_batch_image_count = len(image_batch) merge_size = image_processor.merge_size - cur_batch_embed_len = sum([ - grid_thw.prod() // merge_size // merge_size + cur_batch_embed_len = sum( + grid_thw.prod(-1) // merge_size // merge_size for grid_thw in image_grid_thw[image_counter:image_counter + - cur_batch_image_count] - ]) + cur_batch_image_count]) result.append({ "image_embeds": @@ -187,7 +186,7 @@ def batch_make_video_embeddings( pixel_values = preprocess_result["pixel_values_videos"] video_grid_thw = preprocess_result["video_grid_thw"] - # pixel values to embeddinds & grid_thws + # pixel values to embeddings & grid_thws with torch.no_grad(): visual = llm.llm_engine.model_executor.driver_worker.\ model_runner.model.visual @@ -206,11 +205,10 @@ def batch_make_video_embeddings( for video_batch in video_batches_: cur_batch_video_count = len(video_batch) merge_size = image_processor.merge_size - cur_batch_embed_len = sum([ - grid_thw.prod() // merge_size // merge_size + cur_batch_embed_len = sum( + grid_thw.prod(-1) // merge_size // merge_size for grid_thw in video_grid_thw[video_counter:video_counter + - cur_batch_video_count] - ]) + cur_batch_video_count]) result.append({ "video_embeds": diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 10513111ea7f1..38f284794b8db 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): During runtime, when we actually mark dimensions of tensors, it depends on the value of arguments: - - if it is a single integer, the corresponding dimension of the argument - will be marked as dynamic. + - if it is a single integer (can be negative), the corresponding dimension + of the argument will be marked as dynamic. - if it is `None`, ignored. - if it is `IntermediateTensors`, all the tensors in the intermediate tensors will be marked as dynamic. @@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs): for k, dims in dynamic_arg_dims.items(): arg = bound_args.arguments.get(k) if arg is not None: + dims = [dims] if isinstance(dims, int) else dims if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] torch._dynamo.mark_dynamic(arg, dims) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] torch._dynamo.mark_dynamic(tensor, dims) else: raise ValueError( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 3fcd81a3c4213..d071cfe888f05 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -841,6 +841,37 @@ def get_input_positions( ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" + llm_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + input_tokens, + image_grid_thw, + video_grid_thw, + image_token_id, + video_token_id, + vision_start_token_id, + vision_end_token_id, + spatial_merge_size, + context_len, + seq_len, + ) + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_input_positions_tensor( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() if isinstance(video_grid_thw, torch.Tensor): @@ -916,7 +947,7 @@ def get_input_positions( len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] - return llm_positions.tolist(), mrope_position_delta + return llm_positions, mrope_position_delta @staticmethod def get_next_input_positions( @@ -930,6 +961,17 @@ def get_next_input_positions( seq_len + mrope_position_delta)) for _ in range(3) ] + @staticmethod + def get_next_input_positions_tensor( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> torch.Tensor: + return torch.arange( + mrope_position_delta + context_len, + mrope_position_delta + seq_len, + ).expand(3, -1) + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c9283e0c5ba20..6faa79f65d8de 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key == "pixel_values" and "images" not in modalities: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: modalities["images"] = self._parse_and_validate_image_input( **kwargs) - if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501 + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: modalities["videos"] = self._parse_and_validate_video_input( **kwargs) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index d015f60c6d065..82de1c3574090 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -256,7 +256,15 @@ def forward( return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) class Qwen2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d00e5d362c8bc..34d5c8ad089a3 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,11 +67,15 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vit_attn_backend logger = init_logger(__name__) +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + # === Vision Inputs === # @@ -135,7 +139,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): - List[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). + (concatenation of all videos' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on @@ -611,6 +615,7 @@ def forward( # adapter x = self.merger(x) + return x def load_weights(self, weights: Iterable[Tuple[str, @@ -874,8 +879,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - - num_frames = max(max_total_frames // max(max_videos, 1), 1) + num_frames = min(max(max_total_frames // max(max_videos, 1), 1), + _MAX_FRAMES_PER_VIDEO) # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 if num_frames > 1 and num_frames % 2 == 1: @@ -955,13 +960,14 @@ def _get_prompt_replacements( "image": hf_processor.image_token, "video": hf_processor.video_token, } + merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) - num_tokens = grid_thw.prod() // merge_length + num_tokens = grid_thw.prod().item() // merge_length return placeholder[modality] * num_tokens return [ @@ -1047,11 +1053,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2VLConfig = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - assert not cache_config.enable_prefix_caching, \ - "Qwen2-VL currently does not support prefix caching" self.config = config self.multimodal_config = multimodal_config @@ -1173,59 +1176,82 @@ def _parse_and_validate_video_input( video_embeds=video_embeds, video_grid_thw=video_grid_thw) - def _process_image_input(self, - image_input: Qwen2VLImageInputs) -> torch.Tensor: + def _process_image_input( + self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + if image_input["type"] == "image_embeds": - return image_input["image_embeds"].type(self.visual.dtype) + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, - grid_thw=image_input["image_grid_thw"]) - return image_embeds + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 - def _process_video_input(self, - video_input: Qwen2VLVideoInputs) -> torch.Tensor: if video_input["type"] == "video_embeds": - return video_input["video_embeds"].type(self.visual.dtype) + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=video_input["video_grid_thw"]) - return video_embeds + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size - def _merge_multimodal_embeddings( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: torch.Tensor, - placeholder_token_id: int, - ) -> torch.Tensor: - mask = (input_ids == placeholder_token_id) - inputs_embeds[mask, :] = multimodal_embeddings - return inputs_embeds + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities def get_multimodal_embeddings( self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - if image_input is None and video_input is None: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: return None - # We make a tuple of each embedding with its modality string. This is a - # temporary workaround for models to handle mixed modalities when - # get_multimodal_embeddings and get_input_embeddings are called - # separately. - # TODO(ywang96): Add support for mixed-modality inference for v1. - multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - multimodal_embeddings.append((image_embeds, "image")) - if video_input is not None: - video_embeds = self._process_video_input(video_input) - multimodal_embeddings.append((video_embeds, "video")) + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -1237,21 +1263,9 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - for embeddings, modality in multimodal_embeddings: - if modality == "image": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.image_token_id, - ) - if modality == "video": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.video_token_id, - ) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) return inputs_embeds def forward( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f0..28d8e39053874 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -30,6 +30,9 @@ class CachedRequestState: num_computed_tokens: int output_token_ids: List[int] + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa63d9414c296..87a1cd7f9e627 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -14,6 +14,7 @@ from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType @@ -139,6 +140,32 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + # NOTE: `mrope_positions` is implemented as a permuted tensor to + # satisfy the following properties to allow `torch.compile` to work + # properly: + # - shape: (3, ) + # - stride: (1, 3) + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = torch.zeros((self.max_num_tokens, 3), + dtype=torch.int64, + device=self.device) + self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + self.mrope_positions = self.mrope_positions.permute((1, 0)) + self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0)) + self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -246,6 +273,35 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + ) + req_ids_to_add.append(req_id) # Update the cached states of the resumed requests. @@ -313,6 +369,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): arange, out=positions_np) + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + self._calc_mrope_positions(scheduler_output) + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -359,8 +420,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.model_config.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( @@ -472,6 +541,61 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + num_reqs = self.input_batch.num_reqs + for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + MRotaryEmbedding.get_next_input_positions_tensor( + req.mrope_position_delta, + context_len=num_computed_tokens + + prompt_part_len, + seq_len=num_computed_tokens + + prompt_part_len + + completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + def _prepare_sampling( self, scheduler_output: "SchedulerOutput", @@ -618,9 +742,12 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): + positions = self.mrope_positions[:, :num_input_tokens] \ + if self.model_config.uses_mrope \ + else self.positions[:num_input_tokens] hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=positions, kv_caches=self.kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, @@ -707,9 +834,12 @@ def _dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None with set_forward_context(None, self.vllm_config): + positions = self.mrope_positions[:, :num_tokens] \ + if self.model_config.uses_mrope \ + else self.positions[:num_tokens] hidden_states = model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=positions, kv_caches=kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds,