Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
partial-json-parser # used for parsing partial JSON outputs
pyzmq
msgspec
gguf == 0.10.0
gguf >= 0.10.0
importlib_metadata
mistral_common[opencv] >= 1.5.0
pyyaml
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,12 +1271,15 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
config = copy.deepcopy(model_config.hf_text_config)
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
# revert sliding_window modifications
if model_type == "gemma3_text":
model_type = "gemma3"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
Expand All @@ -1290,6 +1293,8 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
if hasattr(config, "interleaved_sliding_window"):
config.sliding_window = config.interleaved_sliding_window

arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
Expand Down Expand Up @@ -1323,6 +1328,14 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config

# GGUF hasn't supported multimodal models yet, we need to
# extract text_config to only initialize the llm backbone
architectures = model_config.hf_config.architectures
vllm_config.model_config.hf_config = (
vllm_config.model_config.hf_text_config)
vllm_config.model_config.hf_config.architectures = architectures

local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
Expand Down
15 changes: 12 additions & 3 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
Expand All @@ -70,7 +70,7 @@ def __init__(
hidden_size,
bias=False,
quant_config=quant_config)
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
if not (hidden_activation == "gelu_pytorch_tanh"):
raise ValueError(
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
Expand Down Expand Up @@ -128,12 +128,14 @@ def __init__(self,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
Expand Down Expand Up @@ -201,9 +203,9 @@ def __init__(
self.mlp = Gemma2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -253,6 +255,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -319,6 +323,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if self.quant_config and self.quant_config.get_name() == "gguf" \
and name.endswith("norm.weight"):
# Revert +1 during llama.cpp conversion
# see: https://github.com/ggerganov/llama.cpp/blob/2e2f8f093cd4fb6bbb87ba84f6b9684fa082f3fa/convert_hf_to_gguf.py#L3313-L3315
loaded_weight -= 1
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for compressed-tensors quantization
Expand Down
21 changes: 18 additions & 3 deletions vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,21 @@ def __init__(
intermediate_size: int,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
Expand Down Expand Up @@ -125,12 +130,14 @@ def __init__(self,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -293,6 +300,7 @@ def __init__(
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -344,6 +352,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -423,6 +433,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if self.quant_config and self.quant_config.get_name() == "gguf" \
and name.endswith("norm.weight"):
# Revert +1 during llama.cpp conversion
# see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400
loaded_weight -= 1
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
Expand Down