11# SPDX-License-Identifier: Apache-2.0
22
3- import asyncio
4- import os
5- import sys
6- from typing import Optional
7- from unittest .mock import patch
8-
93import pytest
104from transformers import AutoTokenizer , PreTrainedTokenizerBase
115
12- from vllm .transformers_utils .tokenizer_group import (TokenizerGroup ,
13- get_tokenizer_group )
14- from vllm .transformers_utils .tokenizer_group .ray_tokenizer_group import (
15- RayTokenizerGroupPool )
16-
17- from ..conftest import get_tokenizer_pool_config
18-
19-
20- class CustomTokenizerGroup (TokenizerGroup ):
21-
22- def __init__ (self , * args , ** kwargs ):
23- super ().__init__ (* args , ** kwargs )
24- self ._i = 0
25-
26- def encode (self , * args , ** kwargs ):
27- self ._i += 1
28- return super ().encode (* args , ** kwargs )
6+ from vllm .transformers_utils .tokenizer_group import TokenizerGroup
297
308
319@pytest .mark .asyncio
32- @pytest .mark .parametrize ("tokenizer_group_type" ,
33- [None , "ray" , CustomTokenizerGroup ])
34- async def test_tokenizer_group (tokenizer_group_type ):
10+ async def test_tokenizer_group ():
3511 reference_tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
36- tokenizer_group = get_tokenizer_group (
37- get_tokenizer_pool_config (tokenizer_group_type ),
12+ tokenizer_group = TokenizerGroup (
3813 tokenizer_id = "gpt2" ,
3914 enable_lora = False ,
4015 max_num_seqs = 1 ,
@@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
4924 PreTrainedTokenizerBase )
5025 assert tokenizer_group .get_lora_tokenizer (
5126 None ) == await tokenizer_group .get_lora_tokenizer_async (None )
52- if tokenizer_group_type is CustomTokenizerGroup :
53- assert tokenizer_group ._i > 0
54-
55-
56- @pytest .mark .asyncio
57- @pytest .mark .parametrize ("tokenizer_group_type" , ["ray" ])
58- async def test_tokenizer_group_pool (tokenizer_group_type ):
59- reference_tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
60- tokenizer_group_pool = get_tokenizer_group (
61- get_tokenizer_pool_config (tokenizer_group_type ),
62- tokenizer_id = "gpt2" ,
63- enable_lora = False ,
64- max_num_seqs = 1 ,
65- max_input_length = None ,
66- )
67- # Send multiple requests to the tokenizer group pool
68- # (more than the pool size)
69- # and check that all requests are processed correctly.
70- num_requests = tokenizer_group_pool .pool_size * 5
71- requests = [
72- tokenizer_group_pool .encode_async (prompt = f"prompt { i } " ,
73- lora_request = None )
74- for i in range (num_requests )
75- ]
76- results = await asyncio .gather (* requests )
77- expected_results = [
78- reference_tokenizer .encode (f"prompt { i } " ) for i in range (num_requests )
79- ]
80- assert results == expected_results
81-
82-
83- @pytest .mark .asyncio
84- @pytest .mark .parametrize ("tokenizer_group_type" , ["ray" ])
85- async def test_tokenizer_group_ray_pool_env_var_propagation (
86- tokenizer_group_type ):
87- """Test that env vars from caller process are propagated to
88- tokenizer Ray actors."""
89- env_var = "MY_ENV_VAR"
90-
91- class EnvVarCheckerTokenizerGroup (TokenizerGroup ):
92-
93- def ping (self ):
94- assert os .environ .get (env_var ) == "1"
95- return super ().ping ()
96-
97- class EnvVarCheckerRayTokenizerGroupPool (RayTokenizerGroupPool ):
98- _worker_cls = EnvVarCheckerTokenizerGroup
99-
100- tokenizer_pool_config = get_tokenizer_pool_config (tokenizer_group_type )
101- tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool .from_config (
102- tokenizer_pool_config ,
103- tokenizer_id = "gpt2" ,
104- enable_lora = False ,
105- max_num_seqs = 1 ,
106- max_input_length = None )
107- with pytest .raises (AssertionError ):
108- tokenizer_pool .ping ()
109-
110- with patch .dict (os .environ , {env_var : "1" }):
111- tokenizer_pool_config = get_tokenizer_pool_config (tokenizer_group_type )
112- tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool .from_config (
113- tokenizer_pool_config ,
114- tokenizer_id = "gpt2" ,
115- enable_lora = False ,
116- max_num_seqs = 1 ,
117- max_input_length = None )
118- tokenizer_pool .ping ()
119-
120-
121- @pytest .mark .asyncio
122- @pytest .mark .parametrize ("tokenizer_group_type" , ["ray" ])
123- async def test_tokenizer_group_ray_pool_fault_tolerance (tokenizer_group_type ):
124- """Test that Ray tokenizer pool group can recover from failures and
125- if that's not possible, mark itself as unhealthy."""
126-
127- class FailingTokenizerGroup (TokenizerGroup ):
128-
129- def __init__ (self ,
130- * args ,
131- fail_at : Optional [list [int ]] = None ,
132- ** kwargs ):
133- super ().__init__ (* args , ** kwargs )
134- self .i = 0
135- self .fail_at = fail_at or []
136-
137- def encode (self , * args , ** kwargs ):
138- self .i += 1
139- if self .i in self .fail_at :
140- sys .exit (1 )
141- return super ().encode (* args , ** kwargs )
142-
143- class FailingRayTokenizerGroupPool (RayTokenizerGroupPool ):
144- _worker_cls = FailingTokenizerGroup
145-
146- # Fail at first iteration
147- fail_at = [1 ]
148- tokenizer_pool_config = get_tokenizer_pool_config (tokenizer_group_type )
149- tokenizer_group_pool = FailingRayTokenizerGroupPool .from_config (
150- tokenizer_pool_config ,
151- tokenizer_id = "gpt2" ,
152- enable_lora = False ,
153- max_num_seqs = 1 ,
154- max_input_length = None ,
155- fail_at = fail_at )
156- tokenizer_actors = tokenizer_group_pool .tokenizer_actors .copy ()
157-
158- # Modify fail at to not fail at all (will be re-read when actor is
159- # re-initialized).
160- fail_at [0 ] = 1000
161-
162- # We should recover successfully.
163- await tokenizer_group_pool .encode_async (prompt = "prompt" , lora_request = None )
164- await tokenizer_group_pool .encode_async (prompt = "prompt" , lora_request = None )
165-
166- # Check that we have a new actor
167- assert len (tokenizer_group_pool .tokenizer_actors ) == len (tokenizer_actors )
168- assert tokenizer_group_pool .tokenizer_actors != tokenizer_actors
169-
170- # Fail at first iteration
171- fail_at = [1 ]
172- tokenizer_group_pool = FailingRayTokenizerGroupPool .from_config (
173- tokenizer_pool_config ,
174- tokenizer_id = "gpt2" ,
175- enable_lora = False ,
176- max_num_seqs = 1 ,
177- max_input_length = None ,
178- fail_at = fail_at )
179-
180- # We should fail after re-initialization.
181- with pytest .raises (RuntimeError ):
182- await tokenizer_group_pool .encode_async (prompt = "prompt" ,
183- lora_request = None )
184-
185- # check_health should raise the same thing
186- with pytest .raises (RuntimeError ):
187- tokenizer_group_pool .check_health ()
188-
189- # Ensure that non-ActorDiedErrors are still propagated correctly and do not
190- # cause a re-initialization.
191- fail_at = []
192- tokenizer_group_pool = FailingRayTokenizerGroupPool .from_config (
193- tokenizer_pool_config ,
194- tokenizer_id = "gpt2" ,
195- enable_lora = False ,
196- max_num_seqs = 1 ,
197- max_input_length = 2 ,
198- fail_at = fail_at )
199- tokenizer_actors = tokenizer_group_pool .tokenizer_actors .copy ()
200-
201- # Prompt too long error
202- with pytest .raises (ValueError ):
203- await tokenizer_group_pool .encode_async (prompt = "prompt" * 100 ,
204- lora_request = None )
205- await tokenizer_group_pool .encode_async (prompt = "prompt" , lora_request = None )
206- # Actors should stay the same.
207- assert tokenizer_group_pool .tokenizer_actors == tokenizer_actors
0 commit comments