Skip to content

Commit f03dd93

Browse files
shen-shanshanyangw-dev
authored andcommitted
[V1][Structured Output] Add supports_structured_output() method to Platform (vllm-project#16148)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Yang Wang <elainewy@meta.com>
1 parent de4c7ec commit f03dd93

File tree

9 files changed

+41
-3
lines changed

9 files changed

+41
-3
lines changed

vllm/platforms/cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,7 @@ def get_device_communicator_cls(cls) -> str:
180180
Get device specific communicator class for distributed communication.
181181
"""
182182
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
183+
184+
@classmethod
185+
def supports_structured_output(cls) -> bool:
186+
return True

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def supports_fp8(cls) -> bool:
308308
def supports_v1(cls, model_config: ModelConfig) -> bool:
309309
return True
310310

311+
@classmethod
312+
def supports_structured_output(cls) -> bool:
313+
return True
314+
311315
@classmethod
312316
def use_custom_allreduce(cls) -> bool:
313317
return True

vllm/platforms/hpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ def get_punica_wrapper(cls) -> str:
9292
@classmethod
9393
def get_device_communicator_cls(cls) -> str:
9494
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
95+
96+
@classmethod
97+
def supports_structured_output(cls) -> bool:
98+
return True

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
379379
"""
380380
return False
381381

382+
@classmethod
383+
def supports_structured_output(cls) -> bool:
384+
"""
385+
Returns whether the current platform can support structured output.
386+
"""
387+
return False
388+
382389
@classmethod
383390
def use_custom_allreduce(cls) -> bool:
384391
"""

vllm/platforms/neuron.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,7 @@ def get_device_communicator_cls(cls) -> str:
6767
@classmethod
6868
def use_all_gather(cls) -> bool:
6969
return True
70+
71+
@classmethod
72+
def supports_structured_output(cls) -> bool:
73+
return True

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
303303
# V1 support on AMD gpus is experimental
304304
return True
305305

306+
@classmethod
307+
def supports_structured_output(cls) -> bool:
308+
return True
309+
306310
@classmethod
307311
def use_custom_allreduce(cls) -> bool:
308312
# We only enable custom allreduce for MI300 series

vllm/platforms/tpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,8 @@ def use_all_gather(cls) -> bool:
133133
def supports_v1(cls, model_config: ModelConfig) -> bool:
134134
# V1 support on TPU is experimental
135135
return True
136+
137+
@classmethod
138+
def supports_structured_output(cls) -> bool:
139+
# Structured output is not supported on TPU.
140+
return False

vllm/platforms/xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,7 @@ def device_support_bf16(cls) -> bool:
140140
@classmethod
141141
def get_device_communicator_cls(cls) -> str:
142142
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
143+
144+
@classmethod
145+
def supports_structured_output(cls) -> bool:
146+
return True

vllm/v1/engine/processor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,11 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
136136
f" != {engine_level_backend}")
137137
else:
138138
params.guided_decoding.backend = engine_level_backend
139-
import vllm.platforms
140-
if vllm.platforms.current_platform.is_tpu():
141-
raise ValueError("Structured output is not supported on TPU.")
139+
140+
from vllm.platforms import current_platform
141+
if not current_platform.supports_structured_output():
142+
raise ValueError("Structured output is not supported on "
143+
f"{current_platform.device_name}.")
142144

143145
# Request content validation
144146
if engine_level_backend.startswith("xgrammar"):

0 commit comments

Comments
 (0)