From 8214f1ef2ed7459da67b93986d08f56b783e79d3 Mon Sep 17 00:00:00 2001 From: WyldeCat Date: Tue, 22 Jul 2025 13:21:02 +0000 Subject: [PATCH 01/13] feat(motif): add Motif Signed-off-by: WyldeCat Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 581 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 2 files changed, 582 insertions(+) create mode 100644 vllm/model_executor/models/motif.py diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py new file mode 100644 index 000000000000..ade176d7842f --- /dev/null +++ b/vllm/model_executor/models/motif.py @@ -0,0 +1,581 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE +"""Inference-only Motif model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.selector import _Backend +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .adapters import as_seq_cls_model +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class PolyNorm(torch.nn.Module): + """ + A trainable activation function introduced in https://arxiv.org/html/2411.03884v1. + The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md + """ + + def __init__(self, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1)) + self.eps = eps + + def _norm(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + orig_dtype = x.dtype + x_float = x.to(torch.float32) + output = (self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + self.bias) + return output.to(orig_dtype) + + +class MotifMLP(nn.Module): + """MLP for the language component of the Motif model, which contains a + MergedColumnParallelLinear merging 2 outputs via PolyNorm activation.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "poly_norm", + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "poly_norm": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only poly_norm is supported for now.") + self.act_fn = PolyNorm() + self.intermediate_size = intermediate_size + self.tp_size = get_tensor_model_parallel_world_size() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn( + x[..., :self.intermediate_size // + self.tp_size]) * x[..., self.intermediate_size // self.tp_size:] + x, _ = self.down_proj(x) + return x + + +class MotifAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + assert self.num_heads % 2 == 0, 'num_heads should be even' + assert self.num_kv_heads % 2 == 0, 'num_heads should be even' + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) + sliding_window = None + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.subln = nn.RMSNorm(2 * self.head_dim, + eps=1e-5, + elementwise_affine=True) + + params = { + 'differential_flash_attention_config': { + 'lambda_init': self.lambda_init, + 'lambda_q1': self.lambda_q1, + 'lambda_k1': self.lambda_k1, + 'lambda_q2': self.lambda_q2, + 'lambda_k2': self.lambda_k2, + "subln": self.subln, + } + } + + diff_attn_err_msg = ( + 'Set VLLM_ATTENTION_BACKEND="DIFFERENTIAL_FLASH_ATTN" ' + 'to enable Differential Flash Attention.') + try: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + **params, + ) + except TypeError as e: + raise ValueError(diff_attn_err_msg) from e + assert (self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN + ), diff_attn_err_msg + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb(self, config: PretrainedConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class MotifDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + # By default, Motif uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = MotifAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = MotifMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class MotifModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = MotifDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to MotifDecoderLayer + decoder_layer_type = decoder_layer_type or MotifDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class MotifForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = MotifModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 465c25f09480..ae8829c7dacc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -101,6 +101,7 @@ "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), + "MotifForCausalLM": ("motif", "MotifForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), From ef386df53e9863c56f9b2afa12cce5ae09e82734 Mon Sep 17 00:00:00 2001 From: Jeesoo Lee Date: Thu, 7 Aug 2025 12:21:03 +0900 Subject: [PATCH 02/13] fix(motif): fix lambda init fn Sync with https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py#L366 Signed-off-by: WyldeCat Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index ade176d7842f..56b45c18f717 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -236,7 +236,7 @@ def __init__( ), diff_attn_err_msg def lambda_init_fn(self, depth): - return 0.8 - 0.6 * math.exp(-0.3 * depth) + return 0.8 - 0.6 * math.exp(-0.3 * (depth - 1)) def forward( self, From 1154e07ec8e00839741f3ed9104304661eb7f82b Mon Sep 17 00:00:00 2001 From: ca1207 Date: Fri, 22 Aug 2025 07:15:28 +0000 Subject: [PATCH 03/13] Implementation of FusedPolyNormKernel Signed-off-by: ca1207 --- benchmarks/kernels/benchmark_polynorm.py | 157 +++++++++++++++ csrc/layernorm_kernels.cu | 233 +++++++++++++++++++++++ csrc/ops.h | 5 +- csrc/torch_bindings.cpp | 6 + tests/kernels/core/test_layernorm.py | 33 +++- tests/models/registry.py | 2 + vllm/_custom_ops.py | 8 + vllm/model_executor/layers/layernorm.py | 59 ++++++ vllm/model_executor/models/motif.py | 30 +-- vllm/model_executor/models/registry.py | 2 +- 10 files changed, 504 insertions(+), 31 deletions(-) create mode 100644 benchmarks/kernels/benchmark_polynorm.py diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py new file mode 100644 index 000000000000..fc15612fcce5 --- /dev/null +++ b/benchmarks/kernels/benchmark_polynorm.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +from typing import Optional, Union + +import torch + +from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton + + +def polynorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + def norm(x, eps: float): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + x = x.float() + return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + bias).to(weight.dtype).view(orig_shape) + + +def polynorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + out = torch.empty_like(x) + vllm_ops.poly_norm(out, x, weight, bias, eps) + output = out + + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bais = torch.ones(1, dtype=dtype, device="cuda") + + output_naive = polynorm_naive(x.clone(), weight, bais) + output_vllm = polynorm_vllm(x.clone(), weight, bais) + + if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "vllm"], + line_names=["Naive", "vLLM"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name=f"polynorm-perf", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_naive(x.clone(), weight, bias), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_vllm(x.clone(), weight, bias), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/polnorm/", + help="Path to save polnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + ) + + benchmark = get_benchmark() + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f051eb070222..2733485dff01 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -140,6 +140,194 @@ fused_add_rms_norm_kernel( } } +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. + + _f16VecPN struct extends _f16Vec to add operations specifically required for + polynomial normalization (poly norm). + The original _f16Vec does not include the sum-of-powers computation or + in-place polynomial normalization logic. */ +template +struct alignas(16) _f16VecPN : _f16Vec { + using Base = _f16Vec; + using Converter = typename Base::Converter; + using T1 = typename Base::T1; + using T2 = typename Base::T2; + using Base::data; + + __device__ auto sum_pows() const { + float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f; + +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + float x2 = z.x * z.x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + float y2 = z.y * z.y; + float y4 = y2 * y2; + float y6 = y4 * y2; + + s2 += x2 + y2; + s4 += x4 + y4; + s6 += x6 + y6; + } + return std::make_tuple(s2, s4, s6); + } + + __device__ void poly_norm_inplace(const float w2_inv_std, + const float w1_inv_std2, + const float w0_inv_std3, const float bias) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + + float x2 = z.x * z.x; + float x3 = x2 * z.x; + z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias; + + float y2 = z.y * z.y; + float y3 = y2 * z.y; + z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias; + + auto out = Converter::convert(z); + data[i] = out.x; + data[i + 1] = out.y; + } + } +}; + +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16VecPN>); + static_assert(sizeof(_f16VecPN) == sizeof(scalar_t) * width); + + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast<_f16VecPN*>(input); + const int vec_hidden_size = hidden_size / width; + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + auto [x2, x4, x6] = temp.sum_pows(); + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + variance = BlockReduce(reduceStore).Sum(variance, blockDim.x); + __syncthreads(); + variance2 = BlockReduce(reduceStore).Sum(variance2, blockDim.x); + __syncthreads(); + variance3 = BlockReduce(reduceStore).Sum(variance3, blockDim.x); + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + auto* __restrict__ out_v = reinterpret_cast<_f16VecPN*>(out); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias); + out_v[id] = temp; + } +} + +/* Generic poly_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + variance = BlockReduce(reduceStore).Sum(variance, blockDim.x); + __syncthreads(); + variance2 = BlockReduce(reduceStore).Sum(variance2, blockDim.x); + __syncthreads(); + variance3 = BlockReduce(reduceStore).Sum(variance3, blockDim.x); + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x3 = x2 * x; + + out[blockIdx.x * hidden_size + idx] = (scalar_t)( + x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + s_bias); + } +} + } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -219,3 +407,48 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] LAUNCH_FUSED_ADD_RMS_NORM(0); } } + +#define LAUNCH_FUSED_POLY_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ + vllm::poly_norm_kernel<<>>( \ + out.data_ptr(), input.data_ptr(), \ + weight.data_ptr(), bias.data_ptr(), epsilon, \ + hidden_size); \ + }); + +void poly_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [3] + torch::Tensor& bias, // [1] + double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_POLY_NORM(8); + } else { + LAUNCH_FUSED_POLY_NORM(0); + } +} diff --git a/csrc/ops.h b/csrc/ops.h index 86fe848e2fd5..fcce5a170a72 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& bias, double epsilon); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, @@ -346,4 +349,4 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); -#endif \ No newline at end of file +#endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4edb7af50f10..0ee35037f7da 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -161,6 +161,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Polynomial Normalization. + ops.def( + "poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float " + "epsilon) -> ()"); + ops.impl("poly_norm", torch::kCUDA, &poly_norm); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 02316ceaac73..44e8b644ddf2 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,7 +6,7 @@ from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm, PolyNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -70,6 +70,37 @@ def test_rms_norm( (out, x, layer.weight.data, layer.variance_epsilon)) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_poly_norm( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + layer = PolyNorm().to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + layer.bias.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + + ref_out = layer.forward_native(x) + out = layer(x) + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + opcheck( + torch.ops._C.poly_norm, + (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) diff --git a/tests/models/registry.py b/tests/models/registry.py index 4035319b45ce..96aa18d9baa4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -257,6 +257,8 @@ def check_available_online( "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 + "MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B", + trust_remote_code=True), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0043456e0009..35bda698a624 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -280,6 +280,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + bias: torch.Tensor, epsilon: float) -> None: + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: @@ -686,6 +693,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + def cutlass_sparse_compress(a: torch.Tensor) \ -> tuple[torch.Tensor, torch.Tensor]: """ diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a5fc1db2dc10..ce0f5eb8ab0c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -43,6 +43,20 @@ def fused_add_rms_norm( return x, residual +def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) + ops.poly_norm( + out, + x, + weight, + bias, + variance_epsilon, + ) + return out + + def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float) -> torch.Tensor: import aiter as rocm_aiter @@ -265,3 +279,48 @@ def forward_cuda( self.forward_static) self._is_compiled = True return self.forward_native(x, residual) + + +@CustomOp.register("poly_norm") +class PolyNorm(CustomOp): + """Polynomial normalization. + + Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b + where w_n is the learned weight and b is the bias. + Refer to https://arxiv.org/html/2411.03884v1 + """ + + def __init__( + self, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1)) + self.variance_epsilon = eps + + def _norm(self, x): + return x / torch.sqrt( + x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward(). + + Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md + """ + + orig_dtype = x.dtype + x_float = x.to(torch.float32) + output = (self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + self.bias) + return output.to(orig_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return poly_norm(x, self.weight, self.bias, self.variance_epsilon) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 56b45c18f717..2d428f340874 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -19,7 +19,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm, PolyNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -41,30 +41,6 @@ maybe_prefix) -class PolyNorm(torch.nn.Module): - """ - A trainable activation function introduced in https://arxiv.org/html/2411.03884v1. - The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md - """ - - def __init__(self, eps=1e-6): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(3) / 3) - self.bias = torch.nn.Parameter(torch.zeros(1)) - self.eps = eps - - def _norm(self, x): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - orig_dtype = x.dtype - x_float = x.to(torch.float32) - output = (self.weight[0] * self._norm(x_float**3) + - self.weight[1] * self._norm(x_float**2) + - self.weight[2] * self._norm(x_float) + self.bias) - return output.to(orig_dtype) - - class MotifMLP(nn.Module): """MLP for the language component of the Motif model, which contains a MergedColumnParallelLinear merging 2 outputs via PolyNorm activation.""" @@ -199,9 +175,7 @@ def __init__( self.lambda_k2 = nn.Parameter( torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) - self.subln = nn.RMSNorm(2 * self.head_dim, - eps=1e-5, - elementwise_affine=True) + self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) params = { 'differential_flash_attention_config': { diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ae8829c7dacc..dd90f9bfb761 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -101,13 +101,13 @@ "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), - "MotifForCausalLM": ("motif", "MotifForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "MotifForCausalLM": ("motif", "MotifForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), From 8d69392c9ee796bba042736c8fb50b28f1c83fc4 Mon Sep 17 00:00:00 2001 From: ca1207 Date: Mon, 18 Aug 2025 07:25:22 +0000 Subject: [PATCH 04/13] misc changes Signed-off-by: ca1207 --- benchmarks/kernels/benchmark_polynorm.py | 60 +++++++++---------- csrc/layernorm_kernels.cu | 51 ++++++++++------ docs/models/supported_models.md | 1 + tests/kernels/core/test_layernorm.py | 2 +- tests/models/registry.py | 3 +- tests/models/test_initialization.py | 5 +- .../backends/differential_flash_attn.py | 3 + vllm/model_executor/models/motif.py | 17 +++--- 8 files changed, 83 insertions(+), 59 deletions(-) diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py index fc15612fcce5..6836894ec728 100644 --- a/benchmarks/kernels/benchmark_polynorm.py +++ b/benchmarks/kernels/benchmark_polynorm.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from typing import Optional, Union import torch @@ -23,8 +22,16 @@ def norm(x, eps: float): return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) x = x.float() - return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) + - weight[2] * norm(x, eps) + bias).to(weight.dtype).view(orig_shape) + return ( + ( + weight[0] * norm(x**3, eps) + + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + + bias + ) + .to(weight.dtype) + .view(orig_shape) + ) def polynorm_vllm( @@ -44,18 +51,14 @@ def polynorm_vllm( return output -def calculate_diff(batch_size, seq_len, hidden_size): +def calculate_diff(batch_size, seq_len, hidden_dim): dtype = torch.bfloat16 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") weight = torch.ones(3, dtype=dtype, device="cuda") bais = torch.ones(1, dtype=dtype, device="cuda") - output_naive = polynorm_naive(x.clone(), weight, bais) - output_vllm = polynorm_vllm(x.clone(), weight, bais) + output_naive = polynorm_naive(x, weight, bais) + output_vllm = polynorm_vllm(x, weight, bais) if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): print("✅ All implementations match") @@ -65,34 +68,29 @@ def calculate_diff(batch_size, seq_len, hidden_size): batch_size_range = [2**i for i in range(0, 7, 2)] seq_length_range = [2**i for i in range(6, 11, 1)] -head_num_range = [32, 48] -configs = list( - itertools.product(head_num_range, batch_size_range, seq_length_range)) +dim_range = [2048, 4096] +configs = list(itertools.product(dim_range, batch_size_range, seq_length_range)) def get_benchmark(): - @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["head_num", "batch_size", "seq_len"], + x_names=["dim", "batch_size", "seq_len"], x_vals=[list(_) for _ in configs], line_arg="provider", line_vals=["naive", "vllm"], line_names=["Naive", "vLLM"], styles=[("blue", "-"), ("red", "-")], ylabel="us", - plot_name=f"polynorm-perf", + plot_name="polynorm-perf", args={}, - )) - def benchmark(head_num, batch_size, seq_len, provider): + ) + ) + def benchmark(dim, batch_size, seq_len, provider): dtype = torch.bfloat16 - hidden_size = head_num * 128 # assuming head_dim = 128 + hidden_dim = dim * 4 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") weight = torch.ones(3, dtype=dtype, device="cuda") bias = torch.ones(1, dtype=dtype, device="cuda") @@ -100,12 +98,12 @@ def benchmark(head_num, batch_size, seq_len, provider): if provider == "naive": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: polynorm_naive(x.clone(), weight, bias), + lambda: polynorm_naive(x, weight, bias), quantiles=quantiles, ) else: ms, min_ms, max_ms = triton.testing.do_bench( - lambda: polynorm_vllm(x.clone(), weight, bias), + lambda: polynorm_vllm(x, weight, bias), quantiles=quantiles, ) @@ -131,10 +129,10 @@ def benchmark(head_num, batch_size, seq_len, provider): help="Sequence length", ) parser.add_argument( - "--hidden-size", + "--hidden-dim", type=int, - default=4096, - help="Hidden size (2nd dimension) of the sequence", + default=8192, + help="Intermediate size of MLP", ) parser.add_argument( "--save-path", @@ -149,7 +147,7 @@ def benchmark(head_num, batch_size, seq_len, provider): calculate_diff( batch_size=args.batch_size, seq_len=args.seq_len, - hidden_size=args.hidden_size, + hidden_dim=args.hidden_dim, ) benchmark = get_benchmark() diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 2733485dff01..b00c8e4d84da 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -203,7 +203,7 @@ struct alignas(16) _f16VecPN : _f16Vec { template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] - scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [3] const scalar_t* __restrict__ bias, // [1] const float epsilon, const int hidden_size) { @@ -215,7 +215,7 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ auto* __restrict__ input_v = - reinterpret_cast<_f16VecPN*>(input); + reinterpret_cast*>(input); const int vec_hidden_size = hidden_size / width; float variance = 0.0f; float variance2 = 0.0f; @@ -231,14 +231,22 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] variance3 += x6; } - using BlockReduce = cub::BlockReduce; + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); - variance = BlockReduce(reduceStore).Sum(variance, blockDim.x); - __syncthreads(); - variance2 = BlockReduce(reduceStore).Sum(variance2, blockDim.x); - __syncthreads(); - variance3 = BlockReduce(reduceStore).Sum(variance3, blockDim.x); + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; __shared__ float s_w2_inv_std; __shared__ float s_w1_inv_std2; @@ -273,7 +281,7 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] - scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [3] const scalar_t* __restrict__ bias, // [1] const float epsilon, const int hidden_size) { @@ -292,14 +300,22 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] variance3 += x6; } - using BlockReduce = cub::BlockReduce; + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); - variance = BlockReduce(reduceStore).Sum(variance, blockDim.x); - __syncthreads(); - variance2 = BlockReduce(reduceStore).Sum(variance2, blockDim.x); - __syncthreads(); - variance3 = BlockReduce(reduceStore).Sum(variance3, blockDim.x); + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; __shared__ float s_w2_inv_std; __shared__ float s_w1_inv_std2; @@ -323,8 +339,9 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] float x2 = x * x; float x3 = x2 * x; - out[blockIdx.x * hidden_size + idx] = (scalar_t)( - x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + s_bias); + out[blockIdx.x * hidden_size + idx] = + (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + + s_bias); } } diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 297d98142b5f..85b64b904eea 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -382,6 +382,7 @@ th { | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | | ✅︎ | | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 44e8b644ddf2..53e6d793cf2f 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,7 +6,7 @@ from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import RMSNorm, PolyNorm +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] diff --git a/tests/models/registry.py b/tests/models/registry.py index 96aa18d9baa4..079819681171 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -258,7 +258,8 @@ def check_available_online( {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B", - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index bbd3da982af8..97091701e3bd 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -65,8 +65,9 @@ def _initialize_kv_caches_v1(self, vllm_config): _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") - if model_arch == "Phi4FlashForCausalLM": - # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend + if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"): + # Phi4FlashForCausalLM and MotifForCausalLM + # only supports DIFFERENTIAL_FLASH_ATTN backend m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index fac3c318a87a..e9dc5b546eea 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -734,6 +734,7 @@ def forward_generate_kv_cache( window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ) assert prefill_output.shape == output[: num_prefill_tokens].shape @@ -755,6 +756,7 @@ def forward_generate_kv_cache( window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ).squeeze(1) except Exception as e: logger.error("Error in PagedAttention.forward_decode: %s", @@ -787,6 +789,7 @@ def forward_with_kv_cache_only( window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ).squeeze(1) return output diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 2d428f340874..41aa0bcd98e3 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -19,7 +19,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.layernorm import RMSNorm, PolyNorm +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -72,17 +72,20 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "poly_norm": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only poly_norm is supported for now.") + raise NotImplementedError(f"Unsupported activation: {hidden_act}. " + "Only poly_norm is supported for now.") self.act_fn = PolyNorm() self.intermediate_size = intermediate_size - self.tp_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() + if hidden_act == "poly_norm" and tp_size > 1: + raise NotImplementedError( + "Tensor parallelism for poly_norm is not supported yet. " + "Support will be added in the future.") def forward(self, x): x, _ = self.gate_up_proj(x) x = self.act_fn( - x[..., :self.intermediate_size // - self.tp_size]) * x[..., self.intermediate_size // self.tp_size:] + x[..., :self.intermediate_size]) * x[..., self.intermediate_size:] x, _ = self.down_proj(x) return x @@ -175,7 +178,7 @@ def __init__( self.lambda_k2 = nn.Parameter( torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) - self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) + self.subln = RMSNorm(2 * self.head_dim, eps=config.attn_rms_norm_eps) params = { 'differential_flash_attention_config': { From 135c23f2ade26fa3ab535bdb82dd0dcd5329b4d0 Mon Sep 17 00:00:00 2001 From: ca1207 Date: Fri, 22 Aug 2025 07:18:56 +0000 Subject: [PATCH 05/13] fix typo Signed-off-by: ca1207 --- benchmarks/kernels/benchmark_polynorm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py index 6836894ec728..9ac8f5e6594e 100644 --- a/benchmarks/kernels/benchmark_polynorm.py +++ b/benchmarks/kernels/benchmark_polynorm.py @@ -55,10 +55,10 @@ def calculate_diff(batch_size, seq_len, hidden_dim): dtype = torch.bfloat16 x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") weight = torch.ones(3, dtype=dtype, device="cuda") - bais = torch.ones(1, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") - output_naive = polynorm_naive(x, weight, bais) - output_vllm = polynorm_vllm(x, weight, bais) + output_naive = polynorm_naive(x, weight, bias) + output_vllm = polynorm_vllm(x, weight, bias) if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): print("✅ All implementations match") From 4b20a28c9600280c3c20f34a15af6721b55a917a Mon Sep 17 00:00:00 2001 From: TaehyunKim <73943231+ca1207@users.noreply.github.com> Date: Mon, 1 Sep 2025 11:36:52 +0900 Subject: [PATCH 06/13] Update docs/models/supported_models.md Co-authored-by: Jee Jee Li Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com> --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 85b64b904eea..3aace6f491d6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -382,7 +382,7 @@ th { | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | | ✅︎ | | +| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | ✅︎ | ✅︎ | | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | From 2d07e364b19cfd61c48c3e859bb7084d0435d28d Mon Sep 17 00:00:00 2001 From: ca1207 Date: Mon, 1 Sep 2025 09:34:41 +0000 Subject: [PATCH 07/13] inherit from LlamaModel Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 247 +++------------------------- 1 file changed, 20 insertions(+), 227 deletions(-) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 41aa0bcd98e3..1cfdc0d09a7f 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -7,8 +7,7 @@ # LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE """Inference-only Motif model compatible with HuggingFace weights.""" import math -from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Optional import torch from torch import nn @@ -16,29 +15,19 @@ from vllm.attention import Attention, AttentionType from vllm.attention.selector import _Backend -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel from .adapters import as_seq_cls_model -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import SupportsV0Only +from .utils import extract_layer_index class MotifMLP(nn.Module): @@ -332,227 +321,31 @@ def forward( return hidden_states, residual -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }) -class MotifModel(nn.Module): +class MotifModel(LlamaModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", decoder_layer_type: type[nn.Module] = MotifDecoderLayer): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) - self.config = config - self.quant_config = quant_config - self.vocab_size = config.vocab_size - - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - else: - self.embed_tokens = PPMissingLayer() - - # Use the provided decoder layer type or default to MotifDecoderLayer - decoder_layer_type = decoder_layer_type or MotifDecoderLayer - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: decoder_layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() +# Motif model uses differential attention +# Only supported in v0 (no chunked prefill support) +class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only): - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class MotifForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = MotifModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = MotifDecoderLayer): - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM) From 32eb6887b33443fb5f7e74886f49e257f60edf95 Mon Sep 17 00:00:00 2001 From: ca1207 Date: Tue, 2 Sep 2025 02:40:11 +0000 Subject: [PATCH 08/13] removed redundant code Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 1cfdc0d09a7f..411babca855b 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -321,18 +321,6 @@ def forward( return hidden_states, residual -class MotifModel(LlamaModel): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = MotifDecoderLayer): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) - - # Motif model uses differential attention # Only supported in v0 (no chunked prefill support) class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only): From 4288d27f8d5f2c90a3f7c51bce0d3b80b92ed53c Mon Sep 17 00:00:00 2001 From: ca1207 Date: Tue, 2 Sep 2025 02:58:37 +0000 Subject: [PATCH 09/13] removed LlamaModel Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 411babca855b..d68b42a35e76 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -23,7 +23,7 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel +from vllm.model_executor.models.llama import LlamaForCausalLM from .adapters import as_seq_cls_model from .interfaces import SupportsV0Only From f2e4d967a2b7d01cf80ba2ad6140674851c7f29c Mon Sep 17 00:00:00 2001 From: TaehyunKim <73943231+ca1207@users.noreply.github.com> Date: Tue, 2 Sep 2025 19:34:22 +0900 Subject: [PATCH 10/13] Update vllm/model_executor/models/motif.py Co-authored-by: Jee Jee Li Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com> --- vllm/model_executor/models/motif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index d68b42a35e76..6fa91c7ffd53 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -290,7 +290,7 @@ def __init__( intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - bias=getattr(config, "mlp_bias", False), + bias=getattr(config, "use_bias", False), prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, From d20fab4b00be78c8cab0df8b43e8db1415ca770a Mon Sep 17 00:00:00 2001 From: ca1207 Date: Tue, 2 Sep 2025 10:36:55 +0000 Subject: [PATCH 11/13] rename bias in getattr to use_bias Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 6fa91c7ffd53..007425962360 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -255,7 +255,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "use_bias", False) bias_o_proj = attention_bias if hasattr(config, 'qkv_bias'): attention_bias = config.qkv_bias From 0bd2c0168ac8ccd7404aeea8be4b92bdddd86106 Mon Sep 17 00:00:00 2001 From: ca1207 Date: Wed, 3 Sep 2025 04:25:06 +0000 Subject: [PATCH 12/13] add assertions Signed-off-by: ca1207 --- vllm/model_executor/models/motif.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py index 007425962360..153f36dcf1f5 100644 --- a/vllm/model_executor/models/motif.py +++ b/vllm/model_executor/models/motif.py @@ -331,6 +331,12 @@ def __init__(self, prefix: str = "", layer_type: type[nn.Module] = MotifDecoderLayer): + # Prefix caching and chunked prefill is not supported for this model. + assert not vllm_config.cache_config.enable_prefix_caching, \ + "Motif currently does not support prefix caching" + assert not vllm_config.scheduler_config.chunked_prefill_enabled, \ + "Motif currently does not support chunked prefill" + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) From c9779bc73342045a1808a47e14df394a5bfb372c Mon Sep 17 00:00:00 2001 From: ca1207 Date: Wed, 10 Sep 2025 11:53:47 +0900 Subject: [PATCH 13/13] add assert Signed-off-by: ca1207 --- csrc/layernorm_kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index b00c8e4d84da..05be023de0f2 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -440,6 +440,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size] double epsilon) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.data_ptr() != input.data_ptr()); int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size;