Skip to content

Commit

Permalink
cleanup: fix test naming convention (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored Aug 6, 2024
1 parent e96f886 commit 4823630
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from asgi_lifespan import LifespanManager
from fastapi import Request, Response
from httpx import AsyncClient

from litserve import LitAPI, LitServer
from tests.conftest import wrap_litserve_start
from litserve.server import run_batched_loop
from tests.conftest import wrap_litserve_start


class Linear(nn.Module):
Expand All @@ -38,7 +39,7 @@ def forward(self, x):
return self.linear(x)


class SimpleLitAPI(LitAPI):
class SimpleBatchLitAPI(LitAPI):
def setup(self, device):
self.model = Linear().to(device)
self.device = device
Expand All @@ -63,7 +64,7 @@ def encode_response(self, output) -> Response:
return {"output": float(output)}


class SimpleLitAPI2(LitAPI):
class SimpleTorchAPI(LitAPI):
def setup(self, device):
self.model = Linear().to(device)
self.device = device
Expand All @@ -82,7 +83,7 @@ def encode_response(self, output) -> Response:

@pytest.mark.asyncio()
async def test_batched():
api = SimpleLitAPI()
api = SimpleBatchLitAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4)

with wrap_litserve_start(server) as server:
Expand All @@ -97,7 +98,7 @@ async def test_batched():

@pytest.mark.asyncio()
async def test_unbatched():
api = SimpleLitAPI2()
api = SimpleTorchAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=1)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
Expand All @@ -111,13 +112,13 @@ async def test_unbatched():

def test_max_batch_size():
with pytest.raises(ValueError, match="must be"):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=0)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=0)

with pytest.raises(ValueError, match="must be"):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=-1)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=-1)

with pytest.raises(ValueError, match="must be"):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2, batch_timeout=5)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2, batch_timeout=5)


def test_max_batch_size_warning():
Expand All @@ -126,21 +127,21 @@ def test_max_batch_size_warning():
UserWarning,
match=warning,
):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2)

# Test no warnings are raised when max_batch_size is set
# Test no warnings are raised when max_batch_size is set and max_batch_size is not set
with pytest.raises(pytest.fail.Exception), pytest.warns(
UserWarning,
match=warning,
):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2)

# Test no max_batch_size warnings are raised with a different API
# Test no warning is set when LitAPI doesn't implement batch and unbatch
with pytest.raises(pytest.fail.Exception), pytest.warns(
UserWarning,
match=warning,
):
LitServer(SimpleLitAPI2(), accelerator="cpu", devices=1, timeout=2)
LitServer(SimpleTorchAPI(), accelerator="cpu", devices=1, timeout=2)


class FakeResponseQueue:
Expand Down

0 comments on commit 4823630

Please sign in to comment.