diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9f6146d66d2c..9580ba48ecae 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -358,7 +358,7 @@ Specified using `--task generate`. | `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 7d569fd83821..fcedd85d8870 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -174,6 +174,7 @@ def iter_params(self, model_id: str): "internlm/internlm2-chat-7b": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), + "pfnet/plamo-2-1b": PPTestSettings.fast(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), # Tests TransformersForCausalLM "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 50179b9a904d..84a656a3b9da 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -9,7 +9,7 @@ from tests.quantization.utils import is_quant_method_supported -MODELS = ["ai21labs/Jamba-tiny-random"] +MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] @pytest.mark.skipif(not is_quant_method_supported("experts_int8"), diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 670576c68efd..cb3c86e06a17 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only PLaMo2 model.""" -import math from collections.abc import Iterable from typing import Optional @@ -11,30 +10,40 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsV0Only) + SupportsPP, SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.models.utils import ( + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -77,17 +86,6 @@ def _init_weights(self, module: torch.nn.Module) -> None: module.weight.data[module.padding_idx].zero_() -def get_initial_dt_bias(num_heads: int) -> torch.Tensor: - dt_min = 0.001 - dt_max = 0.1 - dt = torch.exp( - torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + - math.log(dt_min)) - dt = torch.clamp(dt, 1e-4) - inv_dt = dt + torch.log(-torch.expm1(-dt)) - return inv_dt - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -97,52 +95,36 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: return (i % config.mamba_step) != (config.mamba_step // 2) -# TODO(Shinichi): Replace this with RMSNorm. -def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, - eps: float) -> torch.Tensor: - input_shape = hidden_states.shape - hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape) - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + eps) - hidden_states = hidden_states.to(input_dtype) - hidden_states = weight * hidden_states - return hidden_states.reshape(input_shape) - - -def _swiglu(h: torch.Tensor) -> torch.Tensor: - h0, h1 = h.chunk(2, dim=-1) - return torch.nn.functional.silu(h0) * h1 - - -# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +# Adapted from: +# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 +# transformers.models.mamba.modeling_mamba.MambaMixer class Plamo2MambaMixer(nn.Module): - # TODO(Shinichi): Rebase on Mamba2 implementation. def __init__(self, - config: Plamo2Config, - cache_config: CacheConfig, - quant_config: QuantizationConfig, - max_model_len: int, + vllm_config: VllmConfig, + *, prefix: str = "", **kwargs) -> None: super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.intermediate_size = (config.mamba_num_heads * - config.hidden_size_per_head) - self.hidden_size_per_head = config.hidden_size_per_head - self.num_heads = config.mamba_num_heads + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.hidden_size = self.config.hidden_size + self.ssm_state_size = self.config.mamba_d_state + self.conv_kernel_size = self.config.mamba_d_conv + self.intermediate_size = (self.config.mamba_num_heads * + self.config.hidden_size_per_head) + self.tp_size = get_tensor_model_parallel_world_size() + self.intermediate_size_per_tp_worker = \ + self.intermediate_size // self.tp_size + self.head_dim = self.config.hidden_size_per_head + self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) - self.use_conv_bias = False - self.use_bias = False self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.intermediate_size, - bias=self.use_conv_bias, + bias=False, + prefix=f"{prefix}.conv1d", + return_bias=False, ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -153,15 +135,19 @@ def __init__(self, self.in_proj = MergedColumnParallelLinear( self.hidden_size, [self.intermediate_size] * 2, - bias=self.use_bias, + bias=False, + quant_config=self.quant_config, prefix=f"{prefix}.in_proj", + return_bias=False, ) # selective projection used to make dt, B and C input dependent self.bcdt_proj = RowParallelLinear( self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False, + quant_config=self.quant_config, prefix=f"{prefix}.bcdt_proj", + return_bias=False, ) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, @@ -170,154 +156,223 @@ def __init__(self, self.time_step_rank, self.num_heads, bias=False, + quant_config=self.quant_config, prefix=f"{prefix}.dt_proj", + return_bias=False, ) - self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) - tp_size = get_tensor_model_parallel_world_size() self.A = nn.Parameter( torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, + divide(self.num_heads, self.tp_size), dtype=torch.float32, )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) + self.dt_bias = nn.Parameter( + torch.ones(divide(self.num_heads, self.tp_size))) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( sharded_weight_loader(0), lambda x: -torch.exp(x.float())) set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, - bias=self.use_bias, + bias=False, input_is_parallel=True, + quant_config=self.quant_config, prefix=f"{prefix}.out_proj", + return_bias=False, ) # The activation function is fixed to SiLU. self.activation = "silu" - self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) - self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) - self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.dt_norm = RMSNorm(self.time_step_rank, + eps=self.config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, + eps=self.config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, + eps=self.config.rms_norm_eps) + + def _project_ssm_parameters(self, hidden_states): + ssm_parameters = self.bcdt_proj(hidden_states) + B, C, time_step = torch.split( + ssm_parameters, + [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], + dim=-1, + ) + + # vllm._custom_ops.rms_norm requires contiguous input tensors. + time_step = self.dt_norm(time_step.contiguous()) + B = self.B_norm(B.contiguous()) + C = self.C_norm(C.contiguous()) + dt = self.dt_proj(time_step) + return B, C, dt def forward( self, hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, **kwargs, ) -> torch.Tensor: + # mamba2_metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0] - # Reshaping the projected states as in modeling_plamo.py. - length = len(hidden_states) - projected_states = projected_states.reshape(length, self.num_heads, -1) - gate, hidden_states = torch.split( - projected_states, - [self.hidden_size_per_head, self.hidden_size_per_head], - dim=-1) - hidden_states = hidden_states.reshape(length, -1).transpose(0, 1) - gate = gate.reshape(length, -1).transpose(0, 1) + projected_states = self.in_proj(hidden_states) + gate, hidden_states = projected_states.chunk(2, dim=-1) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, + # Separate prefill and decode by splitting varlen input + # Split along token dimension + hidden_states_p, hidden_states_d = torch.split( + hidden_states, + [num_prefill_tokens, num_decodes], + dim=0, + ) + gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], + dim=0) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + mamba_cache_params.state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] + if has_prefill else None) + + ssd_output_list = [] + + # Process prefill requests + if has_prefill: + # 2. Convolution sequence transformation + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_p = causal_conv1d_fn( + hidden_states_p.transpose(0, 1), conv_weights, self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), + has_initial_state=mamba2_metadata.has_initial_states, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) + hidden_states_p = hidden_states_p.transpose(0, 1) + hidden_states_p = hidden_states_p[:num_prefill_tokens] + # In some instances, the following `bcdt_proj` op + # requires contiguous inputs + # (e.g. if the Marlin kernel is used). + hidden_states_p = hidden_states_p.contiguous() + + B, C, dt = self._project_ssm_parameters(hidden_states_p) + + # 3. State Space Model sequence transformation + initial_states = None + if (mamba2_metadata.has_initial_states is not None + and mamba2_metadata.prep_initial_states): + # making a copy of the states + initial_states = torch.where( + mamba2_metadata.has_initial_states[:, None, None, None], + mamba_cache_params.ssm_state[state_indices_tensor_p], 0) + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, num_prefill_tokens, 1, -1), + C.view(1, num_prefill_tokens, 1, -1), + chunk_size=mamba2_metadata.chunk_size, + D=self.D, + z=gate_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, self.head_dim), + dt_bias=self.dt_bias, + seq_idx=mamba2_metadata.seq_idx, + chunk_indices=mamba2_metadata.chunk_indices, + chunk_offsets=mamba2_metadata.chunk_offsets, + cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + + # - reshape + ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) + + # Process decode requests + if has_decode: + # 2. Convolution sequence transformation + hidden_states_d = causal_conv1d_update( + hidden_states_d, mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0] - - # Splitting the ssm_parameters as in modeling_plamo.py. - B, C, time_step = torch.split( - ssm_parameters, - [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], - dim=-1, - ) - time_step = self.dt_norm(time_step.contiguous()) - B = self.B_norm(B.contiguous()) - C = self.C_norm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_bias.float() if hasattr( - self.dt_proj, "bias") else None) - - # Broadcasting as in modeling_plamo.py. - discrete_time_step = discrete_time_step.transpose( - 0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head) - discrete_time_step = discrete_time_step.reshape( - -1, self.intermediate_size).transpose(0, 1) - time_proj_bias = time_proj_bias[..., - None].expand(-1, - self.hidden_size_per_head) - time_proj_bias = time_proj_bias.reshape(self.intermediate_size) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = selective_state_update( + conv_state_indices=state_indices_tensor_d) + + B, C, dt = self._project_ssm_parameters(hidden_states_d) + + # 3. State Space Model sequence transformation + A = self.A[:, None, ...][:, :, + None].expand(-1, self.head_dim, + self.config.mamba_d_state) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.unsqueeze(1) + C = C.unsqueeze(1) + hidden_states_d = hidden_states_d.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using state_indices_tensor_d + + hidden_states_d = selective_state_update( mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, + hidden_states_d, + dt, + A, B, C, - self.D, - gate.transpose(0, 1), - time_proj_bias, + D, + z=gate_d.reshape(num_decodes, -1, self.head_dim), + dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor) - scan_outputs = scan_outputs.transpose(0, 1) + state_batch_indices=state_indices_tensor_d, + ) + ssd_output_list.append( + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + + # Merge prefill and decode outputs before passing to MLP + hidden_states = torch.vstack(ssd_output_list) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - return contextualized_states + out = self.out_proj(hidden_states) + return out class DenseMLP(nn.Module): @@ -332,33 +387,39 @@ def __init__( self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_up_proj = MergedColumnParallelLinear( - self.hidden_size, [self.intermediate_size] * 2, + self.hidden_size, + [self.intermediate_size] * 2, bias=False, prefix=f"{prefix}.gate_up_proj", - quant_config=quant_config) + quant_config=quant_config, + return_bias=False, + ) + self.act = SiluAndMul() self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=False, prefix=f"{prefix}.down_proj", - quant_config=quant_config) + quant_config=quant_config, + return_bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - h = self.gate_up_proj(hidden_states)[0] - h = _swiglu(h) - output, _ = self.down_proj(h) - return output # type: ignore + h = self.gate_up_proj(hidden_states) + h = self.act(h) + return self.down_proj(h) +@support_torch_compile class Plamo2AttentionMixer(nn.Module): def __init__(self, - config: Plamo2Config, - cache_config: CacheConfig, - quant_config: QuantizationConfig, - max_model_len: int | None = None, + *, + vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -396,19 +457,35 @@ def __init__(self, "rope_theta") else 10000 self.rope_scaling = config.rope_scaling if hasattr( config, "rope_scaling") else None + max_position = config.max_position_embeddings + if hasattr(vllm_config.model_config, "max_model_len") and isinstance( + vllm_config.model_config.max_model_len, int): + max_position = min(max_position, + vllm_config.model_config.max_model_len) - assert max_model_len is not None, "max_model_len must be provided" self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position=max_model_len, + max_position=max_position, base=self.rope_theta, rope_scaling=self.rope_scaling, ) - self.q_weight = torch.nn.Parameter( + self.q_norm = RMSNorm(config.hidden_size_per_head, + eps=config.rms_norm_eps) + self.q_norm.weight = torch.nn.Parameter( torch.ones((self.num_heads, config.hidden_size_per_head))) - self.k_weight = torch.nn.Parameter( + set_weight_attrs(self.q_norm.weight, + {"weight_loader": sharded_weight_loader(0)}) + self.k_norm = RMSNorm(config.hidden_size_per_head, + eps=config.rms_norm_eps) + self.k_norm.weight = torch.nn.Parameter( torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + # Tensor-parallelism shards the K norm weights to the tp ranks + # in a head-wise manner. This approach does not work if there is only + # a single KV head, as is the case for PLaMo 2-1B. + if self.total_num_kv_heads != 1: + set_weight_attrs(self.k_norm.weight, + {"weight_loader": sharded_weight_loader(0)}) self.attn = Attention( self.num_heads, @@ -423,13 +500,18 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = _rms_norm(q, self.q_weight, 1e-6) - k = _rms_norm(k, self.k_weight, 1e-6) + + q_shape = q.shape + q = q.reshape(q_shape[:-1] + self.q_norm.weight.shape) + q = self.q_norm.forward_native(q).reshape(q_shape) + k_shape = k.shape + k = k.reshape(k_shape[:-1] + self.k_norm.weight.shape) + k = self.k_norm.forward_native(k).reshape(k_shape) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -441,27 +523,18 @@ class Plamo2DecoderLayer(nn.Module): def __init__(self, vllm_config: VllmConfig, layer_idx: int, - max_model_len: int | None = None, prefix: str = "", **kwargs) -> None: super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - max_model_len = vllm_config.scheduler_config.max_model_len self.is_mamba = is_mamba(config, layer_idx) if self.is_mamba: - self.mixer = Plamo2MambaMixer(config=config, - cache_config=cache_config, - quant_config=quant_config, - max_model_len=max_model_len, + self.mixer = Plamo2MambaMixer(vllm_config=vllm_config, prefix=f"{prefix}.mixer") else: - self.mixer = Plamo2AttentionMixer(config=config, - cache_config=cache_config, - quant_config=quant_config, - max_model_len=max_model_len, + self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config, prefix=f"{prefix}.mixer") self.mlp = DenseMLP(config=config, @@ -482,6 +555,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -491,10 +565,12 @@ def forward( hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) - hidden_states = self.mixer(positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=mamba_cache_params) + hidden_states = self.mixer( + positions=positions, + hidden_states=hidden_states, + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -507,14 +583,18 @@ class Plamo2Decoder(torch.nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + config = vllm_config.model_config.hf_config + extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + return Plamo2DecoderLayer(vllm_config=vllm_config, + layer_idx=layer_idx, + prefix=prefix, + **extra_kwargs) - self.layers = nn.ModuleList([ - Plamo2DecoderLayer(vllm_config=vllm_config, - layer_idx=i, - prefix=f"{prefix}.layers.{i}") - for i in range(num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") def forward( self, @@ -522,9 +602,10 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: mamba_cache_index = 0 - for layer in self.layers: + for layer in self.layers[self.start_layer:self.end_layer]: layer_mamba_cache_params = None if layer.is_mamba: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( @@ -535,7 +616,9 @@ def forward( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params) + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) return hidden_states, residual @@ -557,10 +640,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -569,21 +658,41 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO(Shinichi): Implement pipeline parallelism. - hidden_states = self.embed_tokens(input_ids) - residual = None + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) hidden_states, residual = self.layers( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params) + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, - SupportsV0Only): +class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, + IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -629,10 +738,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) - + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -661,7 +775,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) @@ -670,7 +784,8 @@ def _get_mamba_cache_shape( self.config.mamba_d_conv - 1, ) temporal_state_shape = ( - hidden_size // world_size, + divide(self.config.mamba_num_heads, world_size), + self.config.hidden_size_per_head, self.config.mamba_d_state, ) return conv_state_shape, temporal_state_shape @@ -684,6 +799,14 @@ def compute_logits( sampling_metadata) return logits + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -703,23 +826,46 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ".B_norm_weight": ".B_norm.weight", ".C_norm_weight": ".C_norm.weight", ".dt_norm_weight": ".dt_norm.weight", + ".q_weight": ".q_norm.weight", + ".k_weight": ".k_norm.weight", } # Apply replacements based on the defined mappings for old, new in replacements.items(): if old in name: name = name.replace(old, new) - # Broadcast the loaded weight to match the model's parameter shape. - if ".A" in name: - loaded_weight = loaded_weight[:, None, None].expand( - -1, self.config.hidden_size_per_head, - self.config.mamba_d_state) + # Reshape the in_proj weights to match the shape expected + # by MergedColumnParallelLinear. + # This works both for unquantized weights and + # for quantized weights. + # In the quantized case, the weights are already transposed. + # Also, in addition to the quantized weights, + # the zero points and scales have to be reshaped as well. + # Packing should not be affected by this. + if ".mixer.in_proj.weight" in name \ + or "mixer.in_proj.qweight" in name \ + or "mixer.in_proj.scales" in name \ + or "mixer.in_proj.qzeros" in name: + if "mixer.in_proj.weight" in name: + loaded_weight = loaded_weight.transpose(0, 1) + # for weight: + # loaded_weight.shape[0] == self.config.hidden_size + # for qweight: + # loaded_weight.shape[0] == self.config.hidden_size // param.pack_factor # noqa + # for scales and qzeros: + # loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa loaded_weight = loaded_weight.reshape( - -1, self.config.mamba_d_state) - elif ".D" in name: - loaded_weight = loaded_weight[:, None].expand( - -1, self.config.hidden_size_per_head) - loaded_weight = loaded_weight.reshape(-1) + loaded_weight.shape[0], self.config.mamba_num_heads, -1) + gate_weight, hidden_states_weight = loaded_weight.chunk(2, + dim=-1) + gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1) + hidden_states_weight = hidden_states_weight.reshape( + loaded_weight.shape[0], -1) + loaded_weight = torch.cat([gate_weight, hidden_states_weight], + dim=-1) + if "mixer.in_proj.weight" in name: + loaded_weight = loaded_weight.transpose(0, 1) + # Offset parameter with vllm's RMSNorm haven't been supported yet. if ".pre_mixer_norm" in name: loaded_weight += 1.0 @@ -732,6 +878,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): elif "model.norm.weight" in name: loaded_weight += 1.0 + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)