Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ def test_sampler_different(model_name: str):
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text

with pytest.raises(ValueError):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)
13 changes: 8 additions & 5 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import vllm.envs as envs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, SamplingType

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -145,7 +145,10 @@ def validate_request(
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params,
SamplingParams) and params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
if isinstance(params, SamplingParams):
if params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
16 changes: 10 additions & 6 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata:
# Greedy sampling flag for compiling single xla graph.
all_greedy: bool = True

# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())

# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None

Expand All @@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask = None
bad_words_token_ids = None

# Generator not supported by xla
_generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())

@property
def generators(self) -> dict[int, torch.Generator]:
# Generator not supported by torch/xla. This field must be immutable.
return self._generators

@classmethod
def from_input_batch(
cls,
Expand Down Expand Up @@ -109,5 +114,4 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device),
generators=input_batch.generators)
xla_device))
8 changes: 1 addition & 7 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
Expand Down Expand Up @@ -267,11 +266,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None

self.requests[req_id] = CachedRequestState(
req_id=req_id,
Expand All @@ -280,7 +274,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
Expand Down