diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 132c325ce591..c350aaf5d3ad 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -576,7 +576,11 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): + elif config.architectures[0] in ( + "DeepseekV3ForCausalLM", + "DeepseekV2ForCausalLM", + "Glm4MoeForCausalLM", + ): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index dba1f3943b96..4ed690090144 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -318,6 +318,7 @@ def main(args: argparse.Namespace): elif ( config.architectures[0] == "DeepseekV3ForCausalLM" or config.architectures[0] == "DeepseekV2ForCausalLM" + or config.architectures[0] == "Glm4MoeForCausalLM" ): E = config.n_routed_experts topk = config.num_experts_per_tok diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 887f754a3d1c..b0be947a66f9 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -577,6 +577,7 @@ Specified using `--task generate`. | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4MoeForCausalLM` | GLM-4.5 | T + IE+ + VE+ | `THUDM/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 5c546a6c86da..c537438ad6fb 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -364,6 +364,9 @@ def check_available_online( trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 + "Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4.5", + min_transformers_version="4.54", + is_available_online=False), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 @@ -489,6 +492,10 @@ def check_available_online( is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16"), + "Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4.5", + speculative_model="THUDM/GLM-4.5", + min_transformers_version="4.54", + is_available_online=False), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL") diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_use/test_glm4_moe_tool_parser.py new file mode 100644 index 000000000000..478f4b916672 --- /dev/null +++ b/tests/tool_use/test_glm4_moe_tool_parser.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json + +import pytest + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +pytest.skip("skip glm4_moe parser test", allow_module_level=True) +# Use a common model that is likely to be available +MODEL = "THUDM/GLM-4.5" + + +@pytest.fixture(scope="module") +def glm4_moe_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def glm4_moe_tool_parser(glm4_moe_tokenizer): + return Glm4MoeModelToolParser(glm4_moe_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 0 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function.name == expected_tool_call.function.name + # Compare arguments as JSON objects to handle formatting differences + actual_args = json.loads(actual_tool_call.function.arguments) + expected_args = json.loads(expected_tool_call.function.arguments) + assert actual_args == expected_args + + +def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): + model_output = "This is a test" + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + "tool_call_with_mixed_args", + "tool_call_with_chinese_content", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """get_current_weather + city + Dallas + state + TX + unit + fahrenheit + """, + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + None, + ), + ( + """get_current_weather + city + Dallas + state + TX + unit + fahrenheit + + get_current_weather + city + Orlando + state + FL + unit + fahrenheit + """, + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + }), + )), + ], + None, + ), + ( + """I'll help you check the weather. get_current_weather + city + Seattle + state + WA + unit + celsius + """, + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Seattle", + "state": "WA", + "unit": "celsius", + }), + )) + ], + "I'll help you check the weather.", + ), + ( + """get_current_weather + city + New York + state + NY + unit + celsius + """, + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "New York", + "state": "NY", + "unit": "celsius", + }), + )) + ], + None, + ), + ("""I will help you get the weather.get_weather + city + Beijing + date + 2025-08-01 + """, [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Beijing", + "date": "2025-08-01", + }), + )) + ], "I will help you get the weather."), + ], +) +def test_extract_tool_calls(glm4_moe_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser): + """Test tool extraction when thinking tags are present.""" + model_output = """I want to get the weather. + +I will help you get the weather. +get_weather +city +Beijing +date +2025-08-01 +""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[0].function.name == "get_weather" + + expected_content = """I want to get the weather. + +I will help you get the weather.""" + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser): + """Test that malformed XML is handled gracefully.""" + model_output = """get_weather +city +Seattle +incomplete_arg +value +""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + # Should handle malformed XML gracefully + # The parser should either extract what it can or return no tool calls + # depending on how robust we want the parsing to be + assert isinstance(extracted_tool_calls.tools_called, bool) + assert isinstance(extracted_tool_calls.tool_calls, list) + + +def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser): + """Test tool calls with no arguments.""" + model_output = """get_current_time +""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_time" + # Empty arguments should result in empty JSON object + assert extracted_tool_calls.tool_calls[0].function.arguments == "{}" + + +def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser): + """Test extraction with mixed content and multiple tool calls.""" + model_output = """I will help you get the weather info. + +get_weather +city +Beijing +date +2025-08-01 + + +meaningwhile, I will also check the weather in Shanghai. + +get_weather +city +Shanghai +date +2025-08-01 +""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 2 + + # Check first tool call + assert extracted_tool_calls.tool_calls[0].function.name == "get_weather" + args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args1["city"] == "Beijing" + assert args1["date"] == "2025-08-01" + + # Check second tool call + assert extracted_tool_calls.tool_calls[1].function.name == "get_weather" + args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments) + assert args2["city"] == "Shanghai" + assert args2["date"] == "2025-08-01" + + # Content should be everything before the first tool call + assert extracted_tool_calls.content == "I will help you get the weather info." + + +def test_streaming_basic_functionality(glm4_moe_tool_parser): + """Test basic streaming functionality.""" + # Reset streaming state + glm4_moe_tool_parser.current_tool_name_sent = False + glm4_moe_tool_parser.prev_tool_call_arr = [] + glm4_moe_tool_parser.current_tool_id = -1 + glm4_moe_tool_parser.streamed_args_for_tool = [] + + # Test with a simple tool call + current_text = """get_weather +city +Beijing +""" + + # Mock token IDs for testing + tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345 + tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346 + + result = glm4_moe_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=current_text, + delta_text="", + previous_token_ids=[], + current_token_ids=[tool_call_start_id, tool_call_end_id], + delta_token_ids=[tool_call_end_id], + request=None, + ) + + # The result behavior depends on the streaming state + # This test mainly ensures no exceptions are thrown + assert result is None or hasattr(result, 'tool_calls') or hasattr( + result, 'content') + + +def test_streaming_no_tool_calls(glm4_moe_tool_parser): + """Test streaming when there are no tool calls.""" + current_text = "This is just regular text without any tool calls." + + result = glm4_moe_tool_parser.extract_tool_calls_streaming( + previous_text="This is just regular text", + current_text=current_text, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser): + """Test streaming when there's content before tool calls.""" + # Reset streaming state + glm4_moe_tool_parser.current_tool_name_sent = False + glm4_moe_tool_parser.prev_tool_call_arr = [] + glm4_moe_tool_parser.current_tool_id = -1 + glm4_moe_tool_parser.streamed_args_for_tool = [] + + current_text = "I will help you get the weather" + + result = glm4_moe_tool_parser.extract_tool_calls_streaming( + previous_text="I will help you", + current_text=current_text, + delta_text="get the weather.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return content when no tool call tokens are detected + assert result is not None + assert hasattr(result, 'content') + assert result.content == "get the weather." + + +def test_extract_tool_calls_special_characters(glm4_moe_tool_parser): + """Test tool calls with special characters and unicode.""" + model_output = """send_message +recipient +Amy +message +It is a nice day +priority +high +""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[0].function.name == "send_message" + + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args["recipient"] == "Amy" + assert args["message"] == "It is a nice day" + assert args["priority"] == "high" + + +def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser): + """Test incomplete tool calls (missing closing tag).""" + model_output = """get_weather +city +Beijing +date +2025-08-01""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + # Incomplete tool calls should not be extracted + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output diff --git a/vllm/config.py b/vllm/config.py index 384cb584fa9a..e92c501012a4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1333,7 +1333,8 @@ def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp"): + or self.hf_config.model_type == "mimo_mtp" + or self.hf_config.model_type == "glm4_moe_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -2663,7 +2664,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: "n_predict": n_predict, "architectures": ["MiMoMTPModel"] }) - return hf_config + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"] + }) return hf_config @@ -2774,7 +2783,7 @@ def __post_init__(self): "mlp_speculator"): self.method = "mlp_speculator" elif (self.draft_model_config.hf_config.model_type - in ("deepseek_mtp", "mimo_mtp")): + in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): self.method = "deepseek_mtp" if self.num_speculative_tokens > 1: logger.warning( diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 137375b9707c..9eda7155f01f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,6 +3,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser @@ -19,10 +20,22 @@ from .xlam_tool_parser import xLAMToolParser __all__ = [ - "ToolParser", "ToolParserManager", "Granite20bFCToolParser", - "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", - "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", - "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser", - "KimiK2ToolParser", "HunyuanA13BToolParser" + "ToolParser", + "ToolParserManager", + "Granite20bFCToolParser", + "GraniteToolParser", + "Hermes2ProToolParser", + "MistralToolParser", + "Internlm2ToolParser", + "Llama3JsonToolParser", + "JambaToolParser", + "Llama4PythonicToolParser", + "PythonicToolParser", + "Phi4MiniJsonToolParser", + "DeepSeekV3ToolParser", + "xLAMToolParser", + "MinimaxToolParser", + "KimiK2ToolParser", + "HunyuanA13BToolParser", + "Glm4MoeModelToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py new file mode 100644 index 000000000000..c3f9d7923575 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# code modified from deepseekv3_tool_parser.py + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("glm4_moe") +class Glm4MoeModelToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + self.tool_calls_start_token = self.tool_call_start_token + + # Updated regex for the XML-based format + self.tool_call_regex = re.compile( + r"\s*" + r"(?P[^\n<]+)\s*" # 函数名(到换行或 <) + r"(?P(?:\s*[^<]+\s*" + r"[^<]*\s*)*)\s*" + r"", + re.DOTALL, + ) + + # Regex for parsing individual arguments + self.arg_regex = re.compile( + r"(?P[^<]+)\s*(?P[^<]*)", + re.DOTALL, + ) + + # Streaming regex + self.stream_tool_call_portion_regex = re.compile( + r"(?P[^\n<]+)\s*" + r"(?P(?:\s*[^<]+\s*" + r"[^<]*\s*)*)", + re.DOTALL, + ) + + # For streaming, we also need a regex to match just the function name + self.stream_tool_call_name_regex = re.compile( + r"(?P[^\n<]+)", + re.DOTALL, + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + def _parse_arguments(self, args_text: str) -> str: + """Parse XML-based arguments into JSON format.""" + if not args_text or not args_text.strip(): + return "{}" + + args_dict = {} + matches = self.arg_regex.findall(args_text) + + for key, value in matches: + args_dict[key.strip()] = value.strip() + + import json + return json.dumps(args_dict, ensure_ascii=False) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # Find all tool calls in the output + function_call_matches = self.tool_call_regex.findall(model_output) + + logger.debug("function_call_matches: %s", function_call_matches) + + if not function_call_matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls = [] + for i, match in enumerate(function_call_matches): + function_name, function_args_xml = match + function_name = function_name.strip() + + # Parse XML arguments to JSON + function_args_json = self._parse_arguments(function_args_xml) + + tool_calls.append( + ToolCall( + id=f"call_{i}", + type='function', + function=FunctionCall(name=function_name, + arguments=function_args_json), + )) + + # Extract content before the first tool call + content = model_output[:model_output.find(self. + tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content.strip() if content.strip() else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_call_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_id, tool_args = (current_tool_call_matches.groups()) + tool_name = tool_id.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_id_str, = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id_str + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + tool_id = current_tool_call.get("id") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py new file mode 100644 index 000000000000..bdca293d21db --- /dev/null +++ b/vllm/model_executor/models/glm4_moe.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4.5 model compatible with HuggingFace weights.""" +import typing +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Glm4MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Glm4MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + # noaux_tc is not set in transformers new config now + self.gate.e_score_correction_bias = (nn.Parameter( + torch.empty(config.n_routed_experts))) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Glm4MoeAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 131072, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + use_qk_norm: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_qk_norm = use_qk_norm + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q = self.q_norm(q.reshape(-1, self.num_heads, + self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, + self.head_dim)).reshape(k.shape) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + + self.self_attn = Glm4MoeAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + head_dim=config.head_dim, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=config.attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_qk_norm=config.use_qk_norm, + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace): + self.mlp = Glm4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Glm4MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Glm4MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Glm4MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if f"layers.{layer_idx+i}." in weight_name: + return layer_idx + i + return None diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py new file mode 100644 index 000000000000..0624640054d1 --- /dev/null +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name +from .interfaces import SupportsPP +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class Glm4MoeMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = Glm4MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class Glm4MoeMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + Glm4MoeMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class Glm4MoeMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2ca37867b88c..d47b1cb87964 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -67,6 +67,7 @@ "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501 "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), + "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -245,6 +246,7 @@ "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 3e5485b883f1..bae593c1dff0 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -3,6 +3,7 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser @@ -14,4 +15,5 @@ "GraniteReasoningParser", "HunyuanA13BReasoningParser", "Qwen3ReasoningParser", + "Glm4MoeModelReasoningParser", ] diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py new file mode 100644 index 000000000000..6511fb49d10e --- /dev/null +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("glm4_moe") +class Glm4MoeModelReasoningParser(ReasoningParser): + """ + Reasoning parser for the Glm4MoeModel model. + + The Glm4MoeModel model uses ... tokens to denote reasoning + text within its output. The model provides a strict switch to disable + reasoning output via the 'enable_thinking=False' parameter. This parser + extracts the reasoning content enclosed by and tokens + from the model's output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_token = "" + self.think_end_token = "" + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + if (self.think_start_token_id is None + or self.think_end_token_id is None): + raise RuntimeError( + "Glm4MoeModel reasoning parser could not locate " + "think start/end tokens in the tokenizer!") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.think_end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.think_end_token_id) + 1:] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.think_start_token_id, self.think_end_token_id + ]): + return None + + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.think_start_token_id in delta_token_ids: + if self.think_end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[start_index + + len(self.think_start_token + ):end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # thinking is disabled, just content + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + + # Check if the model output contains the and tokens. + if (self.think_start_token not in model_output + or self.think_end_token not in model_output): + return None, model_output + # Check if the is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.think_start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + # Check if the model output contains the tokens. + # If the end token is not found, return the model output as is. + if self.think_end_token not in model_output: + return None, model_output + + # Extract reasoning content from the model output. + reasoning_content, _, content = model_output.partition( + self.think_end_token) + + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b2926dbd185a..6b6943d76436 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -77,7 +77,8 @@ def __init__( "mlp_speculator", "eagle", "deepseek_mtp", - "mimo_mtp")) \ + "glm4_moe_mtp", + "mimo_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner