From 75a7eecd723b0226cf647e29cb51ee6bf2d2c16a Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 28 Aug 2025 00:30:33 +0200 Subject: [PATCH 1/4] `Apertus` and `XIELU` Co-authored-by: AllenHaoHuang Signed-off-by: EduardDurech <39579228+EduardDurech@users.noreply.github.com> --- .../models/language/generation/test_common.py | 3 +- tests/models/registry.py | 2 + vllm/model_executor/layers/activation.py | 109 ++++ vllm/model_executor/models/apertus.py | 576 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 5 files changed, 690 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/apertus.py diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 57382914bfea..4c4434c94145 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -92,7 +92,8 @@ pytest.param( "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], - ) + ), + pytest.param("swiss-ai/Apertus-8B"), # apertus ]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 85b4c96e3b1c..db1aecfb060f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -137,6 +137,8 @@ def check_available_online( # yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] + "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B", + trust_remote_code=True), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f3248589abc4..ed4f0576b42f 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -10,11 +10,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import LazyDict +logger = init_logger(__name__) + @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): @@ -363,6 +366,110 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return self.forward_native(x) +@CustomOp.register("xielu") +class XIELU(CustomOp): + """ + Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 + If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA + Otherwise, we emit a single warning and use xIELU Python + """ + + def __init__( + self, + alpha_p_init: float = 0.8, + alpha_n_init: float = 0.8, + beta: float = 0.5, + eps: float = -1e-6, + dtype: torch.dtype = torch.bfloat16, + with_vector_loads: bool = False, + ): + super().__init__() + self.alpha_p = nn.Parameter( + torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - + 1).unsqueeze(0)) + self.alpha_n = nn.Parameter( + torch.log( + torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - + 1).unsqueeze(0)) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.with_vector_loads = with_vector_loads + # Temporary until xIELU CUDA fully implemented + self._beta_scalar = float(self.beta.detach().cpu().float().item()) + self._eps_scalar = float(self.eps.detach().cpu().float().item()) + + self._xielu_cuda_obj = None + try: + import xielu.ops # noqa: F401 + + self._xielu_cuda_obj = torch.classes.xielu.XIELU() + msg = "Using experimental xIELU CUDA." + try: + from torch._dynamo import allow_in_graph + + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + msg += " Enabled torch._dynamo for xIELU CUDA." + except Exception as err: + msg += (f" Could not enable torch._dynamo for xIELU ({err}) - " + "this may result in slower performance.") + self._xielu_cuda_fn = self._xielu_cuda + logger.warning_once(msg) + except Exception as err: + logger.warning_once( + "CUDA-fused xIELU not available (%s) –" + " falling back to a Python version.\n" + "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", + str(err), + ) + + def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: + alpha_p = nn.functional.softplus(self.alpha_p) + alpha_n = self.beta + nn.functional.softplus(self.alpha_n) + return torch.where( + x > 0, + alpha_p * x * x + self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + + self.beta * x, + ) + + def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: + """Firewall function to prevent torch.compile from seeing .item()""" + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions" + " but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p, + self.alpha_n, + # Temporary until xIELU CUDA fully implemented -> + # self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._xielu_cuda_obj is not None and input.is_cuda: + if not torch._dynamo.is_compiling(): + return self._xielu_cuda_fn(input) + else: + logger.warning_once( + "torch._dynamo is compiling, using Python version of xIELU." + ) + return self._xielu_python(input) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -426,6 +533,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): lambda: nn.Tanh(), "sigmoid": lambda: nn.Sigmoid(), + "xielu": + lambda: XIELU(), }) diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py new file mode 100644 index 000000000000..0de683d2cd06 --- /dev/null +++ b/vllm/model_executor/models/apertus.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The Swiss AI Initiative. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate the architectural differences made by +# the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Apertus model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import ApertusConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import XIELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class ApertusMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "xielu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only xIELU is supported for now.") + self.act_fn = XIELU() + + def forward(self, x): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class ApertusAttention(nn.Module): + + def __init__( + self, + config: ApertusConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) + + sliding_window = None + if layer_types := getattr(config, "layer_types", None): + is_sliding = layer_types[layer_idx] == "sliding_attention" + if is_sliding: + sliding_window = config.sliding_window + + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q.contiguous().view(-1, self.head_dim)).view_as(q) + k = self.k_norm(k.contiguous().view(-1, self.head_dim)).view_as(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb(self, config: ApertusConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "apertus": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class ApertusDecoderLayer(nn.Module): + + def __init__( + self, + config: ApertusConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + # Apertus defaults to causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = ApertusAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = ApertusMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_layernorm(hidden_states) + else: + hidden_states, residual = self.attention_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ApertusModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.aux_hidden_state_layers = tuple[int, ...]() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings" + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + return ApertusModel(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9040189ee558..98115f862356 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -39,6 +39,7 @@ # yapf: disable _TEXT_GENERATION_MODELS = { # [Decoder-only] + "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"), From 81ab8f74429956939743059af3db3ac846b1806e Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 28 Aug 2025 18:47:18 +0200 Subject: [PATCH 2/4] Assert XIELU CUDA obj is not `None` Signed-off-by: EduardDurech <39579228+EduardDurech@users.noreply.github.com> --- vllm/model_executor/layers/activation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index ed4f0576b42f..eb7e494e3286 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -434,6 +434,8 @@ def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: """Firewall function to prevent torch.compile from seeing .item()""" + assert self._xielu_cuda_obj is not None, ( + "XIELU CUDA object must not be None") original_shape = x.shape # CUDA kernel expects 3D tensors, reshape if needed while x.dim() < 3: From 5d7b48ce0dba0f99b64c9d87bcb6edf589a4656f Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Fri, 29 Aug 2025 06:03:57 +0200 Subject: [PATCH 3/4] Version constraint Signed-off-by: EduardDurech <39579228+EduardDurech@users.noreply.github.com> --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index db1aecfb060f..1f963e104ff6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -138,6 +138,7 @@ def check_available_online( _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B", + min_transformers_version="4.56.0", trust_remote_code=True), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), From 5f3765037c7a9ccd25a4942a381792b62c064e41 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Fri, 29 Aug 2025 11:55:34 +0200 Subject: [PATCH 4/4] `test_registry_imports` Transformers version compatibility Signed-off-by: EduardDurech <39579228+EduardDurech@users.noreply.github.com> --- tests/models/test_registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 8769ad45eb93..36882aba5e94 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -24,6 +24,9 @@ @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): + # Skip if transformers version is incompatible + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_transformers_version(on_fail="skip") # Ensure all model classes can be imported successfully model_cls = ModelRegistry._try_load_model_cls(model_arch) assert model_cls is not None