Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions tests/entrypoints/openai/test_lora_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest

Expand Down Expand Up @@ -83,20 +83,31 @@ def register_mock_resolver():
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

def mock_add_lora_side_effect(lora_request: LoRARequest):
tokenizer = get_tokenizer(MODEL_NAME)
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)

async def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora":
# Simulate successful addition
return
elif lora_request.lora_name == "invalid-lora":
return True
if lora_request.lora_name == "invalid-lora":
# Simulate failure during addition (e.g. invalid format)
raise ValueError(f"Simulated failure adding LoRA: "
f"{lora_request.lora_name}")
return True

mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect)

async def mock_generate(*args, **kwargs):
for _ in []:
yield _

mock_engine.generate = MagicMock(spec=AsyncLLM.generate,
side_effect=mock_generate)

mock_engine.add_lora.side_effect = mock_add_lora_side_effect
mock_engine.generate.reset_mock()
mock_engine.add_lora.reset_mock()

Expand Down Expand Up @@ -131,7 +142,7 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup,
with suppress(Exception):
await serving_completion.create_completion(req_found)

mock_engine.add_lora.assert_called_once()
mock_engine.add_lora.assert_awaited_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name
Expand All @@ -157,7 +168,7 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,

response = await serving_completion.create_completion(req)

mock_engine.add_lora.assert_not_called()
mock_engine.add_lora.assert_not_awaited()
mock_engine.generate.assert_not_called()

assert isinstance(response, ErrorResponse)
Expand All @@ -181,7 +192,7 @@ async def test_serving_completion_resolver_add_lora_fails(
response = await serving_completion.create_completion(req)

# Assert add_lora was called before the failure
mock_engine.add_lora.assert_called_once()
mock_engine.add_lora.assert_awaited_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == invalid_model
Expand Down