Skip to content

Commit

Permalink
Implement RPC timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese committed Nov 29, 2024
1 parent 76af9db commit 9b57ee1
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 2 deletions.
3 changes: 3 additions & 0 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Generator, Generic, Literal, Optional, Union

from opentelemetry import trace
Expand Down Expand Up @@ -60,6 +61,7 @@ async def send_rpc(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
timeout: timedelta,
) -> ResponseType:
with _trace_procedure("rpc", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
Expand All @@ -71,6 +73,7 @@ async def send_rpc(
response_deserializer,
error_deserializer,
span,
timeout,
)

async def send_upload(
Expand Down
18 changes: 17 additions & 1 deletion replit_river/client_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from collections.abc import AsyncIterable, AsyncIterator
from datetime import timedelta
from typing import Any, Callable, Optional, Union

import nanoid # type: ignore
Expand All @@ -8,6 +10,7 @@
from opentelemetry.trace import Span

from replit_river.error_schema import (
ERROR_CODE_CANCEL,
ERROR_CODE_STREAM_CLOSED,
RiverException,
RiverServiceException,
Expand Down Expand Up @@ -39,6 +42,7 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
timeout: timedelta,
) -> ResponseType:
"""Sends a single RPC request to the server.
Expand All @@ -58,7 +62,19 @@ async def send_rpc(
# Handle potential errors during communication
try:
try:
response = await output.get()
async with asyncio.timeout(int(timeout.total_seconds())):
response = await output.get()
except asyncio.CancelledError as e:
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT
await self.send_message(
stream_id=stream_id,
control_flags=STREAM_CLOSED_BIT,
payload={"type": "CLOSE"},
service_name=service_name,
procedure_name=procedure_name,
span=span,
)
raise RiverException(ERROR_CODE_CANCEL, str(e)) from e
except ChannelClosed as e:
raise RiverServiceException(
ERROR_CODE_STREAM_CLOSED,
Expand Down
3 changes: 3 additions & 0 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime
from pydantic import TypeAdapter
Expand Down Expand Up @@ -857,6 +858,7 @@ def __init__(self, client: river.Client[Any]):
async def {name}(
self,
input: {render_type_expr(input_type)},
timeout: datetime.timedelta,
) -> {render_type_expr(output_type)}:
return await self.client.send_rpc(
{repr(schema_name)},
Expand All @@ -865,6 +867,7 @@ async def {name}(
{reindent(" ", render_input_method)},
{reindent(" ", parse_output_method)},
{reindent(" ", parse_error_method)},
timeout,
)
""",
)
Expand Down
2 changes: 1 addition & 1 deletion replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
]
ACK_BIT = 0x0001
STREAM_OPEN_BIT = 0x0002
STREAM_CLOSED_BIT = 0x0004
STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2

# these codes are retriable
# if the server sends a response with one of these codes,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_communication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from datetime import timedelta
from typing import AsyncGenerator

import pytest
Expand Down Expand Up @@ -29,6 +30,7 @@ async def test_rpc_method(client: Client) -> None:
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
assert response == "Hello, Alice!"

Expand Down
2 changes: 2 additions & 0 deletions tests/test_opentelemetry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from typing import AsyncGenerator, AsyncIterator, Iterator

import grpc
Expand Down Expand Up @@ -38,6 +39,7 @@ async def test_rpc_method_span(
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
assert response == "Hello, Alice!"
spans = span_exporter.get_finished_spans()
Expand Down

0 comments on commit 9b57ee1

Please sign in to comment.