Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/lora/test_add_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import merge_async_iterators
from vllm.utils.async_utils import merge_async_iterators

MODEL_PATH = "zai-org/chatglm3-6b"
LORA_RANK = 64
Expand Down
42 changes: 42 additions & 0 deletions tests/utils_/test_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncIterator

import pytest

from vllm.utils.async_utils import merge_async_iterators


async def _mock_async_iterator(idx: int):
try:
while True:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print(f"iterator {idx} cancelled")


@pytest.mark.asyncio
async def test_merge_async_iterators():
iterators = [_mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators)

async def stream_output(generator: AsyncIterator[tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")

task = asyncio.create_task(stream_output(merged_iterator))
await asyncio.sleep(0.5)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
except (Exception, asyncio.CancelledError) as e:
raise AssertionError() from e
36 changes: 0 additions & 36 deletions tests/utils_/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa

import asyncio
import hashlib
import json
import os
import pickle
import socket
import tempfile
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -37,7 +35,6 @@
make_zmq_path,
make_zmq_socket,
memory_profiling,
merge_async_iterators,
sha256,
split_host_port,
split_zmq_path,
Expand All @@ -48,39 +45,6 @@
from ..utils import create_new_process_for_each_test


@pytest.mark.asyncio
async def test_merge_async_iterators():
async def mock_async_iterator(idx: int):
try:
while True:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print(f"iterator {idx} cancelled")

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators)

async def stream_output(generator: AsyncIterator[tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")

task = asyncio.create_task(stream_output(merged_iterator))
await asyncio.sleep(0.5)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
except (Exception, asyncio.CancelledError) as e:
raise AssertionError() from e


def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_PORT", "5678")
Expand Down
2 changes: 1 addition & 1 deletion vllm/benchmarks/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
from vllm.utils.async_utils import merge_async_iterators


def run_vllm(
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import as_list, merge_async_iterators
from vllm.utils import as_list
from vllm.utils.async_utils import merge_async_iterators

logger = init_logger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from vllm.pooling_params import PoolingParams
from vllm.utils import chunk_list
from vllm.utils.async_utils import merge_async_iterators

logger = init_logger(__name__)

Expand Down Expand Up @@ -387,8 +388,6 @@ async def _prepare_generators(
)
generators.append(generator)

from vllm.utils import merge_async_iterators

ctx.result_generator = merge_async_iterators(*generators)

return None
Expand Down
7 changes: 3 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,13 @@
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (
from vllm.utils import is_list_of, random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
is_list_of,
make_async,
merge_async_iterators,
random_uuid,
)
from vllm.utils.func import make_async
from vllm.v1.engine import EngineCoreRequest

logger = init_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils import merge_async_iterators
from vllm.utils.async_utils import merge_async_iterators

logger = init_logger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import merge_async_iterators
from vllm.utils.func import make_async
from vllm.utils.async_utils import make_async, merge_async_iterators

logger = init_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
from vllm.utils.async_utils import AsyncMicrobatchTokenizer


@dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask
from vllm.utils.func import make_async
from vllm.utils.async_utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.v1.worker.worker_base import WorkerBase

Expand Down
9 changes: 7 additions & 2 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (
_run_task_with_lock,
get_distributed_init_method,
get_ip,
get_open_port,
)
from vllm.utils.func import make_async
from vllm.utils.async_utils import make_async
from vllm.v1.outputs import SamplerOutput

if ray is not None:
Expand Down Expand Up @@ -748,3 +747,9 @@ def check_health(self) -> None:
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return


async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
Loading