diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index ad3db1cf2100..297d98142b5f 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -401,6 +401,7 @@ th {
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index a6d5c305f799..a6f912643be4 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -292,6 +292,9 @@ def check_available_online(
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
+ "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
+ trust_remote_code=True,
+ is_available_online=False),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py
new file mode 100644
index 000000000000..d85bc9bbf1b3
--- /dev/null
+++ b/tests/tool_use/test_seed_oss_tool_parser.py
@@ -0,0 +1,459 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# ruff: noqa: E501
+
+import json
+from collections.abc import Generator
+from typing import Optional
+
+import pytest
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ ChatCompletionToolsParam,
+ DeltaMessage, FunctionCall,
+ ToolCall)
+from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
+from vllm.transformers_utils.detokenizer import detokenize_incrementally
+from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
+
+# Use a common model that is likely to be available
+MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
+
+
+@pytest.fixture(scope="module")
+def seed_oss_tokenizer():
+ return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
+
+
+@pytest.fixture
+def seed_oss_tool_parser(seed_oss_tokenizer):
+ return SeedOssToolParser(seed_oss_tokenizer)
+
+
+@pytest.fixture
+def sample_tools():
+ return [
+ ChatCompletionToolsParam(
+ type="function",
+ function={
+ "name": "get_weather",
+ "description": "Get current temperature for a given location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description":
+ "City and country e.g. Bogotá, Colombia"
+ },
+ "unit": {
+ "type": "string",
+ "description": "this is the unit of temperature"
+ }
+ },
+ "required": ["location"],
+ "additionalProperties": False
+ },
+ "returns": {
+ "type": "object",
+ "properties": {
+ "temperature": {
+ "type": "number",
+ "description": "temperature in celsius"
+ }
+ },
+ "required": ["temperature"],
+ "additionalProperties": False
+ },
+ "strict": True
+ }),
+ ]
+
+
+def assert_tool_calls(actual_tool_calls: list[ToolCall],
+ expected_tool_calls: list[ToolCall]):
+ assert len(actual_tool_calls) == len(expected_tool_calls)
+
+ for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
+ expected_tool_calls):
+ # Seed-OSS tool call will not generate id
+ assert actual_tool_call.type == "function"
+ assert actual_tool_call.function == expected_tool_call.function
+
+ assert actual_tool_call.function.name == expected_tool_call.function.name
+ assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
+
+
+def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
+ model_output = "This is a test response without any tool calls"
+ extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
+ model_output, request=None) # type: ignore[arg-type]
+
+ assert not extracted_tool_calls.tools_called
+ assert extracted_tool_calls.tool_calls == []
+ assert extracted_tool_calls.content == model_output
+
+
+@pytest.mark.parametrize(
+ ids=[
+ "tool_call_0_thinking_budget",
+ "tool_call_512_thinkg_budget",
+ "tool_call_unlimited_thinking_budget",
+ ],
+ argnames=["model_output", "expected_tool_calls", "expected_content"],
+ argvalues=[
+ ("""\n\n\n"""
+ """The current thinking budget is 0, so I will directly start answering the question.\n\n"""
+ """\n\n"""
+ """Barcelona, Spain\n\n""",
+ [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps({
+ "location": "Barcelona, Spain",
+ }, ),
+ ),
+ type='function')
+ ],
+ """\n\n\n"""
+ """The current thinking budget is 0, so I will directly start answering the question.\n\n"""
+ ),
+ (
+ """The user\'s current thinking budget is 512.\nLet me analyze the """
+ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
+ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
+ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
+ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
+ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use."""
+ """\n Since the unit isn\'t specified, the function will default to Celsius, which """
+ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
+ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
+ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """
+ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """
+ """use.\n The unit parameter can be omitted since it\'s optional.\n"""
+ """\n\nBarcelona, Spain\n"""
+ """\n""",
+ [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps({
+ "location": "Barcelona, Spain",
+ }, ),
+ ),
+ type='function')
+ ],
+ """The user\'s current thinking budget is 512.\nLet me analyze the """
+ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
+ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
+ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
+ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
+ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use."""
+ """\n Since the unit isn\'t specified, the function will default to Celsius, which """
+ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
+ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
+ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """
+ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """
+ """use.\n The unit parameter can be omitted since it\'s optional.\n""",
+ ),
+ (
+ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
+ """First, I need to remember the function I can use: get_weather. The function requires a """
+ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
+ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
+ """let me check the function docstring again. Oh, the function says unit is optional, and """
+ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
+ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
+ """The format is \n\nBarcelona, """
+ """Spain\ncelsius\n\n. """
+ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
+ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
+ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
+ """call should be as above. Then wait for the result to come back and tell the user the """
+ """temperature in Celsius.\n\n"""
+ """Barcelona, Spain\ncelsius\n\n""",
+ [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps(
+ {
+ "location": "Barcelona, Spain",
+ "unit": "celsius",
+ }, ),
+ ),
+ type='function')
+ ],
+ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
+ """First, I need to remember the function I can use: get_weather. The function requires a """
+ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
+ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
+ """let me check the function docstring again. Oh, the function says unit is optional, and """
+ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
+ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
+ """The format is \n\nBarcelona, """
+ """Spain\ncelsius\n\n. """
+ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
+ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
+ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
+ """call should be as above. Then wait for the result to come back and tell the user the """
+ """temperature in Celsius.""",
+ ),
+ ],
+)
+def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
+ expected_tool_calls, expected_content):
+ request = ChatCompletionRequest(model=MODEL,
+ messages=[],
+ tools=sample_tools)
+ extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
+ model_output, request=request) # type: ignore[arg-type]
+ assert extracted_tool_calls.tools_called
+
+ assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
+
+ assert extracted_tool_calls.content == expected_content
+
+
+def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
+ model_output = "This is a test response without any tool calls"
+
+ result = seed_oss_tool_parser.extract_tool_calls_streaming(
+ previous_text="his is a test response",
+ current_text=model_output,
+ delta_text=" without any tool calls.",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[],
+ request=None,
+ )
+
+ # Should return the delta text as content
+ assert result is not None
+ assert hasattr(result, 'content')
+ assert result.content == " without any tool calls."
+
+
+def stream_delta_message_generator(
+ seed_oss_tool_parser: SeedOssToolParser,
+ seed_oss_tokenizer: AnyTokenizer,
+ model_output: str,
+ request: Optional[ChatCompletionRequest] = None
+) -> Generator[DeltaMessage, None, None]:
+ all_token_ids = seed_oss_tokenizer.encode(model_output,
+ add_special_tokens=False)
+
+ previous_text = ""
+ previous_tokens = None
+ prefix_offset = 0
+ read_offset = 0
+ for i, delta_token in enumerate(all_token_ids):
+ delta_token_ids = [delta_token]
+ previous_token_ids = all_token_ids[:i]
+ current_token_ids = all_token_ids[:i + 1]
+
+ (new_tokens, delta_text, new_prefix_offset,
+ new_read_offset) = detokenize_incrementally(
+ tokenizer=seed_oss_tokenizer,
+ all_input_ids=current_token_ids,
+ prev_tokens=previous_tokens,
+ prefix_offset=prefix_offset,
+ read_offset=read_offset,
+ skip_special_tokens=False,
+ spaces_between_special_tokens=True,
+ )
+
+ current_text = previous_text + delta_text
+
+ delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ delta_token_ids,
+ request=request,
+ )
+ if delta_message:
+ yield delta_message
+
+ previous_text = current_text
+ previous_tokens = (previous_tokens +
+ new_tokens if previous_tokens else new_tokens)
+ prefix_offset = new_prefix_offset
+ read_offset = new_read_offset
+
+
+@pytest.mark.parametrize(
+ ids=[
+ "tool_call_0_thinking_budget",
+ "tool_call_512_thinkg_budget",
+ "tool_call_unlimited_thinking_budget",
+ ],
+ argnames=["model_output", "expected_tool_calls", "expected_content"],
+ argvalues=[
+ ("""\n\n\n"""
+ """The current thinking budget is 0, so I will directly start answering the question.\n\n"""
+ """\n\n"""
+ """Barcelona, Spain\n\n""",
+ [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps({
+ "location": "Barcelona, Spain",
+ }, ),
+ ),
+ type='function')
+ ],
+ """\n\n\n"""
+ """The current thinking budget is 0, so I will directly start answering the question.\n\n"""
+ ),
+ (
+ """The user\'s current thinking budget is 512.\nLet me analyze the """
+ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
+ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
+ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
+ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
+ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use."""
+ """\n Since the unit isn\'t specified, the function will default to Celsius, which """
+ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
+ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
+ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """
+ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """
+ """use.\n The unit parameter can be omitted since it\'s optional.\n"""
+ """\n\nBarcelona, Spain\n"""
+ """\n""",
+ [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps({
+ "location": "Barcelona, Spain",
+ }, ),
+ ),
+ type='function')
+ ],
+ """The user\'s current thinking budget is 512.\nLet me analyze the """
+ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
+ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
+ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
+ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
+ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use."""
+ """\n Since the unit isn\'t specified, the function will default to Celsius, which """
+ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
+ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
+ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """
+ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """
+ """use.\n The unit parameter can be omitted since it\'s optional.\n""",
+ ),
+ (
+ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
+ """First, I need to remember the function I can use: get_weather. The function requires a """
+ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
+ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
+ """let me check the function docstring again. Oh, the function says unit is optional, and """
+ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
+ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
+ """The format is \n\nBarcelona, """
+ """Spain\ncelsius\n\n. """
+ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
+ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
+ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
+ """call should be as above. Then wait for the result to come back and tell the user the """
+ """temperature in Celsius.\n\n"""
+ """Barcelona, Spain\ncelsius\n\n""",
+ [
+ ToolCall(function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps(
+ {
+ "location": "Barcelona, Spain",
+ "unit": "celsius",
+ }, ),
+ ),
+ type='function')
+ ],
+ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
+ """First, I need to remember the function I can use: get_weather. The function requires a """
+ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
+ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
+ """let me check the function docstring again. Oh, the function says unit is optional, and """
+ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
+ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
+ """The format is \n\nBarcelona, """
+ """Spain\ncelsius\n\n. """
+ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
+ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
+ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
+ """call should be as above. Then wait for the result to come back and tell the user the """
+ """temperature in Celsius.""",
+ ),
+ ],
+)
+def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
+ sample_tools, model_output, expected_tool_calls,
+ expected_content):
+ """Test incremental streaming behavior"""
+ request = ChatCompletionRequest(model=MODEL,
+ messages=[],
+ tools=sample_tools)
+
+ other_content = ''
+ tool_states = {} # Track state per tool index
+
+ for delta_message in stream_delta_message_generator(
+ seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
+ # role should never be streamed from tool parser
+ assert not delta_message.role
+
+ if delta_message.content:
+ other_content += delta_message.content
+
+ if delta_message.tool_calls:
+ for tool_call in delta_message.tool_calls:
+ idx = tool_call.index
+
+ # Initialize state for new tool
+ if idx not in tool_states:
+ tool_states[idx] = {
+ "id": None,
+ "name": None,
+ "arguments": "",
+ "type": None
+ }
+
+ # First chunk should have id, name, and type
+ if tool_call.id:
+ tool_states[idx]["id"] = tool_call.id
+
+ if tool_call.type:
+ assert tool_call.type == "function"
+ tool_states[idx]["type"] = tool_call.type
+
+ if tool_call.function:
+ if tool_call.function.name:
+ # Should only be set once
+ assert tool_states[idx]["name"] is None
+ tool_states[idx]["name"] = tool_call.function.name
+
+ if tool_call.function.arguments is not None:
+ # Accumulate arguments incrementally
+ tool_states[idx][
+ "arguments"] += tool_call.function.arguments
+
+ # Verify final content
+ assert other_content == expected_content
+
+ # Verify we got all expected tool calls
+ assert len(tool_states) == len(expected_tool_calls)
+
+ # Verify each tool call
+ for idx, expected_tool in enumerate(expected_tool_calls):
+ state = tool_states[idx]
+ assert state["id"] is not None
+ assert state["type"] == "function"
+ assert state["name"] == expected_tool.function.name
+
+ # Parse accumulated arguments
+ arguments_str = state["arguments"]
+ assert arguments_str is not None
+ actual_args = json.loads(arguments_str)
+ expected_args = json.loads(expected_tool.function.arguments)
+ assert actual_args == expected_args
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 099e456aa486..468c3799bd1f 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -18,6 +18,7 @@
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
+from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
@@ -41,5 +42,6 @@
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
+ "SeedOssToolParser",
"Step3ToolParser",
]
diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
new file mode 100644
index 000000000000..69cf2e68f7c4
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
@@ -0,0 +1,676 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Adapted from qwen3coder xml parser, All rights reserved.
+# ruff: noqa: E501
+
+import ast
+import json
+import uuid
+from collections.abc import Sequence
+from typing import Any, Optional, Union
+
+import regex as re
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ ChatCompletionToolsParam,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("seed_oss")
+class SeedOssToolParser(ToolParser):
+ TOOL_CALL_START = ""
+ TOOL_CALL_END = ""
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ # --- streaming state ---
+ self._reset_streaming_state()
+ self.prev_tool_call_arr: list[dict] = []
+
+ self.tool_call_start_token: str = self.TOOL_CALL_START
+ self.tool_call_end_token: str = self.TOOL_CALL_END
+ # Sentinel tokens for streaming mode
+ self.tool_call_prefix: str = " or its closing tag.")
+
+ tool_start_re = re.escape(self.tool_call_start_token)
+ tool_end_re = re.escape(self.tool_call_end_token)
+
+ self.tool_call_complete_regex = re.compile(
+ rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
+ self.tool_call_regex = re.compile(
+ rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
+ re.DOTALL)
+
+ self.tool_call_function_regex = re.compile(
+ r"|| str:
+ """Generate a unique tool call ID."""
+ return f"call_{uuid.uuid4().hex[:24]}"
+
+ def _reset_streaming_state(self):
+ """Reset all streaming state."""
+ self.current_tool_index = 0
+ self.is_tool_call_started = False
+ self.header_sent = False
+ self.current_tool_id = -1
+ self.current_function_name = None
+ self.current_param_name = None
+ self.current_param_value = ""
+ self.param_count = 0
+ self.in_param = False
+ self.in_function = False
+ self.accumulated_text = ""
+ self.json_started = False
+ self.json_closed = False
+
+ def _parse_xml_function_call(
+ self, function_call_str: str,
+ tools: Optional[list[ChatCompletionToolsParam]]
+ ) -> Optional[ToolCall]:
+
+ def get_arguments_config(func_name: str) -> dict:
+ if tools is None:
+ return {}
+ for config in tools:
+ if not hasattr(config, "type") or not (
+ hasattr(config, "function")
+ and hasattr(config.function, "name")):
+ continue
+ if (config.type == "function"
+ and config.function.name == func_name):
+ if not hasattr(config.function, "parameters"):
+ return {}
+ params = config.function.parameters
+ if isinstance(params, dict) and "properties" in params:
+ return params["properties"]
+ elif isinstance(params, dict):
+ return params
+ else:
+ return {}
+ logger.warning("Tool '%s' is not defined in the tools list.",
+ func_name)
+ return {}
+
+ def convert_param_value(param_value: str, param_name: str,
+ param_config: dict, func_name: str) -> Any:
+ # Handle null value for any type
+ if param_value.lower() == "null":
+ return None
+
+ if param_name not in param_config:
+ if param_config != {}:
+ logger.warning(
+ "Parsed parameter '%s' is not defined in "
+ "the tool parameters for tool '%s', "
+ "directly returning the string value.", param_name,
+ func_name)
+ return param_value
+
+ if (isinstance(param_config[param_name], dict)
+ and "type" in param_config[param_name]):
+ param_type = str(
+ param_config[param_name]["type"]).strip().lower()
+ else:
+ param_type = "string"
+ if param_type in [
+ "string", "str", "text", "varchar", "char", "enum"
+ ]:
+ return param_value
+ elif (param_type.startswith("int") or param_type.startswith("uint")
+ or param_type.startswith("long")
+ or param_type.startswith("short")
+ or param_type.startswith("unsigned")):
+ try:
+ param_value = int(param_value) # type: ignore
+ except (ValueError, TypeError):
+ logger.warning(
+ "Parsed value '%s' of parameter '%s' is not an integer in tool "
+ "'%s', degenerating to string.", param_value,
+ param_name, func_name)
+ return param_value
+ elif param_type.startswith("num") or param_type.startswith(
+ "float"):
+ try:
+ float_param_value = float(param_value)
+ param_value = float_param_value if float_param_value - int(
+ float_param_value) != 0 else int(
+ float_param_value) # type: ignore
+ except (ValueError, TypeError):
+ logger.warning(
+ "Parsed value '%s' of parameter '%s' is not a float in tool "
+ "'%s', degenerating to string.", param_value,
+ param_name, func_name)
+ return param_value
+ elif param_type in ["boolean", "bool", "binary"]:
+ param_value = param_value.lower()
+ if param_value not in ["true", "false"]:
+ logger.warning(
+ "Parsed value '%s' of parameter '%s' is not a boolean "
+ "(`true` of `false`) in tool '%s', degenerating to false.",
+ param_value, param_name, func_name)
+ return param_value == "true"
+ else:
+ if param_type == "object" or param_type.startswith("dict"):
+ try:
+ param_value = json.loads(param_value)
+ return param_value
+ except (ValueError, TypeError, json.JSONDecodeError):
+ logger.warning(
+ "Parsed value '%s' of parameter '%s' is not a valid JSON "
+ "object in tool '%s', will try other methods to parse it.",
+ param_value, param_name, func_name)
+ try:
+ param_value = ast.literal_eval(param_value)
+ except (ValueError, SyntaxError):
+ logger.warning(
+ "Parsed value '%s' of parameter '%s' cannot be converted via "
+ "Python `ast.literal_eval()` in tool '%s', degenerating to string.",
+ param_value, param_name, func_name)
+ return param_value
+
+ # Extract function name
+ end_index = function_call_str.index(">")
+ function_name = function_call_str[:end_index]
+ param_config = get_arguments_config(function_name)
+ parameters = function_call_str[end_index + 1:]
+ param_dict = {}
+ for match in self.tool_call_parameter_regex.findall(parameters):
+ match_text = match[0] if match[0] else match[1]
+ idx = match_text.index(">")
+ param_name = match_text[:idx]
+ param_value = str(match_text[idx + 1:])
+ # Remove prefix and trailing \n
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ param_dict[param_name] = convert_param_value(
+ param_value, param_name, param_config, function_name)
+ return ToolCall(
+ type="function",
+ function=FunctionCall(name=function_name,
+ arguments=json.dumps(param_dict,
+ ensure_ascii=False)),
+ )
+
+ def _get_function_calls(self, model_output: str) -> list[str]:
+ # Find all tool calls
+ matched_ranges = self.tool_call_regex.findall(model_output)
+ raw_tool_calls = [
+ match[0] if match[0] else match[1] for match in matched_ranges
+ ]
+
+ # Back-off strategy if no tool_call tags found
+ if len(raw_tool_calls) == 0:
+ raw_tool_calls = [model_output]
+
+ raw_function_calls = []
+ for tool_call in raw_tool_calls:
+ raw_function_calls.extend(
+ self.tool_call_function_regex.findall(tool_call))
+
+ function_calls = [
+ match[0] if match[0] else match[1] for match in raw_function_calls
+ ]
+ return function_calls
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ # Quick check to avoid unnecessary processing
+ if self.tool_call_prefix not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ # Check if both think start and end tokens are present
+ if (self.think_start_token in model_output
+ and self.think_end_token in model_output):
+ # Find the position of think end token
+ think_end_index = model_output.find(self.think_end_token) + len(
+ self.think_end_token)
+ # Extract content after think end token
+ result_content = model_output[think_end_index:]
+ thinking_content = model_output[:think_end_index]
+
+ try:
+ function_calls = self._get_function_calls(result_content)
+ if len(function_calls) == 0:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ tool_calls = [
+ self._parse_xml_function_call(function_call_str, request.tools)
+ for function_call_str in function_calls
+ ]
+
+ # Populate prev_tool_call_arr for serving layer to set finish_reason
+ self.prev_tool_call_arr.clear() # Clear previous calls
+ for tool_call in tool_calls:
+ if tool_call:
+ self.prev_tool_call_arr.append({
+ "name":
+ tool_call.function.name,
+ "arguments":
+ tool_call.function.arguments,
+ })
+
+ # Extract content before tool calls
+ tool_call_start_index = result_content.find(
+ self.tool_call_start_token)
+ tool_call_start_index = (
+ tool_call_start_index if tool_call_start_index >= 0 else
+ result_content.find(self.tool_call_prefix))
+ content = thinking_content + result_content[:tool_call_start_index]
+
+ return ExtractedToolCallInformation(
+ tools_called=(len(tool_calls) > 0),
+ tool_calls=tool_calls,
+ content=content if content else None,
+ )
+
+ except Exception:
+ logger.exception("Error in extracting tool call from response.")
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+ # If no delta text, return None unless
+ # it's an EOS token after tool calls
+ if not delta_text:
+ # Check if this is an EOS token after all tool calls are complete
+ # We check for tool calls in the text even if is_tool_call_started
+ # is False because it might have been reset after processing all tools
+ if (delta_token_ids
+ and self.tool_call_end_token_id not in delta_token_ids):
+ # Count complete tool calls
+ complete_calls = len(
+ self.tool_call_complete_regex.findall(current_text))
+
+ # If we have completed tool calls and populated prev_tool_call_arr
+ if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
+ # Check if all tool calls are closed
+ open_calls = current_text.count(
+ self.tool_call_start_token) - current_text.count(
+ self.tool_call_end_token)
+ if open_calls == 0:
+ # Return empty delta message to allow finish_reason processing
+ return DeltaMessage(content="")
+ elif not self.is_tool_call_started and current_text:
+ # This is a regular content response that's now complete
+ return DeltaMessage(content="")
+ return None
+
+ # Check if this is the first call (reset state if needed)
+ if not previous_text:
+ self._reset_streaming_state()
+
+ # Update accumulated text
+ self.accumulated_text = current_text
+
+ # Check if we need to advance to next tool
+ if self.json_closed and not self.in_function:
+ # Check if this tool call has ended
+ tool_ends = current_text.count(self.tool_call_end_token)
+ if tool_ends > self.current_tool_index:
+ # This tool has ended, advance to next
+ self.current_tool_index += 1
+ self.header_sent = False
+ self.param_count = 0
+ self.json_started = False
+ self.json_closed = False
+
+ # Check if there are more tool calls
+ if self.current_tool_index >= current_text.count(
+ self.tool_call_start_token):
+ # No more tool calls
+ self.is_tool_call_started = False
+ # Continue processing next tool
+ return None
+
+ # Check if end thinking
+ if (not self.is_thinking_end
+ and (self.think_end_token_id in delta_token_ids
+ or self.think_end_token in delta_text)):
+ self.is_thinking_end = True
+
+ # If thinking hasn't ended yet, don't process any tool calls
+ if not self.is_thinking_end:
+ return DeltaMessage(content=delta_text)
+
+ # Handle normal content before tool calls
+ if not self.is_tool_call_started:
+ # Check if tool call is starting
+ if (self.tool_call_start_token_id in delta_token_ids
+ or self.tool_call_start_token in delta_text):
+ self.is_tool_call_started = True
+ # Return any content before the tool call
+ if self.tool_call_start_token in delta_text:
+ content_before = delta_text[:delta_text.index(
+ self.tool_call_start_token)]
+ if content_before:
+ return DeltaMessage(content=content_before)
+ return None
+ else:
+ # Check if we're between tool calls - skip whitespace
+ if (current_text.rstrip().endswith(self.tool_call_end_token)
+ and delta_text.strip() == ""):
+ # We just ended a tool call, skip whitespace
+ return None
+ # Normal content, no tool call
+ return DeltaMessage(content=delta_text)
+
+ # Check if we're between tool calls (waiting for next one)
+ # Count tool calls we've seen vs processed
+ tool_starts_count = current_text.count(self.tool_call_start_token)
+ if self.current_tool_index >= tool_starts_count:
+ # We're past all tool calls, shouldn't be here
+ return None
+
+ # We're in a tool call, find the current tool call portion
+ # Need to find the correct tool call based on current_tool_index
+ # Only process tool calls after think_end_token
+ think_end_index = current_text.find(self.think_end_token) + len(
+ self.think_end_token
+ ) if self.think_end_token in current_text else 0
+ tool_starts: list[int] = []
+ idx = think_end_index
+ while True:
+ idx = current_text.find(self.tool_call_start_token, idx)
+ if idx == -1:
+ break
+ tool_starts.append(idx)
+ idx += len(self.tool_call_start_token)
+
+ if self.current_tool_index >= len(tool_starts):
+ # No more tool calls to process yet
+ return None
+
+ tool_start_idx = tool_starts[self.current_tool_index]
+ # Find where this tool call ends (or current position if not ended yet)
+ tool_end_idx = current_text.find(self.tool_call_end_token,
+ tool_start_idx)
+ if tool_end_idx == -1:
+ tool_text = current_text[tool_start_idx:]
+ else:
+ tool_text = current_text[tool_start_idx:tool_end_idx +
+ len(self.tool_call_end_token)]
+
+ # Looking for function header
+ if not self.header_sent:
+ if self.tool_call_prefix in tool_text:
+ func_start = tool_text.find(self.tool_call_prefix) + len(
+ self.tool_call_prefix)
+ func_end = tool_text.find(">", func_start)
+
+ if func_end != -1:
+ # Found complete function name
+ self.current_function_name = tool_text[func_start:func_end]
+ self.current_tool_id = self._generate_tool_call_id(
+ ) # type: ignore
+ self.header_sent = True
+ self.in_function = True
+
+ # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
+ # This ensures finish_reason="tool_calls" even if parsing isn't complete
+ already_added = any(
+ tool.get("name") == self.current_function_name
+ for tool in self.prev_tool_call_arr)
+ if not already_added:
+ self.prev_tool_call_arr.append({
+ "name": self.current_function_name,
+ "arguments":
+ "{}", # Placeholder, will be updated later
+ })
+
+ # Send header with function info
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ id=self.current_tool_id,
+ function=DeltaFunctionCall(
+ name=self.current_function_name, arguments=""),
+ type="function",
+ )
+ ])
+ return None
+
+ # We've sent header, now handle function body
+ if self.in_function:
+ # Send opening brace if not sent yet
+ if (not self.json_started
+ and self.parameter_prefix not in delta_text):
+ self.json_started = True
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(arguments="{"),
+ )
+ ])
+
+ # Make sure json_started is set if we're processing parameters
+ if not self.json_started:
+ self.json_started = True
+
+ # Check for function end in accumulated text
+ if not self.json_closed and self.function_end_token in tool_text:
+ # Close JSON
+ self.json_closed = True
+
+ # Extract the complete tool call to update prev_tool_call_arr with final arguments
+ # Find the function content
+ func_start = tool_text.find(self.tool_call_prefix) + len(
+ self.tool_call_prefix)
+ func_content_end = tool_text.find(self.function_end_token,
+ func_start)
+ if func_content_end != -1:
+ func_content = tool_text[func_start:func_content_end]
+ # Parse to get the complete arguments
+ try:
+ parsed_tool = self._parse_xml_function_call(
+ func_content, request.tools if request else None)
+ if parsed_tool:
+ # Update existing entry in prev_tool_call_arr with complete arguments
+ for i, tool in enumerate(self.prev_tool_call_arr):
+ if tool.get(
+ "name") == parsed_tool.function.name:
+ self.prev_tool_call_arr[i]["arguments"] = (
+ parsed_tool.function.arguments)
+ break
+ except Exception:
+ logger.warning(
+ "Failed to parse tool arguments during streaming.",
+ exc_info=True)
+
+ result = DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(arguments="}"),
+ )
+ ])
+
+ # Reset state for next tool
+ self.in_function = False
+ self.json_closed = True
+
+ return result
+
+ # Look for parameters
+ # Count how many complete parameters we have processed
+ complete_params = tool_text.count(self.parameter_end_token)
+
+ # Check if we should start a new parameter
+ if not self.in_param and self.param_count < complete_params:
+ # Find the unprocessed parameter
+ # Count parameter starts
+ param_starts = []
+ idx = 0
+ while True:
+ idx = tool_text.find(self.parameter_prefix, idx)
+ if idx == -1:
+ break
+ param_starts.append(idx)
+ idx += len(self.parameter_prefix)
+
+ if len(param_starts) > self.param_count:
+ # Process the next parameter
+ param_idx = param_starts[self.param_count]
+ param_start = param_idx + len(self.parameter_prefix)
+ remaining = tool_text[param_start:]
+
+ if ">" in remaining:
+ # We have the complete parameter name
+ name_end = remaining.find(">")
+ self.current_param_name = remaining[:name_end]
+
+ # Find the parameter value
+ value_start = param_start + name_end + 1
+ value_text = tool_text[value_start:]
+ if value_text.startswith("\n"):
+ value_text = value_text[1:]
+
+ # Find where this parameter ends
+ param_end_idx = value_text.find(
+ self.parameter_end_token)
+ if param_end_idx != -1:
+ # Complete parameter found
+ param_value = value_text[:param_end_idx]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ # Build complete JSON fragment for this parameter
+ if self.param_count == 0:
+ json_fragment = (
+ '"' + self.current_param_name + '": "' +
+ json.dumps(param_value)[1:-1] + '"')
+ else:
+ json_fragment = (
+ ', "' + self.current_param_name + '": "' +
+ json.dumps(param_value)[1:-1] + '"')
+
+ self.param_count += 1
+
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(
+ arguments=json_fragment),
+ )
+ ])
+
+ # Continue parameter value
+ if self.in_param:
+ if self.parameter_end_token in delta_text:
+ # End of parameter
+ end_idx = delta_text.find(self.parameter_end_token)
+ value_chunk = delta_text[:end_idx]
+
+ # Skip past > if at start
+ if not self.current_param_value and ">" in value_chunk:
+ gt_idx = value_chunk.find(">")
+ value_chunk = value_chunk[gt_idx + 1:]
+
+ if not self.current_param_value and value_chunk.startswith(
+ "\n"):
+ value_chunk = value_chunk[1:]
+
+ # Calculate incremental JSON
+ full_value = self.current_param_value + value_chunk
+ prev_escaped = (json.dumps(self.current_param_value)[1:-1]
+ if self.current_param_value else "")
+ full_escaped = json.dumps(full_value)[1:-1]
+ delta_escaped = full_escaped[len(prev_escaped):]
+
+ self.in_param = False
+ self.current_param_value = ""
+
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(
+ arguments=delta_escaped + '"'),
+ )
+ ])
+ else:
+ # Continue accumulating value
+ value_chunk = delta_text
+
+ # Handle first chunk after param name
+ if not self.current_param_value and ">" in value_chunk:
+ gt_idx = value_chunk.find(">")
+ value_chunk = value_chunk[gt_idx + 1:]
+
+ if not self.current_param_value and value_chunk.startswith(
+ "\n"):
+ value_chunk = value_chunk[1:]
+
+ if value_chunk:
+ # Stream the escaped delta
+ prev_escaped = (json.dumps(
+ self.current_param_value)[1:-1]
+ if self.current_param_value else "")
+ self.current_param_value += value_chunk
+ full_escaped = json.dumps(
+ self.current_param_value)[1:-1]
+ delta_escaped = full_escaped[len(prev_escaped):]
+
+ if delta_escaped:
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(
+ arguments=delta_escaped),
+ )
+ ])
+
+ return None
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 28d7e93af91a..465c25f09480 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -130,6 +130,7 @@
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
+ "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py
new file mode 100644
index 000000000000..34a87a6a69a3
--- /dev/null
+++ b/vllm/model_executor/models/seed_oss.py
@@ -0,0 +1,487 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The Seed team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only SeedOss model compatible with HuggingFace weights."""
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig as SeedOssConfig
+
+from vllm.attention import Attention, AttentionType
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsLoRA, SupportsPP
+from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+
+logger = init_logger(__name__)
+
+
+class SeedOssMLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
+ )
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class SeedOssAttention(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ max_position: int = 4096 * 32,
+ rope_theta: float = 10000,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ rope_scaling: Optional[tuple] = None,
+ prefix: str = "",
+ attn_type: str = AttentionType.DECODER,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=self.rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ attn_type=attn_type,
+ prefix=f"{prefix}.attn",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class SeedOssDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: SeedOssConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ # Requires transformers > 4.32.0
+ rope_theta = getattr(config, "rope_theta", 1000000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+
+ # By default, SeedOss uses causal attention as it is a
+ # decoder-only model.
+ # You can override the HF config with `is_causal=False` to enable
+ # bidirectional attention, which is used in some embedding models
+ if getattr(config, "is_causal", True):
+ attn_type = AttentionType.DECODER
+ else:
+ attn_type = AttentionType.ENCODER_ONLY
+
+ self.self_attn = SeedOssAttention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ max_position=config.max_position_embeddings,
+ num_kv_heads=config.num_key_value_heads,
+ head_dim=config.head_dim,
+ rope_theta=rope_theta,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ rope_scaling=rope_scaling,
+ prefix=f"{prefix}.self_attn",
+ attn_type=attn_type,
+ )
+ self.mlp = SeedOssMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ })
+class SeedOssModel(nn.Module):
+
+ def __init__(self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+
+ # TODO (@robertgshaw2): see if this can be moved out
+ if (cache_config.sliding_window is not None
+ and hasattr(config, "max_window_layers")):
+ assert config.max_window_layers == config.num_hidden_layers, (
+ "Sliding window for some but all layers is not supported. "
+ "This model uses sliding window but `max_window_layers` = {} "
+ "is less than `num_hidden_layers` = {}. Please open an issue "
+ "to discuss this feature.".format(
+ config.max_window_layers,
+ config.num_hidden_layers,
+ ))
+
+ self.config = config
+ self.quant_config = quant_config
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
+ and get_pp_group().is_last_rank):
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ # Use the provided decoder layer type or default to SeedDecoderLayer
+ decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: decoder_layer_type(config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix),
+ prefix=f"{prefix}.layers",
+ )
+
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+ for layer in self.layers[self.start_layer:self.end_layer]:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ )
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ if (self.quant_config is not None and
+ (scale_name := self.quant_config.get_cache_scale(name))):
+ # Loading kv cache quantization scales
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
+ loaded_weight[0])
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ lora_config = vllm_config.lora_config
+
+ self.config = config
+ self.lora_config = lora_config
+
+ self.quant_config = quant_config
+ self.model = SeedOssModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+
+ if get_pp_group().is_last_rank:
+ if config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(
+ prefix, "lm_head"))
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."]
+ if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights)