diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 719254a398c0..19a5ca5e2750 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -418,6 +418,7 @@ def run_test_case(*, expected_penalization: List[bool], prompt_len = seq_data.get_prompt_len() seq_lens.append(prompt_len) + assert sgm.sampling_params is not None if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in # logits @@ -533,6 +534,8 @@ def test_sampling(): for i, (sequence_output, metadata) in enumerate( zip(sampler_output, seq_group_metadata_list)): + assert metadata.sampling_params is not None + if metadata.sampling_params.use_beam_search: continue @@ -550,6 +553,8 @@ def test_sampling(): assert expected_tokens_item is not None for n, nth_output in enumerate(sequence_output.samples): + assert metadata.sampling_params is not None + if (metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None): # Ensure exact matches for greedy or random with seed diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index b00a61ebfec6..49bb6aeee90b 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -19,7 +19,9 @@ def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]: audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR) - return librosa.load(audio_path, sr=None) + y, sr = librosa.load(audio_path, sr=None) + assert isinstance(sr, int) + return y, sr @property def url(self) -> str: diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index dc316ca1160c..a472e12e8ca4 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -101,6 +101,7 @@ def __init__(self, rpc_path: str): # Maximum number of sockets that can be opened (typically 65536). # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) + assert isinstance(socket_limit, int) if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: raise ValueError( f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " @@ -141,8 +142,8 @@ async def run_proxy(self, socket_from, socket_to): poller.register(socket_from, zmq.constants.POLLIN) poller.register(socket_to, zmq.constants.POLLIN) while True: - events = await poller.poll() - events = dict(events) + events_lst = await poller.poll() + events = dict(events_lst) if socket_from in events: identity, msg = await socket_from.recv_multipart() await socket_to.send_multipart([identity, msg]) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 5b00117c64e5..f26e3292c264 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -14,7 +14,7 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import json_map_leaves +from vllm.utils import JSONTree, is_list_of, json_map_leaves logger = init_logger(__name__) @@ -54,13 +54,14 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return nested_tensors stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if any(isinstance(t, list) for t in stacked): + if is_list_of(stacked, list): + # Do not stack nested lists return stacked tensors_ = cast(List[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. - return tensors_ + return stacked return torch.stack(tensors_) @@ -101,8 +102,14 @@ def as_kwargs( *, device: torch.types.Device, ) -> BatchedTensorInputs: - return json_map_leaves(lambda x: x.to(device, non_blocking=True), - batched_inputs) + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) _T = TypeVar("_T") diff --git a/vllm/sequence.py b/vllm/sequence.py index 964072dd7c8f..f289a9aec80c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -883,7 +883,7 @@ class SequenceGroupMetadata( request_id: str is_prompt: bool seq_data: Dict[int, SequenceData] - sampling_params: SamplingParams + sampling_params: Optional[SamplingParams] block_tables: Dict[int, List[int]] do_sample: bool = True pooling_params: Optional[PoolingParams] = None