From 7d5044b098fd00b38f0e593dad12da7032180820 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 25 Jun 2024 10:15:10 -0700 Subject: [PATCH] [Core] Add fault tolerance for `RayTokenizerGroupPool` (#5748) --- tests/tokenization/test_tokenizer_group.py | 99 ++++++++++++++++ vllm/engine/async_llm_engine.py | 2 + vllm/engine/llm_engine.py | 2 + .../tokenizer_group/base_tokenizer_group.py | 4 + .../tokenizer_group/ray_tokenizer_group.py | 112 ++++++++++++++---- 5 files changed, 195 insertions(+), 24 deletions(-) diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index 31571dbfff6f6..1b9a590750429 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -1,5 +1,7 @@ import asyncio import os +import sys +from typing import List, Optional from unittest.mock import patch import pytest @@ -100,3 +102,100 @@ class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): max_num_seqs=1, max_input_length=None) tokenizer_pool.ping() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): + """Test that Ray tokenizer pool group can recover from failures and + if that's not possible, mark itself as unhealthy.""" + + class FailingTokenizerGroup(TokenizerGroup): + + def __init__(self, + *args, + fail_at: Optional[List[int]] = None, + **kwargs): + super().__init__(*args, **kwargs) + self.i = 0 + self.fail_at = fail_at or [] + + def encode(self, *args, **kwargs): + self.i += 1 + if self.i in self.fail_at: + sys.exit(1) + return super().encode(*args, **kwargs) + + class FailingRayTokenizerGroupPool(RayTokenizerGroupPool): + _worker_cls = FailingTokenizerGroup + + # Fail at first iteration + fail_at = [1] + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + fail_at=fail_at) + tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() + + # Modify fail at to not fail at all (will be re-read when actor is + # re-initialized). + fail_at[0] = 1000 + + # We should recover successfully. + await tokenizer_group_pool.encode_async(request_id="1", + prompt="prompt", + lora_request=None) + await tokenizer_group_pool.encode_async(request_id="1", + prompt="prompt", + lora_request=None) + + # Check that we have a new actor + assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) + assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors + + # Fail at first iteration + fail_at = [1] + tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + fail_at=fail_at) + + # We should fail after re-initialization. + with pytest.raises(RuntimeError): + await tokenizer_group_pool.encode_async(request_id="1", + prompt="prompt", + lora_request=None) + + # check_health should raise the same thing + with pytest.raises(RuntimeError): + tokenizer_group_pool.check_health() + + # Ensure that non-ActorDiedErrors are still propagated correctly and do not + # cause a re-initialization. + fail_at = [] + tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=2, + fail_at=fail_at) + tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() + + # Prompt too long error + with pytest.raises(ValueError): + await tokenizer_group_pool.encode_async(request_id="1", + prompt="prompt" * 100, + lora_request=None) + await tokenizer_group_pool.encode_async(request_id="1", + prompt="prompt", + lora_request=None) + # Actors should stay the same. + assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 11b1c611e4c58..9ee71c1a19fa2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -310,6 +310,8 @@ async def add_request_async( ) async def check_health_async(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() self.model_executor.check_health() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7eae257fdd16..0ad957ef9f958 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1013,6 +1013,8 @@ def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() self.model_executor.check_health() def is_tracing_enabled(self) -> bool: diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 3cce96e06d1a0..18fbd894f1c0e 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -53,3 +53,7 @@ async def get_lora_tokenizer_async( ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass + + def check_health(self): + """Raise exception if the tokenizer group is unhealthy.""" + return diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 7c605416854b8..21ec2b52bc95e 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -2,17 +2,21 @@ import os from typing import List, Optional +from ray.exceptions import ActorDiedError from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from transformers import PreTrainedTokenizer from vllm.config import TokenizerPoolConfig from vllm.executor.ray_utils import ray +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( TokenizerGroup) +logger = init_logger(__name__) + class RayTokenizerGroupPool(BaseTokenizerGroup): """A Ray-based pool of TokenizerGroups for async tokenization.""" @@ -46,24 +50,28 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, ray_actor_options: dict, **tokenizer_config): # Store a local copy of the TokenizerGroup for quick access # to underlying HF tokenizers. + self._tokenizer_config = { + "tokenizer_id": tokenizer_id, + "enable_lora": enable_lora, + "max_num_seqs": max_num_seqs, + "max_input_length": max_input_length, + **tokenizer_config + } self._local_tokenizer_group = self._worker_cls( - tokenizer_id=tokenizer_id, - enable_lora=enable_lora, - max_num_seqs=max_num_seqs, - max_input_length=max_input_length, - **tokenizer_config, - ) - - ray_tokenizer_group_cls = ray.remote( + **self._tokenizer_config, ) + + self._ray_tokenizer_group_cls = ray.remote( self._worker_cls).options(**ray_actor_options) - self.tokenizer_actors = [ - ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora, - max_num_seqs, max_input_length, - **tokenizer_config) - for _ in range(num_actors) - ] + self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] self._idle_actors: Optional[asyncio.Queue] = None + # If set, actor is unhealthy. Will reraise on the next + # check_health call. + self._exception: Optional[ActorDiedError] = None + + def _init_actor(self) -> ray.ObjectRef: + return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config) + @property def pool_size(self) -> int: return len(self.tokenizer_actors) @@ -78,6 +86,22 @@ def _ensure_queue_initialized(self): for actor in self.tokenizer_actors: self._idle_actors.put_nowait(actor) + def _finalize_encode(self, actor: ray.ObjectRef, + original_actor: ray.ObjectRef, actor_is_alive: bool): + assert self._idle_actors is not None + # Cleanup the dead actor. + if not actor_is_alive or original_actor is not actor: + self.tokenizer_actors.remove(original_actor) + if actor_is_alive: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + # Add back the new actor. + if original_actor is not actor: + self.tokenizer_actors.append(actor) + def encode(self, prompt: str, request_id: Optional[str] = None, @@ -88,23 +112,41 @@ def encode(self, The actor is then put back in the queue for future use. This is blocking. """ + self.check_health() self._ensure_queue_initialized() assert self._idle_actors is not None if self._idle_actors.empty(): raise RuntimeError("No idle actors available.") actor = self._idle_actors.get_nowait() + actor_is_alive = True + original_actor = actor try: ret = ray.get( actor.encode.remote(request_id=request_id, prompt=prompt, lora_request=lora_request)) + except ActorDiedError as e: + # If the actor is dead, we first try to reinitialize it. + logger.warning("%s died with ActorDiedError, reinitializing.", + actor, + exc_info=e) + actor = self._init_actor() + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + except ActorDiedError as e: + logger.error( + "%s died for second time in a row, marking " + "RayTokenizerGroupPool as unhealthy.", actor) + actor_is_alive = False + if not self._exception: + self._exception = e + self.check_health() finally: - # Put the actor back in the queue. - # This is done in a finally block to ensure that the actor is - # always put back in the queue, even if an exception/cancellation - # is raised. - self._idle_actors.put_nowait(actor) + self._finalize_encode(actor, original_actor, actor_is_alive) return ret async def encode_async( @@ -120,20 +162,37 @@ async def encode_async( The actor is then put back in the queue for future use. This is non-blocking. """ + self.check_health() self._ensure_queue_initialized() assert self._idle_actors is not None actor = await self._idle_actors.get() + actor_is_alive = True + original_actor = actor try: ret = await actor.encode.remote(request_id=request_id, prompt=prompt, lora_request=lora_request) + except ActorDiedError as e: + # If the actor is dead, we first try to reinitialize it. + logger.warning("%s died with ActorDiedError, reinitializing.", + actor, + exc_info=e) + actor = self._init_actor() + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + except ActorDiedError as e: + logger.error( + "%s died for second time in a row, marking " + "RayTokenizerGroupPool as unhealthy.", actor) + actor_is_alive = False + if not self._exception: + self._exception = e + self.check_health() finally: - # Put the actor back in the queue. - # This is done in a finally block to ensure that the actor is - # always put back in the queue, even if an exception/cancellation - # is raised. - self._idle_actors.put_nowait(actor) + self._finalize_encode(actor, original_actor, actor_is_alive) return ret def get_max_input_len(self, @@ -155,6 +214,11 @@ async def get_lora_tokenizer_async( return await self._local_tokenizer_group.get_lora_tokenizer_async( lora_request) + def check_health(self): + if self._exception: + raise RuntimeError( + "TokenizerGroupPool is unhealthy.") from self._exception + def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: """Copy over all current process environment variables to the runtime_env.