|
4 | 4 |
|
5 | 5 | import json |
6 | 6 | import re |
| 7 | +from enum import Enum |
7 | 8 | from typing import Any |
8 | 9 |
|
9 | 10 | import jsonschema |
10 | 11 | import pytest |
| 12 | +from pydantic import BaseModel |
11 | 13 |
|
12 | 14 | from vllm.entrypoints.llm import LLM |
13 | 15 | from vllm.outputs import RequestOutput |
@@ -390,3 +392,54 @@ def test_guided_choice_completion( |
390 | 392 | assert generated_text is not None |
391 | 393 | assert generated_text in sample_guided_choice |
392 | 394 | print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 395 | + |
| 396 | + |
| 397 | +class CarType(str, Enum): |
| 398 | + sedan = "sedan" |
| 399 | + suv = "SUV" |
| 400 | + truck = "Truck" |
| 401 | + coupe = "Coupe" |
| 402 | + |
| 403 | + |
| 404 | +class CarDescription(BaseModel): |
| 405 | + brand: str |
| 406 | + model: str |
| 407 | + car_type: CarType |
| 408 | + |
| 409 | + |
| 410 | +@pytest.mark.skip_global_cleanup |
| 411 | +@pytest.mark.parametrize("guided_decoding_backend", |
| 412 | + GUIDED_DECODING_BACKENDS_V1) |
| 413 | +@pytest.mark.parametrize("model_name", MODELS_TO_TEST) |
| 414 | +def test_guided_json_completion_with_enum( |
| 415 | + monkeypatch: pytest.MonkeyPatch, |
| 416 | + guided_decoding_backend: str, |
| 417 | + model_name: str, |
| 418 | +): |
| 419 | + monkeypatch.setenv("VLLM_USE_V1", "1") |
| 420 | + llm = LLM(model=model_name, |
| 421 | + max_model_len=1024, |
| 422 | + guided_decoding_backend=guided_decoding_backend) |
| 423 | + json_schema = CarDescription.model_json_schema() |
| 424 | + sampling_params = SamplingParams( |
| 425 | + temperature=1.0, |
| 426 | + max_tokens=1000, |
| 427 | + guided_decoding=GuidedDecodingParams(json=json_schema)) |
| 428 | + outputs = llm.generate( |
| 429 | + prompts="Generate a JSON with the brand, model and car_type of" |
| 430 | + "the most iconic car from the 90's", |
| 431 | + sampling_params=sampling_params, |
| 432 | + use_tqdm=True) |
| 433 | + |
| 434 | + assert outputs is not None |
| 435 | + |
| 436 | + for output in outputs: |
| 437 | + assert output is not None |
| 438 | + assert isinstance(output, RequestOutput) |
| 439 | + prompt = output.prompt |
| 440 | + |
| 441 | + generated_text = output.outputs[0].text |
| 442 | + assert generated_text is not None |
| 443 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 444 | + output_json = json.loads(generated_text) |
| 445 | + jsonschema.validate(instance=output_json, schema=json_schema) |
0 commit comments