From 9ec7338a842bb00fbf488439bf78ec0cc9b7cfda Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 8 Oct 2024 08:31:26 -0600 Subject: [PATCH] [Frontend] Add Early Validation For Chat Template / Tool Call Parser (#9151) Signed-off-by: Alex-Brooks Signed-off-by: Alvant --- tests/entrypoints/openai/test_cli_args.py | 178 +++++++++++++--------- vllm/entrypoints/chat_utils.py | 22 +++ vllm/entrypoints/openai/api_server.py | 4 +- vllm/entrypoints/openai/cli_args.py | 15 ++ vllm/scripts.py | 8 +- 5 files changed, 155 insertions(+), 72 deletions(-) diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 8ee7fb8b2c6bf..45e6980a94630 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -1,91 +1,131 @@ import json -import unittest -from vllm.entrypoints.openai.cli_args import make_arg_parser +import pytest + +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) from vllm.entrypoints.openai.serving_engine import LoRAModulePath from vllm.utils import FlexibleArgumentParser +from ...utils import VLLM_PATH + LORA_MODULE = { "name": "module2", "path": "/path/to/module2", "base_model_name": "llama" } +CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja" +assert CHATML_JINJA_PATH.exists() -class TestLoraParserAction(unittest.TestCase): +@pytest.fixture +def serve_parser(): + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + return make_arg_parser(parser) - def setUp(self): - # Setting up argparse parser for tests - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") - self.parser = make_arg_parser(parser) - def test_valid_key_value_format(self): - # Test old format: name=path - args = self.parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', +### Tests for Lora module parsing +def test_valid_key_value_format(serve_parser): + # Test old format: name=path + args = serve_parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + ]) + expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + assert args.lora_modules == expected + + +def test_valid_json_format(serve_parser): + # Test valid JSON format input + args = serve_parser.parse_args([ + '--lora-modules', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + assert args.lora_modules == expected + + +def test_invalid_json_format(serve_parser): + # Test invalid JSON format input, missing closing brace + with pytest.raises(SystemExit): + serve_parser.parse_args([ + '--lora-modules', '{"name": "module3", "path": "/path/to/module3"' ]) - expected = [LoRAModulePath(name='module1', path='/path/to/module1')] - self.assertEqual(args.lora_modules, expected) - def test_valid_json_format(self): - # Test valid JSON format input - args = self.parser.parse_args([ + +def test_invalid_type_error(serve_parser): + # Test type error when values are not JSON or key=value + with pytest.raises(SystemExit): + serve_parser.parse_args([ '--lora-modules', - json.dumps(LORA_MODULE), + 'invalid_format' # This is not JSON or key=value format ]) - expected = [ - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') - ] - self.assertEqual(args.lora_modules, expected) - - def test_invalid_json_format(self): - # Test invalid JSON format input, missing closing brace - with self.assertRaises(SystemExit): - self.parser.parse_args([ - '--lora-modules', - '{"name": "module3", "path": "/path/to/module3"' - ]) - - def test_invalid_type_error(self): - # Test type error when values are not JSON or key=value - with self.assertRaises(SystemExit): - self.parser.parse_args([ - '--lora-modules', - 'invalid_format' # This is not JSON or key=value format - ]) - - def test_invalid_json_field(self): - # Test valid JSON format but missing required fields - with self.assertRaises(SystemExit): - self.parser.parse_args([ - '--lora-modules', - '{"name": "module4"}' # Missing required 'path' field - ]) - - def test_empty_values(self): - # Test when no LoRA modules are provided - args = self.parser.parse_args(['--lora-modules', '']) - self.assertEqual(args.lora_modules, []) - - def test_multiple_valid_inputs(self): - # Test multiple valid inputs (both old and JSON format) - args = self.parser.parse_args([ + + +def test_invalid_json_field(serve_parser): + # Test valid JSON format but missing required fields + with pytest.raises(SystemExit): + serve_parser.parse_args([ '--lora-modules', - 'module1=/path/to/module1', - json.dumps(LORA_MODULE), + '{"name": "module4"}' # Missing required 'path' field ]) - expected = [ - LoRAModulePath(name='module1', path='/path/to/module1'), - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') - ] - self.assertEqual(args.lora_modules, expected) -if __name__ == '__main__': - unittest.main() +def test_empty_values(serve_parser): + # Test when no LoRA modules are provided + args = serve_parser.parse_args(['--lora-modules', '']) + assert args.lora_modules == [] + + +def test_multiple_valid_inputs(serve_parser): + # Test multiple valid inputs (both old and JSON format) + args = serve_parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module1', path='/path/to/module1'), + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + assert args.lora_modules == expected + + +### Tests for serve argument validation that run prior to loading +def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser): + """Ensure validation fails if tool choice is enabled with no call parser""" + # If we enable-auto-tool-choice, explode with no tool-call-parser + args = serve_parser.parse_args(args=["--enable-auto-tool-choice"]) + with pytest.raises(TypeError): + validate_parsed_serve_args(args) + + +def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser): + """Ensure validation passes with tool choice enabled with a call parser""" + args = serve_parser.parse_args(args=[ + "--enable-auto-tool-choice", + "--tool-call-parser", + "mistral", + ]) + validate_parsed_serve_args(args) + + +def test_chat_template_validation_for_happy_paths(serve_parser): + """Ensure validation passes if the chat template exists""" + args = serve_parser.parse_args( + args=["--chat-template", + CHATML_JINJA_PATH.absolute().as_posix()]) + validate_parsed_serve_args(args) + + +def test_chat_template_validation_for_sad_paths(serve_parser): + """Ensure validation fails if the chat template doesn't exist""" + args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"]) + with pytest.raises(ValueError): + validate_parsed_serve_args(args) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 83c4062dd5112..1b82b454aa38d 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -303,6 +303,28 @@ def parse_audio(self, audio_url: str) -> None: self._add_placeholder(placeholder) +def validate_chat_template(chat_template: Optional[Union[Path, str]]): + """Raises if the provided chat template appears invalid.""" + if chat_template is None: + return + + elif isinstance(chat_template, Path) and not chat_template.exists(): + raise FileNotFoundError( + "the supplied chat template path doesn't exist") + + elif isinstance(chat_template, str): + JINJA_CHARS = "{}\n" + if not any(c in chat_template + for c in JINJA_CHARS) and not Path(chat_template).exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist!") + + else: + raise TypeError( + f"{type(chat_template)} is not a valid chat template type") + + def load_chat_template( chat_template: Optional[Union[Path, str]]) -> Optional[str]: if chat_template is None: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bf367482cd80c..cda1601549e9e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -31,7 +31,8 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -577,5 +578,6 @@ def signal_handler(*_) -> None: description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() + validate_parsed_serve_args(args) uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index f59ba4e30accd..a089985ac9758 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -10,6 +10,7 @@ from typing import List, Optional, Sequence, Union from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.chat_utils import validate_chat_template from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -231,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: return parser +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + + def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( prog="-m vllm.entrypoints.openai.api_server") diff --git a/vllm/scripts.py b/vllm/scripts.py index 7f2ba62695d3e..4e4c071784287 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -11,7 +11,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.openai.api_server import run_server -from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser @@ -142,7 +143,7 @@ def main(): env_setup() parser = FlexibleArgumentParser(description="vLLM CLI") - subparsers = parser.add_subparsers(required=True) + subparsers = parser.add_subparsers(required=True, dest="subparser") serve_parser = subparsers.add_parser( "serve", @@ -186,6 +187,9 @@ def main(): chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") args = parser.parse_args() + if args.subparser == "serve": + validate_parsed_serve_args(args) + # One of the sub commands should be executed. if hasattr(args, "dispatch_function"): args.dispatch_function(args)