Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def run_test_case(*, expected_penalization: List[bool],
prompt_len = seq_data.get_prompt_len()
seq_lens.append(prompt_len)

assert sgm.sampling_params is not None
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
Expand Down Expand Up @@ -533,6 +534,8 @@ def test_sampling():

for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None

if metadata.sampling_params.use_beam_search:
continue

Expand All @@ -550,6 +553,8 @@ def test_sampling():
assert expected_tokens_item is not None

for n, nth_output in enumerate(sequence_output.samples):
assert metadata.sampling_params is not None

if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
Expand Down
4 changes: 3 additions & 1 deletion vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:

audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr

@property
def url(self) -> str:
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, rpc_path: str):
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
assert isinstance(socket_limit, int)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
Expand Down Expand Up @@ -141,8 +142,8 @@ async def run_proxy(self, socket_from, socket_to):
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
events_lst = await poller.poll()
events = dict(events_lst)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
Expand Down
17 changes: 12 additions & 5 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import json_map_leaves
from vllm.utils import JSONTree, is_list_of, json_map_leaves

logger = init_logger(__name__)

Expand Down Expand Up @@ -54,13 +54,14 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
return nested_tensors

stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if any(isinstance(t, list) for t in stacked):
if is_list_of(stacked, list):
# Do not stack nested lists
return stacked

tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return stacked

return torch.stack(tensors_)

Expand Down Expand Up @@ -101,8 +102,14 @@ def as_kwargs(
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
return json_map_leaves(lambda x: x.to(device, non_blocking=True),
batched_inputs)
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)

return cast(BatchedTensorInputs, json_mapped)


_T = TypeVar("_T")
Expand Down
2 changes: 1 addition & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ class SequenceGroupMetadata(
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
Expand Down