From 98010c1ac3554b616fb6f44496784d635be1723c Mon Sep 17 00:00:00 2001 From: qingjun Date: Tue, 6 May 2025 20:15:09 +0800 Subject: [PATCH 01/30] Add image input processing and merge image patch embedding functionality. Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 151 ++++++++++++++++---- 1 file changed, 127 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 4ac60f97bb5f..f4a2bdb0a19a 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping -from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast import torch import torch.nn as nn from transformers import BatchFeature +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig from vllm.jsontree import json_map_leaves @@ -139,6 +141,7 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return { "pixel_values": MultiModalFieldConfig.batched("image"), + "image_sizes": MultiModalFieldConfig.batched("image"), "image_embeds": MultiModalFieldConfig.batched("image"), } @@ -250,45 +253,144 @@ def _process_image_pixels( pixel_values = inputs["pixel_values"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) + if isinstance(pixel_values, torch.Tensor): + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + stacked_patch_embeddings = self.multi_modal_projector( + stacked_image_features) + + return stacked_patch_embeddings.view( + b, num_patches, *stacked_patch_embeddings.shape[1:]) + + num_patches_per_batch = [v.shape[0] for v in pixel_values] + stacked_pixel_values = torch.cat(pixel_values) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return torch.split(self.multi_modal_projector(stacked_image_features), + num_patches_per_batch) def _process_image_input( self, image_input: MiniMaxVL01ImageInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> Union[torch.Tensor, List[torch.Tensor]]: if image_input["type"] == "image_embeds": - return image_input["data"] - - assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) - - if isinstance(image_features, torch.Tensor): - return self.multi_modal_projector(image_features) - - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features + return [image_input["data"]] + + patch_embeddings = self._process_image_pixels(image_input) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = len(image_input["data"]) + vision_config = self.config.vision_config + default_height = default_width = vision_config.image_size + image_sizes = torch.as_tensor([[default_height, default_width] + for _ in range(batch_size)]) + + return [ + self._merge_image_patch_embeddings(image_sizes[i], + patch_features_batch, + strategy="spatial_unpad") + for i, patch_features_batch in enumerate(patch_embeddings) ] - image_embeds = self.multi_modal_projector(torch.cat(image_features)) - image_embeds = torch.split(image_embeds, feature_sizes) - return image_embeds + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, + patch_embeddings: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + + # image_aspect_ratio == "anyres" + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ + .view(num_patch_height, num_patch_width, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + (orig_height, orig_width)) + other_patch_embeds = torch.cat(( + other_patch_embeds, + self.image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -302,7 +404,8 @@ def _parse_and_validate_image_input( return MiniMaxVL01ImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + flatten_bn(pixel_values)), + image_sizes=flatten_bn(image_sizes, concat=True), ) if image_embeds is not None: From 01c91c00e45c5e551af83d713809504ef4aaff19 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 11:22:26 +0800 Subject: [PATCH 02/30] fix code --- vllm/model_executor/models/minimax_text_01.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 951f4e2304a1..d8a35da49c20 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -432,18 +432,23 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): if _prefill_idx >= len(attn_metadata.query_start_loc): - break + _start = 0 + else: + _start = attn_metadata.query_start_loc[_prefill_idx] + + if _prefill_idx+1 >= len(attn_metadata.query_end_loc): + _end = 0 + else: + _end = attn_metadata.query_end_loc[_prefill_idx+1] + if _prefill_idx >= len(state_indices_tensor): - break - _start = attn_metadata.query_start_loc[_prefill_idx] - _end = attn_metadata.query_start_loc[_prefill_idx + 1] - slot_id = state_indices_tensor[_prefill_idx] + slot_id = 0 + else: + slot_id = state_indices_tensor[_prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() - slot_id = state_indices_tensor[_prefill_idx] slice_layer_cache = kv_cache[slot_id, ...] - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( qs, ks, @@ -458,7 +463,7 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata)) - if not hidden: + if len(hidden) == 0: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) hidden = torch.concat(hidden, dim=0).contiguous() From d66c1bbe4b03c35cb214b2adfa4070223f3a29d6 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 11:28:11 +0800 Subject: [PATCH 03/30] fix code --- vllm/model_executor/models/minimax_text_01.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index d8a35da49c20..640b5c8e61bf 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -436,10 +436,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, else: _start = attn_metadata.query_start_loc[_prefill_idx] - if _prefill_idx+1 >= len(attn_metadata.query_end_loc): + if _prefill_idx+1 >= len(attn_metadata.query_start_loc): _end = 0 else: - _end = attn_metadata.query_end_loc[_prefill_idx+1] + _end = attn_metadata.query_start_loc[_prefill_idx+1] if _prefill_idx >= len(state_indices_tensor): slot_id = 0 From 9e3a12109cd0cf086770fcc7879bf01c4d1a1fb7 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 11:36:28 +0800 Subject: [PATCH 04/30] fix code --- vllm/model_executor/models/minimax_text_01.py | 217 +++++------------- 1 file changed, 59 insertions(+), 158 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 640b5c8e61bf..0328e680d7d2 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -3,7 +3,8 @@ import copy import math import re -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from vllm.model_executor.layers.rotary_embedding import get_rope import torch import torch.distributed @@ -333,180 +334,80 @@ class MiniMaxText01LinearAttention(nn.Module): def __init__( self, hidden_size: int, - hidden_inner_size: int, num_heads: int, - head_dim: int, - max_position: int, - block_size: int, - num_hidden_layer: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - layer_idx: int = 0, - linear_layer_idx: int = 0, - prefix: str = "linear_attn", + prefix: str = "", ) -> None: super().__init__() - - self.layer_idx = layer_idx - self.BLOCK = block_size self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = head_dim + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - self.hidden_inner_size = hidden_inner_size - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size - self.qkv_size = self.num_heads * self.head_dim - self.tp_hidden = self.head_dim * self.tp_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size * 3, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.output_gate = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - self.hidden_inner_size, - bias=False, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, quant_config=quant_config, - prefix=f"{prefix}.output_gate", ) - self.out_proj = RowParallelLinear( - self.hidden_inner_size, + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.norm = MiniMaxText01RMSNormTP( - self.hidden_inner_size, - eps=1e-5, ) - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.num_heads) - if num_hidden_layer <= 1: - self.slope_rate = slope_rate * (1 + 1e-5) - else: - self.slope_rate = slope_rate * (1 - layer_idx / - (num_hidden_layer - 1) + 1e-5) - self.tp_slope = self.slope_rate[self.tp_rank * - self.tp_heads:(self.tp_rank + 1) * - self.tp_heads].contiguous() - - @staticmethod - def weight_direct_load(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - return - - @staticmethod - def _build_slope_tensor(n_attention_heads: int): - - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.tensor(get_slopes(n_attention_heads), - dtype=torch.float32).reshape( - n_attention_heads, 1, 1) - return slopes - - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - hidden = [] - for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx >= len(attn_metadata.query_start_loc): - _start = 0 - else: - _start = attn_metadata.query_start_loc[_prefill_idx] - - if _prefill_idx+1 >= len(attn_metadata.query_start_loc): - _end = 0 - else: - _end = attn_metadata.query_start_loc[_prefill_idx+1] - - if _prefill_idx >= len(state_indices_tensor): - slot_id = 0 - else: - slot_id = state_indices_tensor[_prefill_idx] - qs = q[_start:_end].transpose(0, 1).contiguous() - ks = k[_start:_end].transpose(0, 1).contiguous() - vs = v[_start:_end].transpose(0, 1).contiguous() - slice_layer_cache = kv_cache[slot_id, ...] - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( - qs, - ks, - vs, - slice_layer_cache, - self.tp_slope, - self.BLOCK, - layer_idx=self.layer_idx) - hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden.append( - self._decode_infer(q, k, v, kv_cache, state_indices_tensor, - attn_metadata)) - - if len(hidden) == 0: - return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) - - hidden = torch.concat(hidden, dim=0).contiguous() - return hidden - - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 - ):] - hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, - slot_id, 32) - return hidden + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - qkv32 = qkv.to(torch.float32) - qkvact = torch.nn.functional.silu(qkv32) - qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) - q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor - - decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, attn_metadata) - - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) - hidden = F.sigmoid(gate) * hidden - hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output class MiniMaxText01Attention(nn.Module): From 9691712e8c1eea286ad26a105e5885bdec58d0f4 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 11:57:01 +0800 Subject: [PATCH 05/30] fix code --- vllm/model_executor/models/minimax_text_01.py | 204 +++++++++++++----- 1 file changed, 145 insertions(+), 59 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 0328e680d7d2..d1b7d90c6418 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -3,8 +3,7 @@ import copy import math import re -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union -from vllm.model_executor.layers.rotary_embedding import get_rope +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.distributed @@ -334,80 +333,166 @@ class MiniMaxText01LinearAttention(nn.Module): def __init__( self, hidden_size: int, + hidden_inner_size: int, num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", ) -> None: super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.num_heads = num_heads + self.head_dim = head_dim self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() - self.qkv_proj = QKVParallelLinear( + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + self.qkv_proj = ColumnParallelLinear( hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, + self.hidden_inner_size * 3, + bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) - - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, + self.output_gate = ColumnParallelLinear( hidden_size, + self.hidden_inner_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.output_gate", ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * (1 - layer_idx / + (num_hidden_layer - 1) + 1e-5) + self.tp_slope = self.slope_rate[self.tp_rank * + self.tp_heads:(self.tp_rank + 1) * + self.tp_heads].contiguous() + + @staticmethod + def weight_direct_load(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.tensor(get_slopes(n_attention_heads), + dtype=torch.float32).reshape( + n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + _start = attn_metadata.query_start_loc[_prefill_idx] + _end = attn_metadata.query_start_loc[_prefill_idx + 1] + slot_id = state_indices_tensor[_prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slot_id = state_indices_tensor[_prefill_idx] + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden.append( + self._decode_infer(q, k, v, kv_cache, state_indices_tensor, + attn_metadata)) + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 + ):] + hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, + slot_id, 32) + return hidden + + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, attn_metadata) + + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + hidden, _ = self.out_proj(hidden) + return hidden class MiniMaxText01Attention(nn.Module): @@ -515,6 +600,7 @@ def __init__( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, config.max_model_len) + config.attention_type = 1 if config.attention_type == 0: use_headxdim = True hidden_inner = (head_dim * config.num_attention_heads From b5c27f92cdc46675cd20e1751774761f996cde9b Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 14:45:53 +0800 Subject: [PATCH 06/30] fix code --- vllm/model_executor/models/minimax_text_01.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index d1b7d90c6418..3d468c345663 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -433,7 +433,6 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): _start = attn_metadata.query_start_loc[_prefill_idx] _end = attn_metadata.query_start_loc[_prefill_idx + 1] - slot_id = state_indices_tensor[_prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() @@ -600,7 +599,6 @@ def __init__( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, config.max_model_len) - config.attention_type = 1 if config.attention_type == 0: use_headxdim = True hidden_inner = (head_dim * config.num_attention_heads @@ -836,7 +834,7 @@ def layer_fn(prefix): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") - + linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) if self.decoder_attention_types[i] == 0) max_slots_number = scheduler_config.max_num_seqs @@ -848,8 +846,7 @@ def layer_fn(prefix): self._dtype = _dummy.dtype del _dummy - self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, - cache_shape=self.cache_shape) + self.minimax_cache: Optional[MinimaxCacheManager] = None rope_theta = getattr(config, "rope_theta", 10000) head_dim = getattr(config, "head_dim", @@ -910,6 +907,8 @@ def get_input_embeddings( def forward(self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + minimax_cache: MinimaxCacheManager, + minimax_cache_params: MinimaxCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: @@ -917,21 +916,14 @@ def forward(self, attn_metadata = forward_context.attn_metadata if attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) + + self.minimax_cache = minimax_cache + minimax_cache_tensors = minimax_cache_params.minimax_cache_tensors + if getattr(attn_metadata, "num_prefills", 0) > 0: self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, **kwargs) - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) if get_pp_group().is_first_rank: if inputs_embeds is None: hidden_states = self.embed_scale * self.embed_tokens(input_ids) @@ -989,6 +981,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config.sliding_window = None self.CONCAT_FFN = True + self.minimax_cache: Optional[MinimaxCacheManager] = None self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): @@ -1038,7 +1031,11 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, + if self.minimax_cache is None: + self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, + cache_shape=self.cache_shape) + minimax_cache_params = self.minimax_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, self.minimax_cache, minimax_cache_params, intermediate_tensors, inputs_embeds, **kwargs) return hidden_states From 4aafe76fb3eb7b1c051369dc56046e2e5ea7e209 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 14:56:12 +0800 Subject: [PATCH 07/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index f4a2bdb0a19a..d3adad026ce3 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -8,6 +8,7 @@ from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) +from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from vllm.config import VllmConfig from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.activation import get_act_fn @@ -188,6 +189,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) + self.minimax_cache: Optional[MinimaxCacheManager] = None self.vision_feature_layer = config.vision_feature_layer self.vocab_size = config.text_config.vocab_size self.pad_token_id = -1 @@ -444,9 +446,15 @@ def forward( inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None + if self.minimax_cache is None: + self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, + cache_shape=self.cache_shape) + minimax_cache_params = self.minimax_cache.current_run_tensors(**kwargs) hidden_states = self.language_model.model(input_ids, positions, + self.minimax_cache, + minimax_cache_params, intermediate_tensors, inputs_embeds=inputs_embeds) From 5fa7d0a44c0e3dc79663d42ef34e2e190841bf95 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 16:50:52 +0800 Subject: [PATCH 08/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index d3adad026ce3..cdd83cc89d1a 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -189,6 +189,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) + _dummy = torch.zeros(1) + self._dtype = _dummy.dtype self.minimax_cache: Optional[MinimaxCacheManager] = None self.vision_feature_layer = config.vision_feature_layer self.vocab_size = config.text_config.vocab_size From 77cf22ad12081fa6609ecb9021247008f297c85c Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 16:54:03 +0800 Subject: [PATCH 09/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index cdd83cc89d1a..06ea7b732b38 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -8,6 +8,9 @@ from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from vllm.config import VllmConfig from vllm.jsontree import json_map_leaves @@ -192,6 +195,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: _dummy = torch.zeros(1) self._dtype = _dummy.dtype self.minimax_cache: Optional[MinimaxCacheManager] = None + linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) + if self.decoder_attention_types[i] == 0) + max_slots_number = vllm_config.scheduler_config.max_num_seqs + self.cache_shape = (linear_layer_nums, max_slots_number, + config.num_attention_heads // + get_tensor_model_parallel_world_size(), + config.head_dim, config.head_dim) self.vision_feature_layer = config.vision_feature_layer self.vocab_size = config.text_config.vocab_size self.pad_token_id = -1 From 7264bec7ff60b3cfbc78047e0a08a59b52e7f3c2 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 16:57:16 +0800 Subject: [PATCH 10/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 06ea7b732b38..b04043735ecd 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -195,13 +195,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: _dummy = torch.zeros(1) self._dtype = _dummy.dtype self.minimax_cache: Optional[MinimaxCacheManager] = None - linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) + linear_layer_nums = sum(1 for i in range(config.text_config.num_hidden_layers) if self.decoder_attention_types[i] == 0) max_slots_number = vllm_config.scheduler_config.max_num_seqs self.cache_shape = (linear_layer_nums, max_slots_number, - config.num_attention_heads // + config.text_config.num_attention_heads // get_tensor_model_parallel_world_size(), - config.head_dim, config.head_dim) + config.text_config.head_dim, config.text_config.head_dim) self.vision_feature_layer = config.vision_feature_layer self.vocab_size = config.text_config.vocab_size self.pad_token_id = -1 From 4d74259c26316d47a2a0ca806046bc1d12344983 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 16:59:17 +0800 Subject: [PATCH 11/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index b04043735ecd..dd73b4e2452e 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -194,6 +194,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) _dummy = torch.zeros(1) self._dtype = _dummy.dtype + self.decoder_attention_types = getattr( + config.text_config, "attn_type_list", False) or getattr( + config.text_config, "decoder_attention_types", False) + if not self.decoder_attention_types: + self.decoder_attention_types = [1] * config.text_config.num_hidden_layers self.minimax_cache: Optional[MinimaxCacheManager] = None linear_layer_nums = sum(1 for i in range(config.text_config.num_hidden_layers) if self.decoder_attention_types[i] == 0) From 8590946f045334117b5c46c74d559665b7f0c2b3 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:06:05 +0800 Subject: [PATCH 12/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index dd73b4e2452e..ce58b4281a3c 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -463,6 +463,12 @@ def forward( inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None + + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] + if self.minimax_cache is None: self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, cache_shape=self.cache_shape) From 714c7201a1ea4c5e7f6ba5b5e36d9b9e54cb5990 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:13:02 +0800 Subject: [PATCH 13/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 3d468c345663..46ce3b52221c 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -918,7 +918,7 @@ def forward(self, return None self.minimax_cache = minimax_cache - minimax_cache_tensors = minimax_cache_params.minimax_cache_tensors + minimax_cache_tensors = minimax_cache_params.state_indices_tensor if getattr(attn_metadata, "num_prefills", 0) > 0: self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, From 078ab53bcace1eb514d6feb4bb600e47dada8da2 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:21:46 +0800 Subject: [PATCH 14/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_cache.py | 7 +++++++ vllm/model_executor/models/minimax_text_01.py | 19 ++++++++----------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/minimax_cache.py b/vllm/model_executor/models/minimax_cache.py index c95cbb419eb9..65d137a46ed3 100644 --- a/vllm/model_executor/models/minimax_cache.py +++ b/vllm/model_executor/models/minimax_cache.py @@ -33,3 +33,10 @@ def _copy_cache(self, from_index: int, to_index: int): for cache_t in self.cache: cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) + + def current_run_tensors(self, **kwargs) -> MinimaxCacheParams: + """ + Return the tensors for the current run as MinimaxCacheParams. + """ + cache_tensors, state_indices_tensor = super().current_run_tensors(**kwargs) + return MinimaxCacheParams(cache_tensors, state_indices_tensor) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 46ce3b52221c..9450c68d6fed 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -835,17 +835,6 @@ def layer_fn(prefix): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") - linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) - if self.decoder_attention_types[i] == 0) - max_slots_number = scheduler_config.max_num_seqs - self.cache_shape = (linear_layer_nums, max_slots_number, - config.num_attention_heads // - get_tensor_model_parallel_world_size(), - config.head_dim, config.head_dim) - _dummy = torch.zeros(1) - self._dtype = _dummy.dtype - del _dummy - self.minimax_cache: Optional[MinimaxCacheManager] = None rope_theta = getattr(config, "rope_theta", 10000) @@ -1009,6 +998,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: flash_layer_count = sum(1 for attn_type in self.config.attn_type_list if attn_type == 1) self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] + + linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) + if config.attn_type_list[i] == 0) + max_slots_number = vllm_config.scheduler_config.max_num_seqs + self.cache_shape = (linear_layer_nums, max_slots_number, + config.num_attention_heads // + get_tensor_model_parallel_world_size(), + config.head_dim, config.head_dim) return def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): From dfc7b06b5a1bd91b043583553b499c7421efc895 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:28:26 +0800 Subject: [PATCH 15/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index ce58b4281a3c..ff2263780dc1 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -479,7 +479,8 @@ def forward( self.minimax_cache, minimax_cache_params, intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + **kwargs) return hidden_states From a73d9ccaa0d0ccf9741bcfedfd8ea8cdac762c7d Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:35:25 +0800 Subject: [PATCH 16/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 9450c68d6fed..093fbf14ea59 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -436,7 +436,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() - slot_id = state_indices_tensor[_prefill_idx] + if _prefill_idx >= len(state_indices_tensor): + slot_id = 0 + else: + slot_id = state_indices_tensor[_prefill_idx] slice_layer_cache = kv_cache[slot_id, ...] out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( From 6aabdb69af37b73b139a899b5a85c86c889f09b2 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:41:13 +0800 Subject: [PATCH 17/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 5 +---- vllm/model_executor/models/minimax_vl_01.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 093fbf14ea59..9450c68d6fed 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -436,10 +436,7 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() - if _prefill_idx >= len(state_indices_tensor): - slot_id = 0 - else: - slot_id = state_indices_tensor[_prefill_idx] + slot_id = state_indices_tensor[_prefill_idx] slice_layer_cache = kv_cache[slot_id, ...] out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index ff2263780dc1..31f76393e98b 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -455,7 +455,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - + print("minimax_vl_01 forward", **kwargs) if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: From 334394ab70d908eac9321e0959cfe907cc299006 Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:49:56 +0800 Subject: [PATCH 18/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 31f76393e98b..e1afd6d4f2e7 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -455,7 +455,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - print("minimax_vl_01 forward", **kwargs) if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: From 29a7d2742a7341475ad352e05e05eafba3ea22db Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 17:56:54 +0800 Subject: [PATCH 19/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 9450c68d6fed..f5caa5a1f74b 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -431,6 +431,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): attn_metadata = forward_context.attn_metadata hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx+1 >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break _start = attn_metadata.query_start_loc[_prefill_idx] _end = attn_metadata.query_start_loc[_prefill_idx + 1] qs = q[_start:_end].transpose(0, 1).contiguous() From 1e30f412219dd8dce025bb3002ad00181a3bdd6a Mon Sep 17 00:00:00 2001 From: qingjun Date: Wed, 7 May 2025 19:48:28 +0800 Subject: [PATCH 20/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 4 --- vllm/model_executor/models/minimax_vl_01.py | 32 +++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 1 + 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index f5caa5a1f74b..9450c68d6fed 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -431,10 +431,6 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): attn_metadata = forward_context.attn_metadata hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx+1 >= len(attn_metadata.query_start_loc): - break - if _prefill_idx >= len(state_indices_tensor): - break _start = attn_metadata.query_start_loc[_prefill_idx] _end = attn_metadata.query_start_loc[_prefill_idx + 1] qs = q[_start:_end].transpose(0, 1).contiguous() diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index e1afd6d4f2e7..8aa3c27f5e09 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -7,7 +7,7 @@ from transformers import BatchFeature from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) - +from vllm.v1.core.sched.output import SchedulerOutput from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -447,6 +447,26 @@ def get_multimodal_embeddings( return self._process_image_input(image_input) + def calculate_request_ids_to_seq_ids(self, scheduler_output: SchedulerOutput): + request_ids_to_seq_ids = {} + + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + request_ids_to_seq_ids[req_id] = [0] + + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id + sampling_params = req_data.sampling_params + if hasattr(sampling_params, 'n') and sampling_params.n > 1: + request_ids_to_seq_ids[req_id] = list(range(sampling_params.n)) + else: + request_ids_to_seq_ids[req_id] = [0] + + return request_ids_to_seq_ids + + def calculate_finished_requests_ids(self,scheduler_output: SchedulerOutput): + return scheduler_output.finished_req_ids + def forward( self, input_ids: torch.Tensor, @@ -462,12 +482,12 @@ def forward( inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None - - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] + if "scheduler_output" in kwargs: + scheduler_output = kwargs["scheduler_output"] + kwargs["request_ids_to_seq_ids"] = self.calculate_request_ids_to_seq_ids(scheduler_output) + kwargs["finished_requests_ids"] = self.calculate_finished_requests_ids(scheduler_output) + if self.minimax_cache is None: self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, cache_shape=self.cache_shape) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b4659..6d30c80d01f2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1107,6 +1107,7 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + scheduler_output=scheduler_output, ) if self.use_aux_hidden_state_outputs: From f5149a9ad9bee67ce75716bf47bfb14a43290b6e Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 10:19:09 +0800 Subject: [PATCH 21/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 3 ++- vllm/worker/model_runner.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 8aa3c27f5e09..a89cec8b5887 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -464,7 +464,7 @@ def calculate_request_ids_to_seq_ids(self, scheduler_output: SchedulerOutput): return request_ids_to_seq_ids - def calculate_finished_requests_ids(self,scheduler_output: SchedulerOutput): + def calculate_finished_requests_ids(self, scheduler_output: SchedulerOutput): return scheduler_output.finished_req_ids def forward( @@ -487,6 +487,7 @@ def forward( scheduler_output = kwargs["scheduler_output"] kwargs["request_ids_to_seq_ids"] = self.calculate_request_ids_to_seq_ids(scheduler_output) kwargs["finished_requests_ids"] = self.calculate_finished_requests_ids(scheduler_output) + print("minimax_vl_01 add request_ids_to_seq_ids and finished_requests_ids") if self.minimax_cache is None: self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e22bbcc656ff..0a4d3c3ba64c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1412,8 +1412,10 @@ def _dummy_run(self, # Disable KV Scale Calculation for dummy data during profile run if model_input.attn_metadata is not None: model_input.attn_metadata.enable_kv_scales_calculation = False - - self.execute_model(model_input, kv_caches, intermediate_tensors) + kwargs = {} + kwargs["request_ids_to_seq_ids"] = model_input.request_ids_to_seq_ids + kwargs["finished_requests_ids"] = model_input.finished_requests_ids + self.execute_model(model_input, kv_caches, intermediate_tensors, **kwargs) torch.cuda.synchronize() if self.lora_config: self._remove_dummy_loras() From c40564a982e3eb7f5899efda93739bf079582ff9 Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 10:19:43 +0800 Subject: [PATCH 22/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_vl_01.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index a89cec8b5887..af46b6b1e9d7 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -489,6 +489,11 @@ def forward( kwargs["finished_requests_ids"] = self.calculate_finished_requests_ids(scheduler_output) print("minimax_vl_01 add request_ids_to_seq_ids and finished_requests_ids") + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] + if self.minimax_cache is None: self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, cache_shape=self.cache_shape) From 9303eecc8cf436261d487c12f399eb77ed343a13 Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 10:29:30 +0800 Subject: [PATCH 23/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 9450c68d6fed..47ef1a66dec4 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -426,13 +426,17 @@ def get_slopes_power_of_2(n): n_attention_heads, 1, 1) return slopes - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break _start = attn_metadata.query_start_loc[_prefill_idx] _end = attn_metadata.query_start_loc[_prefill_idx + 1] + slot_id = state_indices_tensor[_prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() @@ -452,6 +456,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): hidden.append( self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata)) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + hidden = torch.concat(hidden, dim=0).contiguous() return hidden From 0c8815ec7fa99f3b8bd08de3a3b8bf581685f58d Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 10:35:42 +0800 Subject: [PATCH 24/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 47ef1a66dec4..596f39fd4e45 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -426,9 +426,10 @@ def get_slopes_power_of_2(n): n_attention_heads, 1, 1) return slopes - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): hidden = [] + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): if _prefill_idx >= len(attn_metadata.query_start_loc): break From 0ba44cb62ac92160c396196a2beafbfc4a6b25f1 Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 10:43:20 +0800 Subject: [PATCH 25/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 596f39fd4e45..ff3a093da3d0 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -495,11 +495,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, hidden = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) - hidden = F.sigmoid(gate) * hidden - hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) + # hidden = self.norm._forward(hidden) + # gate, _ = self.output_gate(hidden_states) + # hidden = F.sigmoid(gate) * hidden + # hidden = hidden.to(hidden_states.dtype) + # hidden, _ = self.out_proj(hidden) return hidden From fad38fac80a34d2738dca3463da262cbd26c6b7b Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 10:49:46 +0800 Subject: [PATCH 26/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index ff3a093da3d0..6b193b4bf842 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -608,6 +608,7 @@ def __init__( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, config.max_model_len) + config.attention_type = 1 if config.attention_type == 0: use_headxdim = True hidden_inner = (head_dim * config.num_attention_heads @@ -642,7 +643,7 @@ def __init__( else: raise ValueError( f"Unsupported attention type: {self.config.attention_type}") - + config.attention_type = 0 if expert_num == 1: self.mlp = MiniMaxText01MLP( hidden_size=self.hidden_size, From 86f856fb8d67ca826526cbae2954cdcc6cd4a94c Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 11:03:15 +0800 Subject: [PATCH 27/30] fix format Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 26 +++++----- vllm/model_executor/models/minimax_vl_01.py | 51 +++++-------------- vllm/v1/worker/gpu_model_runner.py | 1 - vllm/worker/model_runner.py | 6 +-- 4 files changed, 29 insertions(+), 55 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 6b193b4bf842..c640949b71a3 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -495,11 +495,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, hidden = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - # hidden = self.norm._forward(hidden) - # gate, _ = self.output_gate(hidden_states) - # hidden = F.sigmoid(gate) * hidden - # hidden = hidden.to(hidden_states.dtype) - # hidden, _ = self.out_proj(hidden) + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + hidden, _ = self.out_proj(hidden) return hidden @@ -608,7 +608,6 @@ def __init__( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, config.max_model_len) - config.attention_type = 1 if config.attention_type == 0: use_headxdim = True hidden_inner = (head_dim * config.num_attention_heads @@ -643,7 +642,7 @@ def __init__( else: raise ValueError( f"Unsupported attention type: {self.config.attention_type}") - config.attention_type = 0 + if expert_num == 1: self.mlp = MiniMaxText01MLP( hidden_size=self.hidden_size, @@ -844,7 +843,7 @@ def layer_fn(prefix): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") - + self.minimax_cache: Optional[MinimaxCacheManager] = None rope_theta = getattr(config, "rope_theta", 10000) @@ -1008,7 +1007,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: flash_layer_count = sum(1 for attn_type in self.config.attn_type_list if attn_type == 1) self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] - + linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) if config.attn_type_list[i] == 0) max_slots_number = vllm_config.scheduler_config.max_num_seqs @@ -1038,11 +1037,12 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if self.minimax_cache is None: - self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, - cache_shape=self.cache_shape) + if self.minimax_cache is None: + self.minimax_cache = MinimaxCacheManager( + dtype=self._dtype, cache_shape=self.cache_shape) minimax_cache_params = self.minimax_cache.current_run_tensors(**kwargs) - hidden_states = self.model(input_ids, positions, self.minimax_cache, minimax_cache_params, intermediate_tensors, + hidden_states = self.model(input_ids, positions, self.minimax_cache, + minimax_cache_params, intermediate_tensors, inputs_embeds, **kwargs) return hidden_states diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index af46b6b1e9d7..d34020255ded 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -7,12 +7,10 @@ from transformers import BatchFeature from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) -from vllm.v1.core.sched.output import SchedulerOutput + +from vllm.config import VllmConfig from vllm.distributed.parallel_state import ( - get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams -from vllm.config import VllmConfig from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -29,6 +27,7 @@ from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder, init_vision_tower_for_llava) from .llava_next import LlavaNextProcessingInfo +from .minimax_cache import MinimaxCacheManager from .pixtral import PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -198,15 +197,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config.text_config, "attn_type_list", False) or getattr( config.text_config, "decoder_attention_types", False) if not self.decoder_attention_types: - self.decoder_attention_types = [1] * config.text_config.num_hidden_layers + self.decoder_attention_types = [ + 1 + ] * config.text_config.num_hidden_layers self.minimax_cache: Optional[MinimaxCacheManager] = None - linear_layer_nums = sum(1 for i in range(config.text_config.num_hidden_layers) - if self.decoder_attention_types[i] == 0) + linear_layer_nums = sum( + 1 for i in range(config.text_config.num_hidden_layers) + if self.decoder_attention_types[i] == 0) max_slots_number = vllm_config.scheduler_config.max_num_seqs self.cache_shape = (linear_layer_nums, max_slots_number, config.text_config.num_attention_heads // get_tensor_model_parallel_world_size(), - config.text_config.head_dim, config.text_config.head_dim) + config.text_config.head_dim, + config.text_config.head_dim) self.vision_feature_layer = config.vision_feature_layer self.vocab_size = config.text_config.vocab_size self.pad_token_id = -1 @@ -447,26 +450,6 @@ def get_multimodal_embeddings( return self._process_image_input(image_input) - def calculate_request_ids_to_seq_ids(self, scheduler_output: SchedulerOutput): - request_ids_to_seq_ids = {} - - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - request_ids_to_seq_ids[req_id] = [0] - - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id - sampling_params = req_data.sampling_params - if hasattr(sampling_params, 'n') and sampling_params.n > 1: - request_ids_to_seq_ids[req_id] = list(range(sampling_params.n)) - else: - request_ids_to_seq_ids[req_id] = [0] - - return request_ids_to_seq_ids - - def calculate_finished_requests_ids(self, scheduler_output: SchedulerOutput): - return scheduler_output.finished_req_ids - def forward( self, input_ids: torch.Tensor, @@ -482,21 +465,15 @@ def forward( inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None - - if "scheduler_output" in kwargs: - scheduler_output = kwargs["scheduler_output"] - kwargs["request_ids_to_seq_ids"] = self.calculate_request_ids_to_seq_ids(scheduler_output) - kwargs["finished_requests_ids"] = self.calculate_finished_requests_ids(scheduler_output) - print("minimax_vl_01 add request_ids_to_seq_ids and finished_requests_ids") if "request_ids_to_seq_ids" not in kwargs: kwargs["request_ids_to_seq_ids"] = {} if "finished_requests_ids" not in kwargs: kwargs["finished_requests_ids"] = [] - if self.minimax_cache is None: - self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, - cache_shape=self.cache_shape) + if self.minimax_cache is None: + self.minimax_cache = MinimaxCacheManager( + dtype=self._dtype, cache_shape=self.cache_shape) minimax_cache_params = self.minimax_cache.current_run_tensors(**kwargs) hidden_states = self.language_model.model(input_ids, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6d30c80d01f2..97d8c91b4659 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1107,7 +1107,6 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - scheduler_output=scheduler_output, ) if self.use_aux_hidden_state_outputs: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0a4d3c3ba64c..e22bbcc656ff 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1412,10 +1412,8 @@ def _dummy_run(self, # Disable KV Scale Calculation for dummy data during profile run if model_input.attn_metadata is not None: model_input.attn_metadata.enable_kv_scales_calculation = False - kwargs = {} - kwargs["request_ids_to_seq_ids"] = model_input.request_ids_to_seq_ids - kwargs["finished_requests_ids"] = model_input.finished_requests_ids - self.execute_model(model_input, kv_caches, intermediate_tensors, **kwargs) + + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() if self.lora_config: self._remove_dummy_loras() From b9d1c66321859c0e7b9c61fd5c652c1af7893d74 Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 11:10:21 +0800 Subject: [PATCH 28/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index c640949b71a3..50e0f0337448 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -488,12 +488,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, state_indices_tensor = kv_caches.state_indices_tensor decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor) - else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, attn_metadata) + # if not decode_only: + # hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + # state_indices_tensor) + # else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, attn_metadata) hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states) From 34f7ae8253dcffc2f279f0874af1daeff9adcf27 Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 11:22:51 +0800 Subject: [PATCH 29/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 12 ++++++------ vllm/model_executor/models/minimax_vl_01.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 50e0f0337448..c640949b71a3 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -488,12 +488,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, state_indices_tensor = kv_caches.state_indices_tensor decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - # if not decode_only: - # hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - # state_indices_tensor) - # else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, attn_metadata) + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, attn_metadata) hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index d34020255ded..4bc4813a07d7 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -200,7 +200,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.decoder_attention_types = [ 1 ] * config.text_config.num_hidden_layers - self.minimax_cache: Optional[MinimaxCacheManager] = None + self.minimax_cache = None linear_layer_nums = sum( 1 for i in range(config.text_config.num_hidden_layers) if self.decoder_attention_types[i] == 0) From f2812847864dd71386dcb8416c6b4e74aa5dc359 Mon Sep 17 00:00:00 2001 From: qingjun Date: Thu, 8 May 2025 14:23:25 +0800 Subject: [PATCH 30/30] fix code Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 1 - vllm/worker/model_runner.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index c640949b71a3..2640a92aa498 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -437,7 +437,6 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor): break _start = attn_metadata.query_start_loc[_prefill_idx] _end = attn_metadata.query_start_loc[_prefill_idx + 1] - slot_id = state_indices_tensor[_prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e22bbcc656ff..066966a620e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1413,7 +1413,9 @@ def _dummy_run(self, if model_input.attn_metadata is not None: model_input.attn_metadata.enable_kv_scales_calculation = False - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, kv_caches, intermediate_tensors, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids = model_input.finished_requests_ids) torch.cuda.synchronize() if self.lora_config: self._remove_dummy_loras()