diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index e2c83b9c4004..9d5ee84a1956 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from http import HTTPStatus from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -83,20 +83,31 @@ def register_mock_resolver(): def mock_serving_setup(): """Provides a mocked engine and serving completion instance.""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False - def mock_add_lora_side_effect(lora_request: LoRARequest): + tokenizer = get_tokenizer(MODEL_NAME) + mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer) + + async def mock_add_lora_side_effect(lora_request: LoRARequest): """Simulate engine behavior when adding LoRAs.""" if lora_request.lora_name == "test-lora": # Simulate successful addition - return - elif lora_request.lora_name == "invalid-lora": + return True + if lora_request.lora_name == "invalid-lora": # Simulate failure during addition (e.g. invalid format) raise ValueError(f"Simulated failure adding LoRA: " f"{lora_request.lora_name}") + return True + + mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect) + + async def mock_generate(*args, **kwargs): + for _ in []: + yield _ + + mock_engine.generate = MagicMock(spec=AsyncLLM.generate, + side_effect=mock_generate) - mock_engine.add_lora.side_effect = mock_add_lora_side_effect mock_engine.generate.reset_mock() mock_engine.add_lora.reset_mock() @@ -131,7 +142,7 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup, with suppress(Exception): await serving_completion.create_completion(req_found) - mock_engine.add_lora.assert_called_once() + mock_engine.add_lora.assert_awaited_once() called_lora_request = mock_engine.add_lora.call_args[0][0] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == lora_model_name @@ -157,7 +168,7 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, response = await serving_completion.create_completion(req) - mock_engine.add_lora.assert_not_called() + mock_engine.add_lora.assert_not_awaited() mock_engine.generate.assert_not_called() assert isinstance(response, ErrorResponse) @@ -181,7 +192,7 @@ async def test_serving_completion_resolver_add_lora_fails( response = await serving_completion.create_completion(req) # Assert add_lora was called before the failure - mock_engine.add_lora.assert_called_once() + mock_engine.add_lora.assert_awaited_once() called_lora_request = mock_engine.add_lora.call_args[0][0] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == invalid_model