From 9b57ee10510dd24bd2f584c1475f174d244c040c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 27 Nov 2024 10:58:24 -0800 Subject: [PATCH] Implement RPC timeouts --- replit_river/client.py | 3 +++ replit_river/client_session.py | 18 +++++++++++++++++- replit_river/codegen/client.py | 3 +++ replit_river/rpc.py | 2 +- tests/test_communication.py | 2 ++ tests/test_opentelemetry.py | 2 ++ 6 files changed, 28 insertions(+), 2 deletions(-) diff --git a/replit_river/client.py b/replit_river/client.py index ec8b1dc..5f2f480 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -1,6 +1,7 @@ import logging from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from contextlib import contextmanager +from datetime import timedelta from typing import Any, Generator, Generic, Literal, Optional, Union from opentelemetry import trace @@ -60,6 +61,7 @@ async def send_rpc( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], + timeout: timedelta, ) -> ResponseType: with _trace_procedure("rpc", service_name, procedure_name) as span: session = await self._transport.get_or_create_session() @@ -71,6 +73,7 @@ async def send_rpc( response_deserializer, error_deserializer, span, + timeout, ) async def send_upload( diff --git a/replit_river/client_session.py b/replit_river/client_session.py index fbffa59..dddf6a2 100644 --- a/replit_river/client_session.py +++ b/replit_river/client_session.py @@ -1,5 +1,7 @@ +import asyncio import logging from collections.abc import AsyncIterable, AsyncIterator +from datetime import timedelta from typing import Any, Callable, Optional, Union import nanoid # type: ignore @@ -8,6 +10,7 @@ from opentelemetry.trace import Span from replit_river.error_schema import ( + ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, RiverException, RiverServiceException, @@ -39,6 +42,7 @@ async def send_rpc( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], span: Span, + timeout: timedelta, ) -> ResponseType: """Sends a single RPC request to the server. @@ -58,7 +62,19 @@ async def send_rpc( # Handle potential errors during communication try: try: - response = await output.get() + async with asyncio.timeout(int(timeout.total_seconds())): + response = await output.get() + except asyncio.CancelledError as e: + # TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_CLOSED_BIT, + payload={"type": "CLOSE"}, + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + raise RiverException(ERROR_CODE_CANCEL, str(e)) from e except ChannelClosed as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index f7b6ae5..74fe78b 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -56,6 +56,7 @@ # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator from typing import Any +import datetime from pydantic import TypeAdapter @@ -857,6 +858,7 @@ def __init__(self, client: river.Client[Any]): async def {name}( self, input: {render_type_expr(input_type)}, + timeout: datetime.timedelta, ) -> {render_type_expr(output_type)}: return await self.client.send_rpc( {repr(schema_name)}, @@ -865,6 +867,7 @@ async def {name}( {reindent(" ", render_input_method)}, {reindent(" ", parse_output_method)}, {reindent(" ", parse_error_method)}, + timeout, ) """, ) diff --git a/replit_river/rpc.py b/replit_river/rpc.py index 53152c4..0c8d11f 100644 --- a/replit_river/rpc.py +++ b/replit_river/rpc.py @@ -51,7 +51,7 @@ ] ACK_BIT = 0x0001 STREAM_OPEN_BIT = 0x0002 -STREAM_CLOSED_BIT = 0x0004 +STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 # these codes are retriable # if the server sends a response with one of these codes, diff --git a/tests/test_communication.py b/tests/test_communication.py index d50ff9c..fc86d27 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -1,4 +1,5 @@ import asyncio +from datetime import timedelta from typing import AsyncGenerator import pytest @@ -29,6 +30,7 @@ async def test_rpc_method(client: Client) -> None: serialize_request, deserialize_response, deserialize_error, + timedelta(seconds=20), ) assert response == "Hello, Alice!" diff --git a/tests/test_opentelemetry.py b/tests/test_opentelemetry.py index 0b47982..9cc227d 100644 --- a/tests/test_opentelemetry.py +++ b/tests/test_opentelemetry.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import AsyncGenerator, AsyncIterator, Iterator import grpc @@ -38,6 +39,7 @@ async def test_rpc_method_span( serialize_request, deserialize_response, deserialize_error, + timedelta(seconds=20), ) assert response == "Hello, Alice!" spans = span_exporter.get_finished_spans()