77from contextlib import suppress
88from dataclasses import dataclass , field
99from typing import TYPE_CHECKING , Any , Optional
10- from unittest .mock import MagicMock
10+ from unittest .mock import AsyncMock , MagicMock
1111
1212import pytest
1313import pytest_asyncio
@@ -230,6 +230,7 @@ class MockHFConfig:
230230@dataclass
231231class MockModelConfig :
232232 task = "generate"
233+ runner_type = "generate"
233234 tokenizer = MODEL_NAME
234235 trust_remote_code = False
235236 tokenizer_mode = "auto"
@@ -244,11 +245,33 @@ class MockModelConfig:
244245 encoder_config = None
245246 generation_config : str = "auto"
246247 media_io_kwargs : dict [str , dict [str , Any ]] = field (default_factory = dict )
248+ skip_tokenizer_init = False
247249
248250 def get_diff_sampling_param (self ):
249251 return self .diff_sampling_param or {}
250252
251253
254+ def _build_serving_chat (engine : AsyncLLM ,
255+ model_config : MockModelConfig ) -> OpenAIServingChat :
256+ models = OpenAIServingModels (engine_client = engine ,
257+ base_model_paths = BASE_MODEL_PATHS ,
258+ model_config = model_config )
259+ serving_chat = OpenAIServingChat (engine ,
260+ model_config ,
261+ models ,
262+ response_role = "assistant" ,
263+ chat_template = CHAT_TEMPLATE ,
264+ chat_template_content_format = "auto" ,
265+ request_logger = None )
266+
267+ async def _fake_process_inputs (request_id , engine_prompt , sampling_params ,
268+ * , lora_request , trace_headers , priority ):
269+ return dict (engine_prompt ), {}
270+
271+ serving_chat ._process_inputs = AsyncMock (side_effect = _fake_process_inputs )
272+ return serving_chat
273+
274+
252275@dataclass
253276class MockEngine :
254277
@@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
282305 mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
283306 mock_engine .errored = False
284307
285- models = OpenAIServingModels (engine_client = mock_engine ,
286- base_model_paths = BASE_MODEL_PATHS ,
287- model_config = MockModelConfig ())
288- serving_chat = OpenAIServingChat (mock_engine ,
289- MockModelConfig (),
290- models ,
291- response_role = "assistant" ,
292- chat_template = CHAT_TEMPLATE ,
293- chat_template_content_format = "auto" ,
294- request_logger = None )
308+ serving_chat = _build_serving_chat (mock_engine , MockModelConfig ())
295309 messages = [{"role" : "user" , "content" : "what is 1+1?" }]
296310
297311 async def return_model_name (* args ):
@@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
318332 mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
319333 mock_engine .errored = False
320334
321- models = OpenAIServingModels (engine_client = mock_engine ,
322- base_model_paths = BASE_MODEL_PATHS ,
323- model_config = MockModelConfig ())
324- serving_chat = OpenAIServingChat (mock_engine ,
325- MockModelConfig (),
326- models ,
327- response_role = "assistant" ,
328- chat_template = CHAT_TEMPLATE ,
329- chat_template_content_format = "auto" ,
330- request_logger = None )
335+ serving_chat = _build_serving_chat (mock_engine , MockModelConfig ())
331336
332337 req = ChatCompletionRequest (
333338 model = MODEL_NAME ,
@@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
361366 mock_engine .errored = False
362367
363368 # Initialize the serving chat
364- models = OpenAIServingModels (engine_client = mock_engine ,
365- base_model_paths = BASE_MODEL_PATHS ,
366- model_config = mock_model_config )
367- serving_chat = OpenAIServingChat (mock_engine ,
368- mock_model_config ,
369- models ,
370- response_role = "assistant" ,
371- chat_template = CHAT_TEMPLATE ,
372- chat_template_content_format = "auto" ,
373- request_logger = None )
369+ serving_chat = _build_serving_chat (mock_engine , mock_model_config )
374370
375371 # Test Case 1: No max_tokens specified in request
376372 req = ChatCompletionRequest (
@@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
415411 mock_engine .errored = False
416412
417413 # Initialize the serving chat
418- models = OpenAIServingModels (engine_client = mock_engine ,
419- base_model_paths = BASE_MODEL_PATHS ,
420- model_config = mock_model_config )
421- serving_chat = OpenAIServingChat (mock_engine ,
422- mock_model_config ,
423- models ,
424- response_role = "assistant" ,
425- chat_template = CHAT_TEMPLATE ,
426- chat_template_content_format = "auto" ,
427- request_logger = None )
414+ serving_chat = _build_serving_chat (mock_engine , mock_model_config )
428415
429416 # Test case 1: No max_tokens specified, defaults to context_window
430417 req = ChatCompletionRequest (
@@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
471458 mock_engine .errored = False
472459
473460 # Initialize the serving chat
474- models = OpenAIServingModels (engine_client = mock_engine ,
475- base_model_paths = BASE_MODEL_PATHS ,
476- model_config = mock_model_config )
477- serving_chat = OpenAIServingChat (mock_engine ,
478- mock_model_config ,
479- models ,
480- response_role = "assistant" ,
481- chat_template = CHAT_TEMPLATE ,
482- chat_template_content_format = "auto" ,
483- request_logger = None )
461+ serving_chat = _build_serving_chat (mock_engine , mock_model_config )
484462
485463 req = ChatCompletionRequest (
486464 model = MODEL_NAME ,
@@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
525503 mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
526504 mock_engine .errored = False
527505
528- # Initialize the serving chat
529- models = OpenAIServingModels (engine_client = mock_engine ,
530- base_model_paths = BASE_MODEL_PATHS ,
531- model_config = mock_model_config )
532- serving_chat = OpenAIServingChat (mock_engine ,
533- mock_model_config ,
534- models ,
535- response_role = "assistant" ,
536- chat_template = CHAT_TEMPLATE ,
537- chat_template_content_format = "auto" ,
538- request_logger = None )
506+ serving_chat = _build_serving_chat (mock_engine , mock_model_config )
539507
540508 # Test cache_salt
541509 req = ChatCompletionRequest (
@@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
549517 # By default, cache_salt in the engine prompt is not set
550518 with suppress (Exception ):
551519 await serving_chat .create_chat_completion (req )
552- assert "cache_salt" not in mock_engine .generate .call_args .args [0 ]
520+ engine_prompt = serving_chat ._process_inputs .await_args_list [0 ].args [1 ]
521+ assert "cache_salt" not in engine_prompt
553522
554523 # Test with certain cache_salt
555524 req .cache_salt = "test_salt"
556525 with suppress (Exception ):
557526 await serving_chat .create_chat_completion (req )
558- assert mock_engine .generate .call_args .args [0 ]["cache_salt" ] == "test_salt"
527+ engine_prompt = serving_chat ._process_inputs .await_args_list [1 ].args [1 ]
528+ assert engine_prompt .get ("cache_salt" ) == "test_salt"
0 commit comments