Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,3 @@ def get_device_communicator_cls(cls) -> str:
Get device specific communicator class for distributed communication.
"""
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa

@classmethod
def supports_structured_output(cls) -> bool:
return True
4 changes: 0 additions & 4 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def supports_fp8(cls) -> bool:
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True

@classmethod
def supports_structured_output(cls) -> bool:
return True

@classmethod
def use_custom_allreduce(cls) -> bool:
return True
Expand Down
4 changes: 0 additions & 4 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,3 @@ def get_punica_wrapper(cls) -> str:
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa

@classmethod
def supports_structured_output(cls) -> bool:
return True
23 changes: 15 additions & 8 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0

import enum
import platform
import random
Expand All @@ -9,14 +8,21 @@
import numpy as np
import torch

from vllm.inputs import PromptType
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
LoRARequest = None
PoolingParams = None
SamplingParams = None
FlexibleArgumentParser = None

logger = init_logger(__name__)
Expand Down Expand Up @@ -379,20 +385,21 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
"""
return False

@classmethod
def supports_structured_output(cls) -> bool:
"""
Returns whether the current platform can support structured output.
"""
return False

@classmethod
def use_custom_allreduce(cls) -> bool:
"""
Returns if custom allreduce is supported on the current platform
"""
return False

@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
4 changes: 0 additions & 4 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,3 @@ def get_device_communicator_cls(cls) -> str:
@classmethod
def use_all_gather(cls) -> bool:
return True

@classmethod
def supports_structured_output(cls) -> bool:
return True
4 changes: 0 additions & 4 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,6 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on AMD gpus is experimental
return True

@classmethod
def supports_structured_output(cls) -> bool:
return True

@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
Expand Down
22 changes: 18 additions & 4 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import torch

import vllm.envs as envs
from vllm.inputs import PromptType
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
else:
ModelConfig = None
VllmConfig = None
LoRARequest = None
PoolingParams = None
SamplingParams = None

logger = init_logger(__name__)

Expand Down Expand Up @@ -135,6 +142,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
return True

@classmethod
def supports_structured_output(cls) -> bool:
# Structured output is not supported on TPU.
return False
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params,
SamplingParams) and params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
4 changes: 0 additions & 4 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,3 @@ def device_support_bf16(cls) -> bool:
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa

@classmethod
def supports_structured_output(cls) -> bool:
return True
10 changes: 5 additions & 5 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,6 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
else:
params.guided_decoding.backend = engine_level_backend

from vllm.platforms import current_platform
if not current_platform.supports_structured_output():
raise ValueError("Structured output is not supported on "
f"{current_platform.device_name}.")

# Request content validation
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
Expand Down Expand Up @@ -183,6 +178,11 @@ def process_inputs(
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.

from vllm.platforms import current_platform
current_platform.validate_request(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering whether we should remove the call to supports_structured_output and have the default impl of validate_request call that instead. Actually maybe we could remove the supports_structured_output interface method and have validate_request only call it if it exists in the same class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Maybe the simplest thing to do is to just add an impl for the TPU backend and have it reject structured output requests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill I went with the 🔥🔥🔥 option- WDYT?
The only difference in behavior now should be that all out-of-tree platforms will need to explicitly reject structured output in validate_request instead of inheriting the default impl of supports_structured_output

prompt=prompt,
params=params,
)
self._validate_lora(lora_request)
self._validate_params(params)
if priority != 0:
Expand Down