Skip to content

Commit

Permalink
[tests] handler decomposition (#125)
Browse files Browse the repository at this point in the history
Why
===

Striking a balance between test handler reuse and colocating the handler
specifiers with the tests that use them.

What changed
============

- Decomposing monolithic server handler specifiers into directly DI-ing
reusable components
- For tests that need bespoke handlers, defining those alongside the
methods that use them.

Test plan
=========

CI
  • Loading branch information
blast-hardcheese authored Nov 28, 2024
1 parent 88f4347 commit 502278f
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 93 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dev-dependencies = [
"pytest-mock>=3.11.1",
"ruff>=0.0.278",
"types-protobuf>=4.24.0.20240311",
"types-nanoid>=2.0.0.20240601",
]

[tool.ruff]
Expand Down
82 changes: 82 additions & 0 deletions tests/common_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Any, AsyncGenerator, AsyncIterator, Iterator

import grpc
import grpc.aio

from replit_river.rpc import (
rpc_method_handler,
stream_method_handler,
subscription_method_handler,
upload_method_handler,
)
from tests.conftest import HandlerMapping, deserialize_request, serialize_response


async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
return f"Hello, {request}!"


basic_rpc_method: HandlerMapping = {
("test_service", "rpc_method"): (
"rpc",
rpc_method_handler(rpc_handler, deserialize_request, serialize_response),
)
}


async def upload_handler(
request: Iterator[str] | AsyncIterator[str], context: Any
) -> str:
uploaded_data = []
if isinstance(request, AsyncIterator):
async for data in request:
uploaded_data.append(data)
else:
for data in request:
uploaded_data.append(data)
return f"Uploaded: {', '.join(uploaded_data)}"


basic_upload: HandlerMapping = {
("test_service", "upload_method"): (
"upload",
upload_method_handler(upload_handler, deserialize_request, serialize_response),
),
}


async def subscription_handler(
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(5):
yield f"Subscription message {i} for {request}"


basic_subscription: HandlerMapping = {
("test_service", "subscription_method"): (
"subscription",
subscription_method_handler(
subscription_handler, deserialize_request, serialize_response
),
),
}


async def stream_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
if isinstance(request, AsyncIterator):
async for data in request:
yield f"Stream response for {data}"
else:
for data in request:
yield f"Stream response for {data}"


basic_stream: HandlerMapping = {
("test_service", "stream_method"): (
"stream",
stream_method_handler(stream_handler, deserialize_request, serialize_response),
),
}
104 changes: 15 additions & 89 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from typing import Any, AsyncGenerator, Iterator, Literal
from typing import Any, AsyncGenerator, Literal, Mapping

import grpc.aio
import nanoid # type: ignore
import nanoid
import pytest
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
Expand All @@ -14,13 +12,10 @@

from replit_river.client import Client
from replit_river.client_transport import UriAndMetadata
from replit_river.error_schema import RiverError, RiverException
from replit_river.error_schema import RiverError
from replit_river.rpc import (
GenericRpcHandler,
TransportMessage,
rpc_method_handler,
stream_method_handler,
subscription_method_handler,
upload_method_handler,
)
from replit_river.server import Server
from replit_river.transport_options import TransportOptions
Expand All @@ -29,6 +24,8 @@
# Modular fixtures
pytest_plugins = ["tests.river_fixtures.logging"]

HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]


def transport_message(
seq: int = 0,
Expand Down Expand Up @@ -71,93 +68,22 @@ def deserialize_error(response: dict) -> RiverError:
return RiverError.model_validate(response)


# RPC method handlers for testing
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
return f"Hello, {request}!"


async def subscription_handler(
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(5):
yield f"Subscription message {i} for {request}"


async def upload_handler(
request: Iterator[str] | AsyncIterator[str], context: Any
) -> str:
uploaded_data = []
if isinstance(request, AsyncIterator):
async for data in request:
uploaded_data.append(data)
else:
for data in request:
uploaded_data.append(data)
return f"Uploaded: {', '.join(uploaded_data)}"


async def stream_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
if isinstance(request, AsyncIterator):
async for data in request:
yield f"Stream response for {data}"
else:
for data in request:
yield f"Stream response for {data}"


async def stream_error_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
raise RiverException("INJECTED_ERROR", "test error")
yield "test" # appease the type checker


@pytest.fixture
def transport_options() -> TransportOptions:
return TransportOptions()


@pytest.fixture
def server(transport_options: TransportOptions) -> Server:
def server_handlers(handlers: HandlerMapping) -> HandlerMapping:
return handlers


@pytest.fixture
def server(
transport_options: TransportOptions, server_handlers: HandlerMapping
) -> Server:
server = Server(server_id="test_server", transport_options=transport_options)
server.add_rpc_handlers(
{
("test_service", "rpc_method"): (
"rpc",
rpc_method_handler(
rpc_handler, deserialize_request, serialize_response
),
),
("test_service", "subscription_method"): (
"subscription",
subscription_method_handler(
subscription_handler, deserialize_request, serialize_response
),
),
("test_service", "upload_method"): (
"upload",
upload_method_handler(
upload_handler, deserialize_request, serialize_response
),
),
("test_service", "stream_method"): (
"stream",
stream_method_handler(
stream_handler, deserialize_request, serialize_response
),
),
("test_service", "stream_method_error"): (
"stream",
stream_method_handler(
stream_error_handler, deserialize_request, serialize_response
),
),
}
)
server.add_rpc_handlers(server_handlers)
return server


Expand Down
20 changes: 19 additions & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,21 @@
from replit_river.client import Client
from replit_river.error_schema import RiverError
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
from tests.conftest import deserialize_error, deserialize_response, serialize_request
from tests.common_handlers import (
basic_rpc_method,
basic_stream,
basic_subscription,
basic_upload,
)
from tests.conftest import (
deserialize_error,
deserialize_response,
serialize_request,
)


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_rpc_method}])
async def test_rpc_method(client: Client) -> None:
response = await client.send_rpc(
"test_service",
Expand All @@ -23,6 +34,7 @@ async def test_rpc_method(client: Client) -> None:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload}])
async def test_upload_method(client: Client) -> None:
async def upload_data() -> AsyncGenerator[str, None]:
yield "Data 1"
Expand All @@ -43,6 +55,7 @@ async def upload_data() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload}])
async def test_upload_more_than_send_buffer_max(client: Client) -> None:
iterations = MAX_MESSAGE_BUFFER_SIZE * 2

Expand All @@ -64,6 +77,7 @@ async def upload_data() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload}])
async def test_upload_empty(client: Client) -> None:
async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
if enabled:
Expand All @@ -83,6 +97,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
async def test_subscription_method(client: Client) -> None:
async for response in client.send_subscription(
"test_service",
Expand All @@ -97,6 +112,7 @@ async def test_subscription_method(client: Client) -> None:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_stream}])
async def test_stream_method(client: Client) -> None:
async def stream_data() -> AsyncGenerator[str, None]:
yield "Stream 1"
Expand Down Expand Up @@ -125,6 +141,7 @@ async def stream_data() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_stream}])
async def test_stream_empty(client: Client) -> None:
async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:
if enabled:
Expand All @@ -147,6 +164,7 @@ async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_upload, **basic_stream}])
async def test_multiplexing(client: Client) -> None:
async def upload_data() -> AsyncGenerator[str, None]:
yield "Upload Data 1"
Expand Down
1 change: 1 addition & 0 deletions tests/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def transport_options() -> TransportOptions:


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{}])
async def test_handshake_timeout(server: Server) -> None:
async with serve(server.serve, "localhost", 8765):
start = time()
Expand Down
Loading

0 comments on commit 502278f

Please sign in to comment.