diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 2f28253bce53..9a82ab99ea9c 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -12,7 +12,7 @@ from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams -from vllm.utils import merge_async_iterators +from vllm.utils.async_utils import merge_async_iterators MODEL_PATH = "zai-org/chatglm3-6b" LORA_RANK = 64 diff --git a/tests/utils_/test_async_utils.py b/tests/utils_/test_async_utils.py new file mode 100644 index 000000000000..03d116bdfd81 --- /dev/null +++ b/tests/utils_/test_async_utils.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from vllm.utils.async_utils import merge_async_iterators + + +async def _mock_async_iterator(idx: int): + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + print(f"iterator {idx} cancelled") + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + iterators = [_mock_async_iterator(i) for i in range(3)] + merged_iterator = merge_async_iterators(*iterators) + + async def stream_output(generator: AsyncIterator[tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index b4883a4fea31..3bc4d3536d58 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -2,14 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa -import asyncio import hashlib import json import os import pickle import socket import tempfile -from collections.abc import AsyncIterator from pathlib import Path from unittest.mock import patch @@ -37,7 +35,6 @@ make_zmq_path, make_zmq_socket, memory_profiling, - merge_async_iterators, sha256, split_host_port, split_zmq_path, @@ -48,39 +45,6 @@ from ..utils import create_new_process_for_each_test -@pytest.mark.asyncio -async def test_merge_async_iterators(): - async def mock_async_iterator(idx: int): - try: - while True: - yield f"item from iterator {idx}" - await asyncio.sleep(0.1) - except asyncio.CancelledError: - print(f"iterator {idx} cancelled") - - iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators(*iterators) - - async def stream_output(generator: AsyncIterator[tuple[int, str]]): - async for idx, output in generator: - print(f"idx: {idx}, output: {output}") - - task = asyncio.create_task(stream_output(merged_iterator)) - await asyncio.sleep(0.5) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - for iterator in iterators: - try: - await asyncio.wait_for(anext(iterator), 1) - except StopAsyncIteration: - # All iterators should be cancelled and print this message. - print("Iterator was cancelled normally") - except (Exception, asyncio.CancelledError) as e: - raise AssertionError() from e - - def test_get_open_port(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_PORT", "5678") diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index ad111a1ebd5b..866365ac18eb 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -34,7 +34,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.utils import merge_async_iterators +from vllm.utils.async_utils import merge_async_iterators def run_vllm( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7cbe9c69435c..f33fce7716a9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -34,7 +34,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import as_list, merge_async_iterators +from vllm.utils import as_list +from vllm.utils.async_utils import merge_async_iterators logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e2b940ef00c0..4c05d9f57fa6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -40,6 +40,7 @@ ) from vllm.pooling_params import PoolingParams from vllm.utils import chunk_list +from vllm.utils.async_utils import merge_async_iterators logger = init_logger(__name__) @@ -387,8 +388,6 @@ async def _prepare_generators( ) generators.append(generator) - from vllm.utils import merge_async_iterators - ctx.result_generator = merge_async_iterators(*generators) return None diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c318c0f425bd..6464d4f9e675 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -90,14 +90,13 @@ log_tracing_disabled_warning, ) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import ( +from vllm.utils import is_list_of, random_uuid +from vllm.utils.async_utils import ( AsyncMicrobatchTokenizer, collect_from_async_generator, - is_list_of, + make_async, merge_async_iterators, - random_uuid, ) -from vllm.utils.func import make_async from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index aa81a233b297..7a27348da35b 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -36,7 +36,7 @@ from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.tasks import SupportedTask -from vllm.utils import merge_async_iterators +from vllm.utils.async_utils import merge_async_iterators logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index e5c7f80a1753..9cbfc9791819 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -37,8 +37,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import merge_async_iterators -from vllm.utils.func import make_async +from vllm.utils.async_utils import make_async, merge_async_iterators logger = init_logger(__name__) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index 4f1213b09730..63487a6ed007 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -17,7 +17,7 @@ from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import AsyncMicrobatchTokenizer +from vllm.utils.async_utils import AsyncMicrobatchTokenizer @dataclass(frozen=True) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 093d5e97fd3e..9de2249f6c05 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest from vllm.tasks import SupportedTask -from vllm.utils.func import make_async +from vllm.utils.async_utils import make_async from vllm.v1.outputs import SamplerOutput from vllm.v1.worker.worker_base import WorkerBase diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index a57b64152f49..b41466a6a770 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -20,12 +20,11 @@ from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest from vllm.utils import ( - _run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, ) -from vllm.utils.func import make_async +from vllm.utils.async_utils import make_async from vllm.v1.outputs import SamplerOutput if ray is not None: @@ -748,3 +747,9 @@ def check_health(self) -> None: # Assume that the Ray workers are healthy. # TODO: check the health of the Ray workers return + + +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): + """Utility function to run async task in a lock""" + async with lock: + return await task(*args, **kwargs) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 5fd94b7b4049..99a9225cb6a4 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio import contextlib import datetime import enum @@ -38,10 +37,8 @@ RawDescriptionHelpFormatter, _ArgumentGroup, ) -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import ( - AsyncGenerator, Callable, Collection, Generator, @@ -51,7 +48,6 @@ Mapping, Sequence, ) -from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps @@ -82,7 +78,6 @@ from packaging import version from packaging.version import Version from torch.library import Library -from transformers.tokenization_utils_base import BatchEncoding from typing_extensions import Never, TypeIs, assert_never import vllm.envs as envs @@ -223,278 +218,12 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -class AsyncMicrobatchTokenizer: - """Asynchronous tokenizer with micro-batching. - - Pulls pending encode/decode requests from a queue and batches them - up to reduce overhead. A single-thread ThreadPoolExecutor is used - so the event loop stays responsive. - """ - - def __init__( - self, - tokenizer, - max_batch_size: int = 32, - batch_wait_timeout_s: float = 0.002, - ) -> None: - self.tokenizer = tokenizer - self.max_batch_size = max_batch_size - self.batch_wait_timeout_s = batch_wait_timeout_s - - self._loop = asyncio.get_running_loop() - self._queues: dict[ - tuple, - asyncio.Queue[ - tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future] - ], - ] = {} - self._batcher_tasks: list[asyncio.Task] = [] - - # Single-thread executor for blocking tokenizer calls. - self._executor = ThreadPoolExecutor(max_workers=1) - - # === Public async API === - async def __call__(self, prompt, **kwargs): - result_future: asyncio.Future = self._loop.create_future() - key = self._queue_key("encode", kwargs) - queue = self._get_queue(self._loop, key) - await queue.put((prompt, kwargs, result_future)) - return await result_future - - async def decode(self, token_ids, **kwargs): - result_future: asyncio.Future = self._loop.create_future() - key = self._queue_key("decode", kwargs) - queue = self._get_queue(self._loop, key) - await queue.put((token_ids, result_future)) - return await result_future - - # === Internal helpers === - def _get_queue( - self, loop: asyncio.AbstractEventLoop, key: tuple - ) -> asyncio.Queue[ - tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future] - ]: - """Get the request queue for the given operation key, creating a new - queue and batcher task if needed.""" - queue = self._queues.get(key) - if queue is None: - self._queues[key] = queue = asyncio.Queue() - if key[0] == "encode": - can_batch = key[1] != "other" - coro = self._batch_encode_loop(queue, can_batch) - else: - assert key[0] == "decode", f"Unknown operation type: {key[0]}." - coro = self._batch_decode_loop(queue) - self._batcher_tasks.append(loop.create_task(coro)) - return queue - - async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): - """Batch incoming encode requests for efficiency.""" - while True: - prompt, kwargs, result_future = await queue.get() - prompts = [prompt] - kwargs_list = [kwargs] - result_futures = [result_future] - deadline = self._loop.time() + self.batch_wait_timeout_s - - while len(prompts) < self.max_batch_size: - timeout = deadline - self._loop.time() - if timeout <= 0: - break - try: - prompt, kwargs, result_future = await asyncio.wait_for( - queue.get(), timeout - ) - prompts.append(prompt) - result_futures.append(result_future) - if not can_batch: - kwargs_list.append(kwargs) - except asyncio.TimeoutError: - break - - try: - # If every request uses identical kwargs we can run a single - # batched tokenizer call for a big speed-up. - if can_batch and len(prompts) > 1: - batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) - results = await self._loop.run_in_executor( - self._executor, batch_encode_fn - ) - - for i, fut in enumerate(result_futures): - if not fut.done(): - data = {k: v[i] for k, v in results.items()} - fut.set_result(BatchEncoding(data)) - else: - encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ - self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) - ] - results = await self._loop.run_in_executor( - self._executor, encode_fn - ) - - for fut, res in zip(result_futures, results): - if not fut.done(): - fut.set_result(res) - except Exception as e: - for fut in result_futures: - if not fut.done(): - fut.set_exception(e) - - async def _batch_decode_loop(self, queue: asyncio.Queue): - """Batch incoming decode requests for efficiency.""" - while True: - token_ids, result_future = await queue.get() - token_ids_list = [token_ids] - result_futures = [result_future] - deadline = self._loop.time() + self.batch_wait_timeout_s - - while len(token_ids_list) < self.max_batch_size: - timeout = deadline - self._loop.time() - if timeout <= 0: - break - try: - token_ids, result_future = await asyncio.wait_for( - queue.get(), timeout - ) - token_ids_list.append(token_ids) - result_futures.append(result_future) - except asyncio.TimeoutError: - break - - try: - # Perform a single batched decode call for all requests - results = await self._loop.run_in_executor( - self._executor, self.tokenizer.batch_decode, token_ids_list - ) - for fut, res in zip(result_futures, results): - if not fut.done(): - fut.set_result(res) - except Exception as e: - for fut in result_futures: - if not fut.done(): - fut.set_exception(e) - - def _queue_key(self, op: str, kwargs: dict) -> tuple: - """ - Return a normalized key describing operation + kwargs. - - - `add_special_tokens`: {True/False} - - `truncation`: {True/False} - - If `truncation` is False (`max_length` is None), - returns a key for a can_batch queue. - - If `truncation` is True and `max_length` is None or equals - `tokenizer.model_max_length`, returns a key for a can_batch queue. - - Otherwise, returns a key for a cannot_batch queue. - - Examples: - - Decode: ("decode",) - - Encode typical: - ("encode", add_special_tokens, bool_truncation, max_length_label) - - Fallback: ("encode", "other") - """ - - if op == "decode": - return ("decode",) - - add_special_tokens = kwargs.get("add_special_tokens", True) - truncation = kwargs.get("truncation", False) - max_length = kwargs.get("max_length") - - if not truncation: - return "encode", add_special_tokens, False, None - - model_max = getattr(self.tokenizer, "model_max_length", None) - if max_length is None or (model_max is not None and max_length == model_max): - return "encode", add_special_tokens, True, "model_max" - - return "encode", "other" - - def __del__(self): - if ( - (tasks := getattr(self, "_batcher_tasks", None)) - and (loop := getattr(self, "_loop", None)) - and not loop.is_closed() - ): - - def cancel_tasks(): - for task in tasks: - task.cancel() - - loop.call_soon_threadsafe(cancel_tasks) - - -def cancel_task_threadsafe(task: Task): - if task and not task.done(): - run_in_loop(task.get_loop(), task.cancel) - - def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]): for sock in sockets: if sock is not None: sock.close(linger=0) -def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): - if in_loop(loop): - function(*args) - elif not loop.is_closed(): - loop.call_soon_threadsafe(function, *args) - - -def in_loop(event_loop: AbstractEventLoop) -> bool: - try: - return asyncio.get_running_loop() == event_loop - except RuntimeError: - return False - - -async def merge_async_iterators( - *iterators: AsyncGenerator[T, None], -) -> AsyncGenerator[tuple[int, T], None]: - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - if len(iterators) == 1: - # Fast-path single iterator case. - async for item in iterators[0]: - yield 0, item - return - - loop = asyncio.get_running_loop() - - awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} - try: - while awaits: - done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) - for d in done: - pair = awaits.pop(d) - try: - item = await d - i, it = pair - awaits[loop.create_task(anext(it))] = pair - yield i, item - except StopAsyncIteration: - pass - finally: - # Cancel any remaining iterators - for f, (_, it) in awaits.items(): - with contextlib.suppress(BaseException): - f.cancel() - await it.aclose() - - -async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: - """Collect all items from an async generator into a list.""" - items = [] - async for item in iterator: - items.append(item) - return items - - def get_ip() -> str: host_ip = envs.VLLM_HOST_IP if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ: @@ -1803,12 +1532,6 @@ def load_config_file(self, file_path: str) -> list[str]: return processed_args -async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): - """Utility function to run async task in a lock""" - async with lock: - return await task(*args, **kwargs) - - # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. diff --git a/vllm/utils/async_utils.py b/vllm/utils/async_utils.py new file mode 100644 index 000000000000..aeabd808add5 --- /dev/null +++ b/vllm/utils/async_utils.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Contains helpers related to asynchronous code.""" + +import asyncio +import contextlib +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task +from collections.abc import AsyncGenerator, Awaitable, Callable +from concurrent.futures import Executor, ThreadPoolExecutor +from functools import partial +from typing import TypeVar + +from transformers.tokenization_utils_base import BatchEncoding +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + + +class AsyncMicrobatchTokenizer: + """Asynchronous tokenizer with micro-batching. + + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used + so the event loop stays responsive. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + self._loop = asyncio.get_running_loop() + self._queues: dict[ + tuple, + asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]], + ] = {} + self._batcher_tasks: list[Task] = [] + + # Single-thread executor for blocking tokenizer calls. + self._executor = ThreadPoolExecutor(max_workers=1) + + # === Public async API === + async def __call__(self, prompt, **kwargs): + result_future: Future = self._loop.create_future() + key = self._queue_key("encode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((prompt, kwargs, result_future)) + return await result_future + + async def decode(self, token_ids, **kwargs): + result_future: Future = self._loop.create_future() + key = self._queue_key("decode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((token_ids, result_future)) + return await result_future + + # === Internal helpers === + def _get_queue( + self, loop: asyncio.AbstractEventLoop, key: tuple + ) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]: + """Get the request queue for the given operation key, creating a new + queue and batcher task if needed.""" + queue = self._queues.get(key) + if queue is None: + self._queues[key] = queue = asyncio.Queue() + if key[0] == "encode": + can_batch = key[1] != "other" + coro = self._batch_encode_loop(queue, can_batch) + else: + assert key[0] == "decode", f"Unknown operation type: {key[0]}." + coro = self._batch_decode_loop(queue) + self._batcher_tasks.append(loop.create_task(coro)) + return queue + + async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): + """Batch incoming encode requests for efficiency.""" + while True: + prompt, kwargs, result_future = await queue.get() + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(prompts) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + prompt, kwargs, result_future = await asyncio.wait_for( + queue.get(), timeout + ) + prompts.append(prompt) + result_futures.append(result_future) + if not can_batch: + kwargs_list.append(kwargs) + except asyncio.TimeoutError: + break + + try: + # If every request uses identical kwargs we can run a single + # batched tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await self._loop.run_in_executor( + self._executor, batch_encode_fn + ) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(BatchEncoding(data)) + else: + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) + ] + results = await self._loop.run_in_executor( + self._executor, encode_fn + ) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + async def _batch_decode_loop(self, queue: asyncio.Queue): + """Batch incoming decode requests for efficiency.""" + while True: + token_ids, result_future = await queue.get() + token_ids_list = [token_ids] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(token_ids_list) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + token_ids, result_future = await asyncio.wait_for( + queue.get(), timeout + ) + token_ids_list.append(token_ids) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + try: + # Perform a single batched decode call for all requests + results = await self._loop.run_in_executor( + self._executor, self.tokenizer.batch_decode, token_ids_list + ) + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def _queue_key(self, op: str, kwargs: dict) -> tuple: + """ + Return a normalized key describing operation + kwargs. + + - `add_special_tokens`: {True/False} + - `truncation`: {True/False} + - If `truncation` is False (`max_length` is None), + returns a key for a can_batch queue. + - If `truncation` is True and `max_length` is None or equals + `tokenizer.model_max_length`, returns a key for a can_batch queue. + - Otherwise, returns a key for a cannot_batch queue. + + Examples: + - Decode: ("decode",) + - Encode typical: + ("encode", add_special_tokens, bool_truncation, max_length_label) + - Fallback: ("encode", "other") + """ + + if op == "decode": + return ("decode",) + + add_special_tokens = kwargs.get("add_special_tokens", True) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") + + if not truncation: + return "encode", add_special_tokens, False, None + + model_max = getattr(self.tokenizer, "model_max_length", None) + if max_length is None or (model_max is not None and max_length == model_max): + return "encode", add_special_tokens, True, "model_max" + + return "encode", "other" + + def __del__(self): + if ( + (tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed() + ): + + def cancel_tasks(): + for task in tasks: + task.cancel() + + loop.call_soon_threadsafe(cancel_tasks) + + +def cancel_task_threadsafe(task: Task): + if task and not task.done(): + run_in_loop(task.get_loop(), task.cancel) + + +def make_async( + func: Callable[P, T], + executor: Executor | None = None, +) -> Callable[P, Awaitable[T]]: + """ + Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=executor, func=p_func) + + return _async_wrapper + + +def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): + if in_loop(loop): + function(*args) + elif not loop.is_closed(): + loop.call_soon_threadsafe(function, *args) + + +def in_loop(event_loop: AbstractEventLoop) -> bool: + try: + return asyncio.get_running_loop() == event_loop + except RuntimeError: + return False + + +async def merge_async_iterators( + *iterators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[int, T], None]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + if len(iterators) == 1: + # Fast-path single iterator case. + async for item in iterators[0]: + yield 0, item + return + + loop = asyncio.get_running_loop() + + awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} + try: + while awaits: + done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[loop.create_task(anext(it))] = pair + yield i, item + except StopAsyncIteration: + pass + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + f.cancel() + await it.aclose() + + +async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items diff --git a/vllm/utils/func.py b/vllm/utils/func.py index bd26b29d5f6d..c061a0dad552 100644 --- a/vllm/utils/func.py +++ b/vllm/utils/func.py @@ -6,12 +6,10 @@ This is similar in concept to the `functools` module. """ -import asyncio -import concurrent.futures import inspect import threading import warnings -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Callable, Mapping from functools import lru_cache, partial, wraps from typing import Any, TypeVar @@ -32,26 +30,6 @@ def identity(value: T, **kwargs) -> T: return value -def make_async( - func: Callable[P, T], - executor: concurrent.futures.Executor | None = None, -) -> Callable[P, Awaitable[T]]: - """ - Take a blocking function, and run it on in an executor thread. - - This function prevents the blocking function from blocking the - asyncio event loop. - The code in this function needs to be thread safe. - """ - - def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]: - loop = asyncio.get_event_loop() - p_func = partial(func, *args, **kwargs) - return loop.run_in_executor(executor=executor, func=p_func) - - return _async_wrapper - - def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if wrapper.has_run: # type: ignore[attr-defined] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c8fb30f96c0a..ed9d82ca5373 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -29,7 +29,8 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv +from vllm.utils import Device, as_list, cdiv +from vllm.utils.async_utils import cancel_task_threadsafe from vllm.utils.func import deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c800d0d279af..a9deebc7e1f5 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -27,9 +27,9 @@ close_sockets, get_open_port, get_open_zmq_inproc_path, - in_loop, make_zmq_socket, ) +from vllm.utils.async_utils import in_loop from vllm.v1.engine import ( EngineCoreOutputs, EngineCoreRequest,