Skip to content

Commit 2999d38

Browse files
hmellorxuebwang-amd
authored andcommitted
Add backward compatibility for GuidedDecodingParams (vllm-project#25422)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 860ecfd commit 2999d38

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import json
8+
from dataclasses import fields
89
from enum import Enum
910
from typing import TYPE_CHECKING, Any
1011

@@ -21,7 +22,8 @@
2122
from vllm.outputs import RequestOutput
2223
from vllm.platforms import current_platform
2324
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
24-
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
25+
from vllm.sampling_params import (GuidedDecodingParams, SamplingParams,
26+
StructuredOutputsParams)
2527

2628
if TYPE_CHECKING:
2729
from vllm.config import TokenizerMode
@@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str:
8991
return json.loads(s)
9092

9193

94+
def test_guided_decoding_deprecated():
95+
with pytest.warns(DeprecationWarning,
96+
match="GuidedDecodingParams is deprecated.*"):
97+
guided_decoding = GuidedDecodingParams(json_object=True)
98+
99+
structured_outputs = StructuredOutputsParams(json_object=True)
100+
assert fields(guided_decoding) == fields(structured_outputs)
101+
102+
with pytest.warns(DeprecationWarning,
103+
match="guided_decoding is deprecated.*"):
104+
sp1 = SamplingParams(guided_decoding=guided_decoding)
105+
106+
with pytest.warns(DeprecationWarning,
107+
match="guided_decoding is deprecated.*"):
108+
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
109+
110+
assert sp1 == sp2
111+
assert sp1.structured_outputs == guided_decoding
112+
113+
92114
@pytest.mark.skip_global_cleanup
93115
@pytest.mark.parametrize(
94116
"model_name, backend, tokenizer_mode, speculative_config",

vllm/sampling_params.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Sampling parameters for text generation."""
44
import copy
5+
import warnings
56
from dataclasses import field
67
from enum import Enum, IntEnum
78
from functools import cached_property
@@ -59,6 +60,19 @@ def __post_init__(self):
5960
f"but multiple are specified: {self.__dict__}")
6061

6162

63+
@dataclass
64+
class GuidedDecodingParams(StructuredOutputsParams):
65+
66+
def __post_init__(self):
67+
warnings.warn(
68+
"GuidedDecodingParams is deprecated. This will be removed in "
69+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
70+
"StructuredOutputsParams instead.",
71+
DeprecationWarning,
72+
stacklevel=2)
73+
return super().__post_init__()
74+
75+
6276
class RequestOutputKind(Enum):
6377
# Return entire output so far in every RequestOutput
6478
CUMULATIVE = 0
@@ -179,6 +193,8 @@ class SamplingParams(
179193
# Fields used to construct logits processors
180194
structured_outputs: Optional[StructuredOutputsParams] = None
181195
"""Parameters for configuring structured outputs."""
196+
guided_decoding: Optional[GuidedDecodingParams] = None
197+
"""Deprecated alias for structured_outputs."""
182198
logit_bias: Optional[dict[int, float]] = None
183199
"""If provided, the engine will construct a logits processor that applies
184200
these logit biases."""
@@ -227,6 +243,7 @@ def from_optional(
227243
ge=-1)]] = None,
228244
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
229245
structured_outputs: Optional[StructuredOutputsParams] = None,
246+
guided_decoding: Optional[GuidedDecodingParams] = None,
230247
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
231248
allowed_token_ids: Optional[list[int]] = None,
232249
extra_args: Optional[dict[str, Any]] = None,
@@ -238,6 +255,15 @@ def from_optional(
238255
int(token): min(100.0, max(-100.0, bias))
239256
for token, bias in logit_bias.items()
240257
}
258+
if guided_decoding is not None:
259+
warnings.warn(
260+
"guided_decoding is deprecated. This will be removed in "
261+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
262+
"structured_outputs instead.",
263+
DeprecationWarning,
264+
stacklevel=2)
265+
structured_outputs = guided_decoding
266+
guided_decoding = None
241267

242268
return SamplingParams(
243269
n=1 if n is None else n,
@@ -334,6 +360,16 @@ def __post_init__(self) -> None:
334360
# eos_token_id is added to this by the engine
335361
self._all_stop_token_ids.update(self.stop_token_ids)
336362

363+
if self.guided_decoding is not None:
364+
warnings.warn(
365+
"guided_decoding is deprecated. This will be removed in "
366+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
367+
"structured_outputs instead.",
368+
DeprecationWarning,
369+
stacklevel=2)
370+
self.structured_outputs = self.guided_decoding
371+
self.guided_decoding = None
372+
337373
def _verify_args(self) -> None:
338374
if not isinstance(self.n, int):
339375
raise ValueError(f"n must be an int, but is of "

0 commit comments

Comments
 (0)