From 8f5c2cec184db714925f50e50f9d1c25829af16c Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Mon, 7 Apr 2025 02:38:36 +0000 Subject: [PATCH 1/2] add supports_structured_output() method to platform Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/platforms/cpu.py | 4 ++++ vllm/platforms/cuda.py | 4 ++++ vllm/platforms/hpu.py | 4 ++++ vllm/platforms/interface.py | 7 +++++++ vllm/platforms/neuron.py | 4 ++++ vllm/platforms/rocm.py | 4 ++++ vllm/platforms/tpu.py | 5 +++++ vllm/platforms/xpu.py | 4 ++++ vllm/v1/engine/processor.py | 7 ++++--- 9 files changed, 40 insertions(+), 3 deletions(-) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 67466bdb9807..cfd7bc2a4057 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -180,3 +180,7 @@ def get_device_communicator_cls(cls) -> str: Get device specific communicator class for distributed communication. """ return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa + + @classmethod + def supports_structured_output(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0576022be448..053cf74ebceb 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -308,6 +308,10 @@ def supports_fp8(cls) -> bool: def supports_v1(cls, model_config: ModelConfig) -> bool: return True + @classmethod + def supports_structured_output(cls) -> bool: + return True + @classmethod def use_custom_allreduce(cls) -> bool: return True diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 4c842b525110..f011f14029a3 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -92,3 +92,7 @@ def get_punica_wrapper(cls) -> str: @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa + + @classmethod + def supports_structured_output(cls) -> bool: + return True diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b6f6029de9c8..2bb543bd73f7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: """ return False + @classmethod + def supports_structured_output(cls) -> bool: + """ + Returns whether the current platform can support structured output. + """ + return False + @classmethod def use_custom_allreduce(cls) -> bool: """ diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index c1f426e5b880..93657881cbdd 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -67,3 +67,7 @@ def get_device_communicator_cls(cls) -> str: @classmethod def use_all_gather(cls) -> bool: return True + + @classmethod + def supports_structured_output(cls) -> bool: + return True diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d18b7c26f7ec..a2fbf416ecf2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -303,6 +303,10 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: # V1 support on AMD gpus is experimental return True + @classmethod + def supports_structured_output(cls) -> bool: + return True + @classmethod def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 43d3044cb93e..9b4156529f08 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -133,3 +133,8 @@ def use_all_gather(cls) -> bool: def supports_v1(cls, model_config: ModelConfig) -> bool: # V1 support on TPU is experimental return True + + @classmethod + def supports_structured_output(cls) -> bool: + logger.warning("Structured output is not supported on TPU.") + return False diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 225e756cd7ce..c4bd639384a4 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -140,3 +140,7 @@ def device_support_bf16(cls) -> bool: @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa + + @classmethod + def supports_structured_output(cls) -> bool: + return True diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 0d2892837eb2..c914ccad52e1 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -136,9 +136,10 @@ def _validate_structured_output(self, params: SamplingParams) -> None: f" != {engine_level_backend}") else: params.guided_decoding.backend = engine_level_backend - import vllm.platforms - if vllm.platforms.current_platform.is_tpu(): - raise ValueError("Structured output is not supported on TPU.") + + from vllm.platforms import current_platform + if not current_platform.supports_structured_output(): + return # Request content validation if engine_level_backend.startswith("xgrammar"): From cbb079be1bedc25fe0cfa876d5bae31cf80caef4 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Mon, 7 Apr 2025 02:57:06 +0000 Subject: [PATCH 2/2] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/platforms/tpu.py | 2 +- vllm/v1/engine/processor.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 9b4156529f08..eeadb4a71e5e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -136,5 +136,5 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: @classmethod def supports_structured_output(cls) -> bool: - logger.warning("Structured output is not supported on TPU.") + # Structured output is not supported on TPU. return False diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index c914ccad52e1..403edddfcbee 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -139,7 +139,8 @@ def _validate_structured_output(self, params: SamplingParams) -> None: from vllm.platforms import current_platform if not current_platform.supports_structured_output(): - return + raise ValueError("Structured output is not supported on " + f"{current_platform.device_name}.") # Request content validation if engine_level_backend.startswith("xgrammar"):