From 2a2873bccdbe10d6d66f7436068b7c07c33ab590 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 13 Mar 2025 19:45:42 +0800 Subject: [PATCH 1/5] add gemma3 gguf support Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 7 ++++++- vllm/model_executor/models/gemma3.py | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c88af56e1805..77004435e79c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1271,12 +1271,15 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): See "Standardized tensor names" in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. """ - config = model_config.hf_config + config = copy.deepcopy(model_config.hf_config) model_type = config.model_type gguf_to_hf_name_map = {} # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" + # revert sliding_window modifications + if model_type == "gemma3_text": + model_type = "gemma3" if model_type in ("deepseek_v3", "deepseek_v2"): model_type = "deepseek2" # GGUF layer map assumes that we will have a merged expert weights @@ -1290,6 +1293,8 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + if hasattr(config, "interleaved_sliding_window"): + config.sliding_window = config.interleaved_sliding_window arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index f1ecf7fa821d..161a9b964eef 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -344,6 +344,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -423,6 +425,11 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue + if self.quant_config and self.quant_config.get_name() == "gguf" \ + and name.endswith("norm.weight"): + # Revert +1 during llama.cpp conversion + # see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400 + loaded_weight -= 1 for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue From d22f91b988fa36a462d89526a4810e30b2817510 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 13 Mar 2025 23:34:13 +0800 Subject: [PATCH 2/5] add gemma2 and gemma3 GGUF Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/gemma2.py | 15 ++++++++++++--- vllm/model_executor/models/gemma3.py | 14 +++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index cf744fc2b9d1..23f4543b7d4c 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -57,9 +57,9 @@ def __init__( self, hidden_size: int, intermediate_size: int, - hidden_act: str, hidden_activation: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -70,7 +70,7 @@ def __init__( hidden_size, bias=False, quant_config=quant_config) - if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): + if not (hidden_activation == "gelu_pytorch_tanh"): raise ValueError( "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " @@ -128,12 +128,14 @@ def __init__(self, self.total_num_kv_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -201,9 +203,9 @@ def __init__( self.mlp = Gemma2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, hidden_activation=config.hidden_activation, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -253,6 +255,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -319,6 +323,11 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: + if self.quant_config and self.quant_config.get_name() == "gguf" \ + and name.endswith("norm.weight"): + # Revert +1 during llama.cpp conversion + # see: https://github.com/ggerganov/llama.cpp/blob/2e2f8f093cd4fb6bbb87ba84f6b9684fa082f3fa/convert_hf_to_gguf.py#L3313-L3315 + loaded_weight -= 1 if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache scales for compressed-tensors quantization diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 161a9b964eef..badd82fcbade 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -59,16 +59,21 @@ def __init__( intermediate_size: int, hidden_activation: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.down_proj") if hidden_activation != "gelu_pytorch_tanh": raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " @@ -125,12 +130,14 @@ def __init__(self, self.total_num_kv_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -293,6 +300,7 @@ def __init__( intermediate_size=config.intermediate_size, hidden_activation=config.hidden_activation, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) From a9851454d775bc3c9f09b3e09a3b60a5bee1b3d1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Mar 2025 00:52:46 +0800 Subject: [PATCH 3/5] loose gguf version Signed-off-by: Isotr0py <2037008807@qq.com> --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 13a06011e409..068c9e2f1a53 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -25,7 +25,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31 partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec -gguf == 0.10.0 +gguf >= 0.10.0 importlib_metadata mistral_common[opencv] >= 1.5.0 pyyaml From c0c86c2e5665e077ecae3cfd81c95223f6f9cb7d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Mar 2025 01:17:25 +0800 Subject: [PATCH 4/5] fix text config Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 77004435e79c..29462e93a6ac 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1271,7 +1271,7 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): See "Standardized tensor names" in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. """ - config = copy.deepcopy(model_config.hf_config) + config = copy.deepcopy(model_config.hf_text_config) model_type = config.model_type gguf_to_hf_name_map = {} # hack: ggufs have a different name than transformers From 6f9adf613da104ba70bdae956869171717eda0e9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Mar 2025 01:47:00 +0800 Subject: [PATCH 5/5] handle text config Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 29462e93a6ac..ec55b0aed5da 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1328,6 +1328,14 @@ def download_model(self, model_config: ModelConfig) -> None: def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + + # GGUF hasn't supported multimodal models yet, we need to + # extract text_config to only initialize the llm backbone + architectures = model_config.hf_config.architectures + vllm_config.model_config.hf_config = ( + vllm_config.model_config.hf_text_config) + vllm_config.model_config.hf_config.architectures = architectures + local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights