diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 132c325ce591..c350aaf5d3ad 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -576,7 +576,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
- elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
+ elif config.architectures[0] in (
+ "DeepseekV3ForCausalLM",
+ "DeepseekV2ForCausalLM",
+ "Glm4MoeForCausalLM",
+ ):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
index dba1f3943b96..4ed690090144 100644
--- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py
+++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
@@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
elif (
config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"
+ or config.architectures[0] == "Glm4MoeForCausalLM"
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 887f754a3d1c..b0be947a66f9 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -577,6 +577,7 @@ Specified using `--task generate`.
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4MoeForCausalLM` | GLM-4.5 | T + IE+ + VE+ | `THUDM/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 5c546a6c86da..c537438ad6fb 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -364,6 +364,9 @@ def check_available_online(
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501
+ "Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4.5",
+ min_transformers_version="4.54",
+ is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
@@ -489,6 +492,10 @@ def check_available_online(
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
+ "Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4.5",
+ speculative_model="THUDM/GLM-4.5",
+ min_transformers_version="4.54",
+ is_available_online=False),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_use/test_glm4_moe_tool_parser.py
new file mode 100644
index 000000000000..478f4b916672
--- /dev/null
+++ b/tests/tool_use/test_glm4_moe_tool_parser.py
@@ -0,0 +1,410 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# ruff: noqa: E501
+
+import json
+
+import pytest
+
+from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
+from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+pytest.skip("skip glm4_moe parser test", allow_module_level=True)
+# Use a common model that is likely to be available
+MODEL = "THUDM/GLM-4.5"
+
+
+@pytest.fixture(scope="module")
+def glm4_moe_tokenizer():
+ return get_tokenizer(tokenizer_name=MODEL)
+
+
+@pytest.fixture
+def glm4_moe_tool_parser(glm4_moe_tokenizer):
+ return Glm4MoeModelToolParser(glm4_moe_tokenizer)
+
+
+def assert_tool_calls(actual_tool_calls: list[ToolCall],
+ expected_tool_calls: list[ToolCall]):
+ assert len(actual_tool_calls) == len(expected_tool_calls)
+
+ for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
+ expected_tool_calls):
+ assert isinstance(actual_tool_call.id, str)
+ assert len(actual_tool_call.id) > 0
+
+ assert actual_tool_call.type == "function"
+ assert actual_tool_call.function.name == expected_tool_call.function.name
+ # Compare arguments as JSON objects to handle formatting differences
+ actual_args = json.loads(actual_tool_call.function.arguments)
+ expected_args = json.loads(expected_tool_call.function.arguments)
+ assert actual_args == expected_args
+
+
+def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
+ model_output = "This is a test"
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+ assert not extracted_tool_calls.tools_called
+ assert extracted_tool_calls.tool_calls == []
+ assert extracted_tool_calls.content == model_output
+
+
+@pytest.mark.parametrize(
+ ids=[
+ "single_tool_call",
+ "multiple_tool_calls",
+ "tool_call_with_content_before",
+ "tool_call_with_mixed_args",
+ "tool_call_with_chinese_content",
+ ],
+ argnames=["model_output", "expected_tool_calls", "expected_content"],
+ argvalues=[
+ (
+ """get_current_weather
+ city
+ Dallas
+ state
+ TX
+ unit
+ fahrenheit
+ """,
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ ))
+ ],
+ None,
+ ),
+ (
+ """get_current_weather
+ city
+ Dallas
+ state
+ TX
+ unit
+ fahrenheit
+
+ get_current_weather
+ city
+ Orlando
+ state
+ FL
+ unit
+ fahrenheit
+ """,
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ )),
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Orlando",
+ "state": "FL",
+ "unit": "fahrenheit",
+ }),
+ )),
+ ],
+ None,
+ ),
+ (
+ """I'll help you check the weather. get_current_weather
+ city
+ Seattle
+ state
+ WA
+ unit
+ celsius
+ """,
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Seattle",
+ "state": "WA",
+ "unit": "celsius",
+ }),
+ ))
+ ],
+ "I'll help you check the weather.",
+ ),
+ (
+ """get_current_weather
+ city
+ New York
+ state
+ NY
+ unit
+ celsius
+ """,
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "New York",
+ "state": "NY",
+ "unit": "celsius",
+ }),
+ ))
+ ],
+ None,
+ ),
+ ("""I will help you get the weather.get_weather
+ city
+ Beijing
+ date
+ 2025-08-01
+ """, [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps({
+ "city": "Beijing",
+ "date": "2025-08-01",
+ }),
+ ))
+ ], "I will help you get the weather."),
+ ],
+)
+def test_extract_tool_calls(glm4_moe_tool_parser, model_output,
+ expected_tool_calls, expected_content):
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+ assert extracted_tool_calls.tools_called
+ assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
+
+ assert extracted_tool_calls.content == expected_content
+
+
+def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser):
+ """Test tool extraction when thinking tags are present."""
+ model_output = """I want to get the weather.
+
+I will help you get the weather.
+get_weather
+city
+Beijing
+date
+2025-08-01
+"""
+
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ assert extracted_tool_calls.tools_called
+ assert len(extracted_tool_calls.tool_calls) == 1
+ assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
+
+ expected_content = """I want to get the weather.
+
+I will help you get the weather."""
+ assert extracted_tool_calls.content == expected_content
+
+
+def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
+ """Test that malformed XML is handled gracefully."""
+ model_output = """get_weather
+city
+Seattle
+incomplete_arg
+value
+"""
+
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ # Should handle malformed XML gracefully
+ # The parser should either extract what it can or return no tool calls
+ # depending on how robust we want the parsing to be
+ assert isinstance(extracted_tool_calls.tools_called, bool)
+ assert isinstance(extracted_tool_calls.tool_calls, list)
+
+
+def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
+ """Test tool calls with no arguments."""
+ model_output = """get_current_time
+"""
+
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ assert extracted_tool_calls.tools_called
+ assert len(extracted_tool_calls.tool_calls) == 1
+ assert extracted_tool_calls.tool_calls[
+ 0].function.name == "get_current_time"
+ # Empty arguments should result in empty JSON object
+ assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
+
+
+def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
+ """Test extraction with mixed content and multiple tool calls."""
+ model_output = """I will help you get the weather info.
+
+get_weather
+city
+Beijing
+date
+2025-08-01
+
+
+meaningwhile, I will also check the weather in Shanghai.
+
+get_weather
+city
+Shanghai
+date
+2025-08-01
+"""
+
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ assert extracted_tool_calls.tools_called
+ assert len(extracted_tool_calls.tool_calls) == 2
+
+ # Check first tool call
+ assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
+ args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
+ assert args1["city"] == "Beijing"
+ assert args1["date"] == "2025-08-01"
+
+ # Check second tool call
+ assert extracted_tool_calls.tool_calls[1].function.name == "get_weather"
+ args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments)
+ assert args2["city"] == "Shanghai"
+ assert args2["date"] == "2025-08-01"
+
+ # Content should be everything before the first tool call
+ assert extracted_tool_calls.content == "I will help you get the weather info."
+
+
+def test_streaming_basic_functionality(glm4_moe_tool_parser):
+ """Test basic streaming functionality."""
+ # Reset streaming state
+ glm4_moe_tool_parser.current_tool_name_sent = False
+ glm4_moe_tool_parser.prev_tool_call_arr = []
+ glm4_moe_tool_parser.current_tool_id = -1
+ glm4_moe_tool_parser.streamed_args_for_tool = []
+
+ # Test with a simple tool call
+ current_text = """get_weather
+city
+Beijing
+"""
+
+ # Mock token IDs for testing
+ tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345
+ tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346
+
+ result = glm4_moe_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text=current_text,
+ delta_text="",
+ previous_token_ids=[],
+ current_token_ids=[tool_call_start_id, tool_call_end_id],
+ delta_token_ids=[tool_call_end_id],
+ request=None,
+ )
+
+ # The result behavior depends on the streaming state
+ # This test mainly ensures no exceptions are thrown
+ assert result is None or hasattr(result, 'tool_calls') or hasattr(
+ result, 'content')
+
+
+def test_streaming_no_tool_calls(glm4_moe_tool_parser):
+ """Test streaming when there are no tool calls."""
+ current_text = "This is just regular text without any tool calls."
+
+ result = glm4_moe_tool_parser.extract_tool_calls_streaming(
+ previous_text="This is just regular text",
+ current_text=current_text,
+ delta_text=" without any tool calls.",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[],
+ request=None,
+ )
+
+ # Should return the delta text as content
+ assert result is not None
+ assert hasattr(result, 'content')
+ assert result.content == " without any tool calls."
+
+
+def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
+ """Test streaming when there's content before tool calls."""
+ # Reset streaming state
+ glm4_moe_tool_parser.current_tool_name_sent = False
+ glm4_moe_tool_parser.prev_tool_call_arr = []
+ glm4_moe_tool_parser.current_tool_id = -1
+ glm4_moe_tool_parser.streamed_args_for_tool = []
+
+ current_text = "I will help you get the weather"
+
+ result = glm4_moe_tool_parser.extract_tool_calls_streaming(
+ previous_text="I will help you",
+ current_text=current_text,
+ delta_text="get the weather.",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[],
+ request=None,
+ )
+
+ # Should return content when no tool call tokens are detected
+ assert result is not None
+ assert hasattr(result, 'content')
+ assert result.content == "get the weather."
+
+
+def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
+ """Test tool calls with special characters and unicode."""
+ model_output = """send_message
+recipient
+Amy
+message
+It is a nice day
+priority
+high
+"""
+
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ assert extracted_tool_calls.tools_called
+ assert len(extracted_tool_calls.tool_calls) == 1
+ assert extracted_tool_calls.tool_calls[0].function.name == "send_message"
+
+ args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
+ assert args["recipient"] == "Amy"
+ assert args["message"] == "It is a nice day"
+ assert args["priority"] == "high"
+
+
+def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
+ """Test incomplete tool calls (missing closing tag)."""
+ model_output = """get_weather
+city
+Beijing
+date
+2025-08-01"""
+
+ extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ # Incomplete tool calls should not be extracted
+ assert not extracted_tool_calls.tools_called
+ assert extracted_tool_calls.tool_calls == []
+ assert extracted_tool_calls.content == model_output
diff --git a/vllm/config.py b/vllm/config.py
index 384cb584fa9a..e92c501012a4 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1333,7 +1333,8 @@ def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
if (self.hf_text_config.model_type == "deepseek_mtp"
- or self.hf_config.model_type == "mimo_mtp"):
+ or self.hf_config.model_type == "mimo_mtp"
+ or self.hf_config.model_type == "glm4_moe_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
@@ -2663,7 +2664,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
- return hf_config
+
+ if hf_config.architectures[0] == "Glm4MoeForCausalLM":
+ hf_config.model_type = "glm4_moe_mtp"
+ n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
+ hf_config.update({
+ "num_hidden_layers": 0,
+ "n_predict": n_predict,
+ "architectures": ["Glm4MoeMTPModel"]
+ })
return hf_config
@@ -2774,7 +2783,7 @@ def __post_init__(self):
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
- in ("deepseek_mtp", "mimo_mtp")):
+ in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 137375b9707c..9eda7155f01f 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -3,6 +3,7 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
+from .glm4_moe_tool_parser import Glm4MoeModelToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import Hermes2ProToolParser
@@ -19,10 +20,22 @@
from .xlam_tool_parser import xLAMToolParser
__all__ = [
- "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
- "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
- "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
- "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
- "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
- "KimiK2ToolParser", "HunyuanA13BToolParser"
+ "ToolParser",
+ "ToolParserManager",
+ "Granite20bFCToolParser",
+ "GraniteToolParser",
+ "Hermes2ProToolParser",
+ "MistralToolParser",
+ "Internlm2ToolParser",
+ "Llama3JsonToolParser",
+ "JambaToolParser",
+ "Llama4PythonicToolParser",
+ "PythonicToolParser",
+ "Phi4MiniJsonToolParser",
+ "DeepSeekV3ToolParser",
+ "xLAMToolParser",
+ "MinimaxToolParser",
+ "KimiK2ToolParser",
+ "HunyuanA13BToolParser",
+ "Glm4MoeModelToolParser",
]
diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
new file mode 100644
index 000000000000..c3f9d7923575
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
@@ -0,0 +1,402 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# code modified from deepseekv3_tool_parser.py
+
+from collections.abc import Sequence
+from typing import Union
+
+import regex as re
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("glm4_moe")
+class Glm4MoeModelToolParser(ToolParser):
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+ self.current_tool_name_sent = False
+ self.prev_tool_call_arr: list[dict] = []
+ self.current_tool_id = -1
+ self.streamed_args_for_tool: list[str] = []
+ self.tool_call_start_token = ""
+ self.tool_call_end_token = ""
+
+ self.tool_calls_start_token = self.tool_call_start_token
+
+ # Updated regex for the XML-based format
+ self.tool_call_regex = re.compile(
+ r"\s*"
+ r"(?P[^\n<]+)\s*" # 函数名(到换行或 <)
+ r"(?P(?:\s*[^<]+\s*"
+ r"[^<]*\s*)*)\s*"
+ r"",
+ re.DOTALL,
+ )
+
+ # Regex for parsing individual arguments
+ self.arg_regex = re.compile(
+ r"(?P[^<]+)\s*(?P[^<]*)",
+ re.DOTALL,
+ )
+
+ # Streaming regex
+ self.stream_tool_call_portion_regex = re.compile(
+ r"(?P[^\n<]+)\s*"
+ r"(?P(?:\s*[^<]+\s*"
+ r"[^<]*\s*)*)",
+ re.DOTALL,
+ )
+
+ # For streaming, we also need a regex to match just the function name
+ self.stream_tool_call_name_regex = re.compile(
+ r"(?P[^\n<]+)",
+ re.DOTALL,
+ )
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolParser "
+ "constructor during construction.")
+
+ self.tool_call_start_token_id = self.vocab.get(
+ self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+
+ def _parse_arguments(self, args_text: str) -> str:
+ """Parse XML-based arguments into JSON format."""
+ if not args_text or not args_text.strip():
+ return "{}"
+
+ args_dict = {}
+ matches = self.arg_regex.findall(args_text)
+
+ for key, value in matches:
+ args_dict[key.strip()] = value.strip()
+
+ import json
+ return json.dumps(args_dict, ensure_ascii=False)
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+
+ # sanity check; avoid unnecessary processing
+ if self.tool_calls_start_token not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ try:
+ # Find all tool calls in the output
+ function_call_matches = self.tool_call_regex.findall(model_output)
+
+ logger.debug("function_call_matches: %s", function_call_matches)
+
+ if not function_call_matches:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output,
+ )
+
+ tool_calls = []
+ for i, match in enumerate(function_call_matches):
+ function_name, function_args_xml = match
+ function_name = function_name.strip()
+
+ # Parse XML arguments to JSON
+ function_args_json = self._parse_arguments(function_args_xml)
+
+ tool_calls.append(
+ ToolCall(
+ id=f"call_{i}",
+ type='function',
+ function=FunctionCall(name=function_name,
+ arguments=function_args_json),
+ ))
+
+ # Extract content before the first tool call
+ content = model_output[:model_output.find(self.
+ tool_calls_start_token)]
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content.strip() if content.strip() else None,
+ )
+
+ except Exception:
+ logger.exception("Error in extracting tool call from response.")
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ logger.debug("delta_text: %s", delta_text)
+ logger.debug("delta_token_ids: %s", delta_token_ids)
+ # check to see if we should be streaming a tool call - is there a
+ if self.tool_call_start_token_id not in current_token_ids:
+ logger.debug("No tool call tokens found!")
+ return DeltaMessage(content=delta_text)
+ delta_text = delta_text.replace(self.tool_calls_start_token,
+ "").replace(self.tool_call_end_token,
+ "")
+ try:
+
+ # figure out where we are in the parsing by counting tool call
+ # start & end tags
+ prev_tool_start_count = previous_token_ids.count(
+ self.tool_call_start_token_id)
+ prev_tool_end_count = previous_token_ids.count(
+ self.tool_call_end_token_id)
+ cur_tool_start_count = current_token_ids.count(
+ self.tool_call_start_token_id)
+ cur_tool_end_count = current_token_ids.count(
+ self.tool_call_end_token_id)
+ tool_call_portion = None
+ text_portion = None
+
+ # case: if we're generating text, OR rounding out a tool call
+ if (cur_tool_start_count == cur_tool_end_count
+ and prev_tool_end_count == cur_tool_end_count
+ and self.tool_call_end_token not in delta_text):
+ logger.debug("Generating text content! skipping tool parsing.")
+ return DeltaMessage(content=delta_text)
+
+ if self.tool_call_end_token in delta_text:
+ logger.debug("tool_call_end_token in delta_text")
+ full_text = current_text + delta_text
+ tool_call_portion = full_text.split(
+ self.tool_call_start_token)[-1].split(
+ self.tool_call_end_token)[0].rstrip()
+ delta_text = delta_text.split(
+ self.tool_call_end_token)[0].rstrip()
+ text_portion = delta_text.split(
+ self.tool_call_end_token)[-1].lstrip()
+
+ # case -- we're starting a new tool call
+ if (cur_tool_start_count > cur_tool_end_count
+ and cur_tool_start_count > prev_tool_start_count):
+ if len(delta_token_ids) > 1:
+ tool_call_portion = current_text.split(
+ self.tool_call_start_token)[-1]
+ else:
+ tool_call_portion = None
+ delta = None
+
+ text_portion = None
+
+ # set cursors and state appropriately
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ logger.debug("Starting on a new tool %s", self.current_tool_id)
+
+ # case -- we're updating an existing tool call
+ elif (cur_tool_start_count > cur_tool_end_count
+ and cur_tool_start_count == prev_tool_start_count):
+
+ # get the portion of the text that's the tool call
+ tool_call_portion = current_text.split(
+ self.tool_call_start_token)[-1]
+ text_portion = None
+
+ # case -- the current tool call is being closed.
+ elif (cur_tool_start_count == cur_tool_end_count
+ and cur_tool_end_count >= prev_tool_end_count):
+ if self.prev_tool_call_arr is None or len(
+ self.prev_tool_call_arr) == 0:
+ logger.debug(
+ "attempting to close tool call, but no tool call")
+ return None
+ diff = self.prev_tool_call_arr[self.current_tool_id].get(
+ "arguments")
+ if diff:
+ diff = (diff.encode("utf-8").decode("unicode_escape")
+ if diff is str else diff)
+ if '"}' not in delta_text:
+ return None
+ end_loc = delta_text.rindex('"}')
+ diff = delta_text[:end_loc] + '"}'
+ logger.debug(
+ "Finishing tool and found diff that had not "
+ "been streamed yet: %s",
+ diff,
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += diff
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=diff).model_dump(exclude_none=True),
+ )
+ ])
+
+ # case -- otherwise we're just generating text
+ else:
+ text = delta_text.replace(self.tool_call_start_token, "")
+ text = text.replace(self.tool_call_end_token, "")
+ delta = DeltaMessage(tool_calls=[], content=text)
+ return delta
+
+ current_tool_call = dict()
+ if tool_call_portion:
+ current_tool_call_matches = (
+ self.stream_tool_call_portion_regex.match(
+ tool_call_portion))
+ if current_tool_call_matches:
+ tool_id, tool_args = (current_tool_call_matches.groups())
+ tool_name = tool_id.split('.')[1].split(':')[0]
+ current_tool_call['id'] = tool_id
+ current_tool_call["name"] = tool_name
+ current_tool_call["arguments"] = tool_args
+ else:
+ current_tool_call_name_matches = (
+ self.stream_tool_call_name_regex.match(
+ tool_call_portion))
+ if current_tool_call_name_matches:
+ tool_id_str, = current_tool_call_name_matches.groups()
+ tool_name = tool_id_str.split('.')[1].split(':')[0]
+ current_tool_call['id'] = tool_id_str
+ current_tool_call["name"] = tool_name
+ current_tool_call["arguments"] = ""
+ else:
+ logger.debug("Not enough token")
+ return None
+
+ # case - we haven't sent the tool name yet. If it's available, send
+ # it. otherwise, wait until it's available.
+ if not self.current_tool_name_sent:
+ if current_tool_call is None:
+ return None
+ function_name: Union[str, None] = current_tool_call.get("name")
+ tool_id = current_tool_call.get("id")
+ if function_name:
+ self.current_tool_name_sent = True
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ type="function",
+ id=tool_id,
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True),
+ )
+ ])
+ else:
+ return None
+
+ # case -- otherwise, send the tool call delta
+
+ # if the tool call portion is None, send the delta as text
+ if tool_call_portion is None:
+ # if there's text but not tool calls, send that -
+ # otherwise None to skip chunk
+ delta = (DeltaMessage(
+ content=delta_text) if text_portion is not None else None)
+ return delta
+
+ # now, the nitty-gritty of tool calls
+ # now we have the portion to parse as tool call.
+
+ logger.debug("Trying to parse current tool call with ID %s",
+ self.current_tool_id)
+
+ # if we're starting a new tool call, push an empty object in as
+ # a placeholder for the arguments
+ if len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+
+ # main logic for tool parsing here - compare prev. partially-parsed
+ # JSON to the current partially-parsed JSON
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
+ "arguments")
+ cur_arguments = current_tool_call.get("arguments")
+
+ logger.debug("diffing old arguments: %s", prev_arguments)
+ logger.debug("against new ones: %s", cur_arguments)
+
+ # case -- no arguments have been created yet. skip sending a delta.
+ if not cur_arguments and not prev_arguments:
+ logger.debug("Skipping text %s - no arguments", delta_text)
+ delta = None
+
+ # case -- prev arguments are defined, but non are now.
+ # probably impossible, but not a fatal error - just keep going
+ elif not cur_arguments and prev_arguments:
+ logger.error("should be impossible to have arguments reset "
+ "mid-call. skipping streaming anything.")
+ delta = None
+
+ # case -- we now have the first info about arguments available from
+ # autocompleting the JSON
+ elif cur_arguments and not prev_arguments:
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=cur_arguments).model_dump(
+ exclude_none=True),
+ )
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] = cur_arguments
+
+ # last case -- we have an update to existing arguments.
+ elif cur_arguments and prev_arguments:
+ if (isinstance(delta_text, str)
+ and cur_arguments != prev_arguments
+ and len(cur_arguments) > len(prev_arguments)
+ and cur_arguments.startswith(prev_arguments)):
+ delta_arguments = cur_arguments[len(prev_arguments):]
+ logger.debug("got diff %s", delta_text)
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=delta_arguments).model_dump(
+ exclude_none=True),
+ )
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] = cur_arguments
+ else:
+ delta = None
+
+ # handle saving the state for the current tool into
+ # the "prev" list for use in diffing for the next iteration
+ if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
+ self.prev_tool_call_arr[
+ self.current_tool_id] = current_tool_call
+ else:
+ self.prev_tool_call_arr.append(current_tool_call)
+
+ return delta
+
+ except Exception:
+ logger.exception("Error trying to handle streaming tool call.")
+ return None # do not stream a delta. skip this token ID.
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
new file mode 100644
index 000000000000..bdca293d21db
--- /dev/null
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -0,0 +1,685 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The ZhipuAI Team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
+import typing
+from collections.abc import Callable, Iterable
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
+from vllm.distributed import (get_ep_group, get_pp_group,
+ get_tensor_model_parallel_world_size)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsPP
+from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+
+logger = init_logger(__name__)
+
+
+class Glm4MoeMLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj")
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class Glm4MoE(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = self.ep_group.rank()
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts: int = config.n_routed_experts
+ self.n_shared_experts: int = config.n_shared_experts
+
+ if config.hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {config.hidden_act}. "
+ "Only silu is supported for now.")
+
+ self.gate = ReplicatedLinear(config.hidden_size,
+ config.n_routed_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate")
+
+ # noaux_tc is not set in transformers new config now
+ self.gate.e_score_correction_bias = (nn.Parameter(
+ torch.empty(config.n_routed_experts)))
+
+ # Load balancing settings.
+ vllm_config = get_current_vllm_config()
+ parallel_config = vllm_config.parallel_config
+ self.enable_eplb = enable_eplb
+
+ self.n_redundant_experts = parallel_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = (self.n_logical_experts +
+ self.n_redundant_experts)
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = (self.ep_rank *
+ self.n_local_physical_experts)
+ self.physical_expert_end = (self.physical_expert_start +
+ self.n_local_physical_experts)
+
+ self.experts = FusedMoE(
+ num_experts=config.n_routed_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func="sigmoid",
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts)
+
+ if config.n_shared_experts is not None:
+ intermediate_size = (config.moe_intermediate_size *
+ config.n_shared_experts)
+ self.shared_experts = Glm4MoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ reduce_results=self.experts.must_reduce_shared_expert_outputs(
+ ),
+ prefix=f"{prefix}.shared_experts",
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ if self.n_shared_experts is not None:
+ shared_output = self.shared_experts(hidden_states)
+ router_logits, _ = self.gate(hidden_states)
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states,
+ router_logits=router_logits) * self.routed_scaling_factor
+ if shared_output is not None:
+ final_hidden_states = final_hidden_states + shared_output
+ if self.tp_size > 1:
+ final_hidden_states = (
+ self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states))
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class Glm4MoeAttention(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embeddings: int = 131072,
+ head_dim: Optional[int] = None,
+ rms_norm_eps: float = 1e-05,
+ qkv_bias: bool = False,
+ use_qk_norm: bool = False,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.use_qk_norm = use_qk_norm
+
+ self.qkv_proj = QKVParallelLinear(hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj")
+
+ self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
+
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ partial_rotary_factor=partial_rotary_factor,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ if self.use_qk_norm:
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ if self.use_qk_norm:
+ q = self.q_norm(q.reshape(-1, self.num_heads,
+ self.head_dim)).reshape(q.shape)
+ k = self.k_norm(k.reshape(-1, self.num_kv_heads,
+ self.head_dim)).reshape(k.shape)
+
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Glm4MoeDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings",
+ 131072)
+ # DecoderLayers are created with `make_layers` which passes the prefix
+ # with the layer's index.
+ layer_idx = int(prefix.split(sep='.')[-1])
+ self.layer_idx = layer_idx
+
+ self.self_attn = Glm4MoeAttention(
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ head_dim=config.head_dim,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=config.attention_bias,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ use_qk_norm=config.use_qk_norm,
+ )
+
+ if (config.n_routed_experts is not None
+ and layer_idx >= config.first_k_dense_replace):
+ self.mlp = Glm4MoE(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ enable_eplb=enable_eplb,
+ )
+ else:
+ self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(positions=positions,
+ hidden_states=hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Glm4MoeModel(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ enable_eplb = vllm_config.parallel_config.enable_eplb
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens")
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Glm4MoeDecoderLayer(
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ enable_eplb=enable_eplb,
+ ),
+ prefix=f"{prefix}.layers")
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype,
+ device: torch.device) -> IntermediateTensors:
+ return IntermediateTensors({
+ "hidden_states":
+ torch.zeros((batch_size, self.config.hidden_size),
+ dtype=dtype,
+ device=device),
+ "residual":
+ torch.zeros((batch_size, self.config.hidden_size),
+ dtype=dtype,
+ device=device),
+ })
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts)
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is not None:
+ continue
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or not
+ # here since otherwise we may skip experts with other
+ # available replicas.
+ weight_loader = typing.cast(Callable[..., bool],
+ param.weight_loader)
+ success = weight_loader(param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True)
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+class Glm4MoeForCausalLM(nn.Module, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ self.model = Glm4MoeModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+ else:
+ self.lm_head = PPMissingLayer()
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+ self.expert_weights = []
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = (config.num_hidden_layers -
+ config.first_k_dense_replace)
+ self.num_expert_groups = config.n_group
+
+ self.moe_layers: list[FusedMoE] = []
+ for layer in self.model.layers:
+ assert isinstance(layer, Glm4MoeDecoderLayer)
+ if isinstance(layer.mlp, Glm4MoE):
+ self.moe_layers.append(layer.mlp.experts)
+
+ # Pick last one layer since the first ones may be dense layers.
+ example_moe = typing.cast(
+ Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_shared_experts = example_moe.n_shared_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+ def set_eplb_state(
+ self,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ for layer_idx, layer in enumerate(self.moe_layers):
+ # Register the expert weights.
+ self.expert_weights.append(layer.get_expert_weights())
+ layer.set_eplb_state(
+ moe_layer_idx=layer_idx,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ )
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
+
+
+def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
+ weight_name: str) -> Optional[int]:
+ if hasattr(config,
+ "num_nextn_predict_layers") and (config.num_nextn_predict_layers
+ > 0):
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_nextn_predict_layers):
+ if f"layers.{layer_idx+i}." in weight_name:
+ return layer_idx + i
+ return None
diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py
new file mode 100644
index 000000000000..0624640054d1
--- /dev/null
+++ b/vllm/model_executor/models/glm4_moe_mtp.py
@@ -0,0 +1,307 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The ZhipuAI Team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights."""
+
+from collections.abc import Iterable
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from vllm.config import CacheConfig, VllmConfig
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name
+from .interfaces import SupportsPP
+from .utils import maybe_prefix
+
+
+class SharedHead(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.norm(hidden_states)
+
+
+class Glm4MoeMultiTokenPredictorLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ prefix: str,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+ self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.eh_proj = nn.Linear(config.hidden_size * 2,
+ config.hidden_size,
+ bias=False)
+ self.shared_head = SharedHead(config=config, quant_config=quant_config)
+ self.mtp_block = Glm4MoeDecoderLayer(config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ spec_step_index: int = 0,
+ ) -> torch.Tensor:
+ assert inputs_embeds is not None
+ # masking inputs at position 0, as not needed by MTP
+ inputs_embeds[positions == 0] = 0
+ inputs_embeds = self.enorm(inputs_embeds)
+ previous_hidden_states = self.hnorm(previous_hidden_states)
+
+ hidden_states = self.eh_proj(
+ torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
+
+ hidden_states, residual = self.mtp_block(positions=positions,
+ hidden_states=hidden_states,
+ residual=None)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Glm4MoeMultiTokenPredictor(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ self.mtp_start_layer_idx = config.num_hidden_layers
+ self.num_mtp_layers = config.num_nextn_predict_layers
+ # to map the exact layer index from weights
+ self.layers = torch.nn.ModuleDict({
+ str(idx):
+ Glm4MoeMultiTokenPredictorLayer(
+ config,
+ f"{prefix}.layers.{idx}",
+ cache_config=vllm_config.cache_config,
+ quant_config=vllm_config.quant_config,
+ )
+ for idx in range(self.mtp_start_layer_idx,
+ self.mtp_start_layer_idx + self.num_mtp_layers)
+ })
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ )
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ current_step_idx = (spec_step_idx % self.num_mtp_layers)
+ return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
+ input_ids,
+ positions,
+ previous_hidden_states,
+ inputs_embeds,
+ current_step_idx,
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ current_step_idx = (spec_step_idx % self.num_mtp_layers)
+ mtp_layer = self.layers[str(self.mtp_start_layer_idx +
+ current_step_idx)]
+ logits = self.logits_processor(mtp_layer.shared_head.head,
+ mtp_layer.shared_head(hidden_states),
+ sampling_metadata)
+ return logits
+
+
+class Glm4MoeMTP(nn.Module, SupportsPP):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.config = vllm_config.model_config.hf_config
+ self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config,
+ prefix=maybe_prefix(
+ prefix, "model"))
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ hidden_states = self.model(input_ids, positions,
+ previous_hidden_states, inputs_embeds,
+ spec_step_idx)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ spec_step_idx: int = 0,
+ ) -> Optional[torch.Tensor]:
+ return self.model.compute_logits(hidden_states, sampling_metadata,
+ spec_step_idx)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts)
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is None:
+ continue
+ name = self._rewrite_spec_layer_name(spec_layer, name)
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param,
+ loaded_weight,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # According to DeepSeek-V3 Technical Report, MTP modules
+ # shares embedding layer. We only load the first weights.
+ if (spec_layer != self.model.mtp_start_layer_idx
+ and ".layers" not in name):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
+ """
+ Rewrite the weight name to match the format of the original model.
+ Add .mtp_block for modules in transformer layer block for spec layer
+ and rename shared layer weights to be top level.
+ """
+ spec_layer_weight_names = [
+ "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
+ ]
+ shared_weight_names = ["embed_tokens"]
+ spec_layer_weight = False
+ shared_weight = False
+ for weight_name in spec_layer_weight_names:
+ if weight_name in name:
+ spec_layer_weight = True
+ if weight_name in shared_weight_names:
+ shared_weight = True
+ break
+ if not spec_layer_weight:
+ # treat rest weights as weights for transformer layer block
+ name = name.replace(f"model.layers.{spec_layer}.",
+ f"model.layers.{spec_layer}.mtp_block.")
+ elif shared_weight:
+ # treat shared weights as top level weights
+ name = name.replace(f"model.layers.{spec_layer}.", "model.")
+ return name
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 2ca37867b88c..d47b1cb87964 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -67,6 +67,7 @@
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
+ "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
@@ -245,6 +246,7 @@
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
+ "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index 3e5485b883f1..bae593c1dff0 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -3,6 +3,7 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
+from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
@@ -14,4 +15,5 @@
"GraniteReasoningParser",
"HunyuanA13BReasoningParser",
"Qwen3ReasoningParser",
+ "Glm4MoeModelReasoningParser",
]
diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py
new file mode 100644
index 000000000000..6511fb49d10e
--- /dev/null
+++ b/vllm/reasoning/glm4_moe_reasoning_parser.py
@@ -0,0 +1,151 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+from typing import Optional, Union
+
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage)
+from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
+
+logger = init_logger(__name__)
+
+
+@ReasoningParserManager.register_module("glm4_moe")
+class Glm4MoeModelReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for the Glm4MoeModel model.
+
+ The Glm4MoeModel model uses ... tokens to denote reasoning
+ text within its output. The model provides a strict switch to disable
+ reasoning output via the 'enable_thinking=False' parameter. This parser
+ extracts the reasoning content enclosed by and tokens
+ from the model's output.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
+ super().__init__(tokenizer)
+ self.think_start_token = ""
+ self.think_end_token = ""
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ReasoningParser "
+ "constructor during construction.")
+
+ self.think_start_token_id = self.vocab.get(self.think_start_token)
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
+ if (self.think_start_token_id is None
+ or self.think_end_token_id is None):
+ raise RuntimeError(
+ "Glm4MoeModel reasoning parser could not locate "
+ "think start/end tokens in the tokenizer!")
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.think_end_token_id in input_ids
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ """
+ Extract the content after the end tokens
+ """
+ if self.think_end_token_id not in input_ids[:-1]:
+ return []
+ else:
+ return input_ids[input_ids.index(self.think_end_token_id) + 1:]
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract reasoning content from a delta message.
+ Handles streaming output where previous + delta = current.
+ Uses token IDs for faster processing.
+ For text abcxyz:
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+ """
+ # Skip single special tokens
+ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
+ self.think_start_token_id, self.think_end_token_id
+ ]):
+ return None
+
+ if self.think_start_token_id in previous_token_ids:
+ if self.think_end_token_id in delta_token_ids:
+ # in previous, in delta,
+ # extract reasoning content
+ end_index = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[:end_index]
+ content = delta_text[end_index + len(self.think_end_token):]
+ return DeltaMessage(reasoning_content=reasoning_content,
+ content=content if content else None)
+ elif self.think_end_token_id in previous_token_ids:
+ # in previous, in previous,
+ # reasoning content continues
+ return DeltaMessage(content=delta_text)
+ else:
+ # in previous, no in previous or delta,
+ # reasoning content continues
+ return DeltaMessage(reasoning_content=delta_text)
+ elif self.think_start_token_id in delta_token_ids:
+ if self.think_end_token_id in delta_token_ids:
+ # in delta, in delta, extract reasoning content
+ start_index = delta_text.find(self.think_start_token)
+ end_index = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[start_index +
+ len(self.think_start_token
+ ):end_index]
+ content = delta_text[end_index + len(self.think_end_token):]
+ return DeltaMessage(reasoning_content=reasoning_content,
+ content=content if content else None)
+ else:
+ # in delta, no in delta,
+ # reasoning content continues
+ return DeltaMessage(reasoning_content=delta_text)
+ else:
+ # thinking is disabled, just content
+ return DeltaMessage(content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest
+ ) -> tuple[Optional[str], Optional[str]]:
+ """
+ Extract reasoning content from the model output.
+
+ For text abcxyz:
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+
+ Returns:
+ tuple[Optional[str], Optional[str]]: reasoning content and content
+ """
+
+ # Check if the model output contains the and tokens.
+ if (self.think_start_token not in model_output
+ or self.think_end_token not in model_output):
+ return None, model_output
+ # Check if the is present in the model output, remove it
+ # if it is present.
+ model_output_parts = model_output.partition(self.think_start_token)
+ model_output = model_output_parts[2] if model_output_parts[
+ 1] else model_output_parts[0]
+ # Check if the model output contains the tokens.
+ # If the end token is not found, return the model output as is.
+ if self.think_end_token not in model_output:
+ return None, model_output
+
+ # Extract reasoning content from the model output.
+ reasoning_content, _, content = model_output.partition(
+ self.think_end_token)
+
+ final_content = content or None
+ return reasoning_content, final_content
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index b2926dbd185a..6b6943d76436 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -77,7 +77,8 @@ def __init__(
"mlp_speculator",
"eagle",
"deepseek_mtp",
- "mimo_mtp")) \
+ "glm4_moe_mtp",
+ "mimo_mtp")) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner