diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c9fa03a1ae1f..a32dd8263992 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -23,20 +23,46 @@ ] +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) @pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_json_completion( +def test_structured_output( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], + unsupported_json_schema: dict[str, Any], + sample_sql_ebnf: str, + sample_sql_lark: str, + sample_regex: str, + sample_guided_choice: str, guided_decoding_backend: str, model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") + + # Use a single LLM instance for several scenarios to + # speed up the test suite. llm = LLM(model=model_name, + enforce_eager=True, max_model_len=1024, guided_decoding_backend=guided_decoding_backend) + + # + # Test 1: Generate JSON output based on a provided schema + # sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, @@ -63,20 +89,9 @@ def test_guided_json_completion( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_json_object( - monkeypatch: pytest.MonkeyPatch, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 2: Generate JSON object without a schema + # sampling_params = SamplingParams( temperature=1.0, max_tokens=100, @@ -111,21 +126,9 @@ def test_guided_json_object( allowed_types = (dict, list) assert isinstance(parsed_json, allowed_types) - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1 + ["auto"]) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_json_unsupported_schema( - monkeypatch: pytest.MonkeyPatch, - unsupported_json_schema: dict[str, Any], - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 3: test a jsonschema incompatible with xgrammar + # sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, @@ -141,8 +144,6 @@ def test_guided_json_unsupported_schema( sampling_params=sampling_params, use_tqdm=True) else: - # This should work for both "guidance" and "auto". - outputs = llm.generate( prompts=("Give an example JSON object for a grade " "that fits this schema: " @@ -161,21 +162,9 @@ def test_guided_json_unsupported_schema( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_grammar_ebnf( - monkeypatch: pytest.MonkeyPatch, - sample_sql_ebnf: str, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 4: Generate SQL statement using EBNF grammar + # sampling_params = SamplingParams( temperature=0.8, top_p=0.95, @@ -205,21 +194,9 @@ def test_guided_grammar_ebnf( print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_grammar_lark( - monkeypatch: pytest.MonkeyPatch, - sample_sql_lark: str, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 5: Generate SQL statement using Lark grammar + # sampling_params = SamplingParams( temperature=0.8, top_p=0.95, @@ -254,20 +231,9 @@ def test_guided_grammar_lark( print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_grammar_ebnf_invalid( - monkeypatch: pytest.MonkeyPatch, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 6: Test invalid grammar input + # sampling_params = SamplingParams( temperature=0.8, top_p=0.95, @@ -281,21 +247,9 @@ def test_guided_grammar_ebnf_invalid( use_tqdm=True, ) - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_regex( - monkeypatch: pytest.MonkeyPatch, - sample_regex: str, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 7: Generate text based on a regex pattern + # sampling_params = SamplingParams( temperature=0.8, top_p=0.95, @@ -319,21 +273,9 @@ def test_guided_regex( assert re.fullmatch(sample_regex, generated_text) is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_choice_completion( - monkeypatch: pytest.MonkeyPatch, - sample_guided_choice: str, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 8: Generate text based on a choices + # sampling_params = SamplingParams( temperature=0.8, top_p=0.95, @@ -353,33 +295,9 @@ def test_guided_choice_completion( assert generated_text in sample_guided_choice print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -class CarType(str, Enum): - sedan = "sedan" - suv = "SUV" - truck = "Truck" - coupe = "Coupe" - - -class CarDescription(BaseModel): - brand: str - model: str - car_type: CarType - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_json_completion_with_enum( - monkeypatch: pytest.MonkeyPatch, - guided_decoding_backend: str, - model_name: str, -): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) + # + # Test 9: Generate structured output using a Pydantic model with an enum + # json_schema = CarDescription.model_json_schema() sampling_params = SamplingParams( temperature=1.0, @@ -403,3 +321,41 @@ def test_guided_json_completion_with_enum( print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("model_name", MODELS_TO_TEST) +def test_structured_output_auto_mode( + monkeypatch: pytest.MonkeyPatch, + unsupported_json_schema: dict[str, Any], + model_name: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend="auto") + + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) + + # This would fail with the default of "xgrammar", but in "auto" + # we will handle fallback automatically. + outputs = llm.generate(prompts=("Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}"), + sampling_params=sampling_params, + use_tqdm=True) + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + generated_text = output.outputs[0].text + assert generated_text is not None + print(generated_text) + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict)