From 7db612232f504175e79ea65bc2176b682ce11d8b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Jun 2024 08:02:09 -0700 Subject: [PATCH 1/5] [Model] Add Gemma 2 --- docs/source/models/supported_models.rst | 4 + requirements-common.txt | 2 +- vllm/config.py | 29 +- vllm/model_executor/layers/layernorm.py | 46 ++ .../model_executor/layers/logits_processor.py | 10 +- .../model_executor/layers/rotary_embedding.py | 10 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/gemma2.py | 403 ++++++++++++++++++ 8 files changed, 496 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/models/gemma2.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 47737ae525209..544322582f8e9 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -55,6 +55,10 @@ Alongside each architecture, we include some popular models that use it. - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. - ✅︎ + * - :code:`Gemma2ForCausalLM` + - Gemma2 + - :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc. + - ✅︎ * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. diff --git a/requirements-common.txt b/requirements-common.txt index 05969cfa5d65f..636f85343e1f2 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -6,7 +6,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. +transformers >= 4.42.0 # Required for Gemma 2. tokenizers >= 0.19.1 # Required for Llama 3. fastapi aiohttp diff --git a/vllm/config.py b/vllm/config.py index 0c4d770e46847..486e0e84c8347 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -14,7 +14,7 @@ from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_tpu, is_xpu, + is_hip, is_neuron, is_tpu, is_xpu, print_warning_once, update_environment_variables) if TYPE_CHECKING: @@ -257,8 +257,17 @@ def verify_with_parallel_config( "BitAndBytes quantization with TP or PP is not supported yet.") def get_hf_config_sliding_window(self) -> Optional[int]: - """Get the sliding window size, or None if disabled. - """ + """Get the sliding window size, or None if disabled.""" + + if (self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "While Gemma 2 uses sliding window attention for every odd " + "layer, vLLM currently ignores it and uses global attention " + "for all layers. This might affect the model's behavior when " + "the context length is larger than the sliding window size " + f"({self.hf_text_config.sliding_window}).") + return None # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # addition to sliding window size. We check if that field is present @@ -1252,10 +1261,16 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: - # Following the common practice, we use float16 for float32 - # models. - logger.info("Casting torch.float32 to torch.float16.") - torch_dtype = torch.float16 + if config.model_type == "gemma2": + logger.info( + "For Gemma 2, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 else: torch_dtype = config_dtype else: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 14f5e2378a421..7a8699e3932cb 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -95,3 +95,49 @@ def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" s += f", eps={self.variance_epsilon}" return s + + +class GemmaRMSNorm(CustomOp): + """RMS normalization for Gemma. + + Two differences from the above RMSNorm: + 1. x * (1 + w) instead of x * w. + 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + x = x * (1.0 + self.weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. + return self.forward_native(x, residual) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 7eee599473a11..8062bfb5194bc 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -22,7 +22,8 @@ def __init__(self, vocab_size: int, org_vocab_size: Optional[int] = None, scale: float = 1.0, - logits_as_input: bool = False) -> None: + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: """ Args: scale: A scaling factor to apply to the logits. @@ -34,6 +35,8 @@ def __init__(self, self.logits_as_input = logits_as_input # original vocabulary size (without LoRA). self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap def forward( self, @@ -52,6 +55,11 @@ def forward( logits = self._get_logits(hidden_states, embedding, embedding_bias) if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + if self.scale != 1.0: logits *= self.scale diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index a0b19046b7491..9e53deef0fbf1 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -610,6 +610,16 @@ def forward( return query.flatten(-2), key.flatten(-2) +class GemmaRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 + inv_freq = 1.0 / (base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / + self.rotary_dim)) + return inv_freq + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 5afb2e1d44d39..e7ced618c7be7 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -23,6 +23,7 @@ "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py new file mode 100644 index 0000000000000..e909090e4bef6 --- /dev/null +++ b/vllm/model_executor/models/gemma2.py @@ -0,0 +1,403 @@ +# coding=utf-8 +# Copyright 2024 The vLLM team. +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import Gemma2Config + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once + +from .interfaces import SupportsLoRA + + +class Gemma2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): + raise ValueError( + "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma2Attention(nn.Module): + + def __init__(self, + layer_idx: int, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + # TODO(woosuk): Use the `get_rope` interface. + self.rotary_emb = GemmaRotaryEmbedding( + self.head_dim, + self.head_dim, + max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), + ) + + if self.config.attn_logit_softcapping is not None: + print_warning_once( + "Gemma 2 normally uses attention logit soft-capping; " + "soft-capping is currently incompatible with the flash " + "attention kernels, so vLLM removes it to enable speed and " + "efficiency gains of flash attention.") + # FIXME(woosuk): While Gemma 2 uses sliding window attention for every + # odd layer, vLLM currently ignores it and uses global attention for + # all layers. + use_sliding_window = (layer_idx % 2 == 1 + and config.sliding_window is not None) + del use_sliding_window # Unused. + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma2DecoderLayer(nn.Module): + + def __init__( + self, + layer_idx: int, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma2Attention( + layer_idx=layer_idx, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +class Gemma2Model(nn.Module): + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + hidden_states *= self.normalizer + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Gemma2ForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + del lora_config # Unused. + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Gemma2Model(config, cache_config, quant_config) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.embed_tokens.weight, + hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}") From df2c007974b09777259f5bc55ba4b94092f9f475 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Jun 2024 16:11:25 +0000 Subject: [PATCH 2/5] Remove supports_lora=True --- vllm/model_executor/models/gemma2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index e909090e4bef6..4e35a9ec34069 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -293,8 +293,6 @@ def forward( class Gemma2ForCausalLM(nn.Module, SupportsLoRA): - supports_lora = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", From a1768031919a7b6485a6850e108a7d336804e0db Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Jun 2024 16:35:01 +0000 Subject: [PATCH 3/5] Move warning --- vllm/config.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 486e0e84c8347..84d2b52d5e358 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -141,6 +141,18 @@ def __init__( code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "While Gemma 2 uses sliding window attention for every odd " + "layer, vLLM currently ignores it and uses global attention " + "for all layers. This might affect the model's behavior when " + "the context length is larger than the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, max_model_len=max_model_len, @@ -259,16 +271,6 @@ def verify_with_parallel_config( def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" - if (self.hf_text_config.model_type == "gemma2" - and self.hf_text_config.sliding_window is not None): - print_warning_once( - "While Gemma 2 uses sliding window attention for every odd " - "layer, vLLM currently ignores it and uses global attention " - "for all layers. This might affect the model's behavior when " - "the context length is larger than the sliding window size " - f"({self.hf_text_config.sliding_window}).") - return None - # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # addition to sliding window size. We check if that field is present # and if it's False, return None. From a1ddec810d27f25d04260c1af2c3813cf220bd4e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Jun 2024 16:53:46 +0000 Subject: [PATCH 4/5] Fix warning msg --- vllm/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 84d2b52d5e358..903be2a16ee66 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -146,10 +146,9 @@ def __init__( and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None): print_warning_once( - "While Gemma 2 uses sliding window attention for every odd " - "layer, vLLM currently ignores it and uses global attention " - "for all layers. This might affect the model's behavior when " - "the context length is larger than the sliding window size " + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " f"({self.hf_text_config.sliding_window}).") self.disable_sliding_window = True From 7fbcf48044c4993edbe7567f85316457e648191f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Jun 2024 18:36:28 +0000 Subject: [PATCH 5/5] Add soft_cap to LoRA Logits Processor --- vllm/lora/layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e4a23273f7282..2fddfccaf1e4c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1069,6 +1069,10 @@ def vocab_size(self): def scale(self): return self.base_layer.scale + @property + def soft_cap(self): + return self.base_layer.soft_cap + @property def org_vocab_size(self): return self.base_layer.org_vocab_size