Skip to content

Commit

Permalink
Revert "Revert "Make client accept a function for websocket uri and h…
Browse files Browse the repository at this point in the history
…adnshakemetadata (#62) (#71)"

This reverts commit 2dc3593.
  • Loading branch information
blast-hardcheese committed Aug 20, 2024
1 parent d5aabb4 commit 1f6efd6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
12 changes: 7 additions & 5 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator
from collections.abc import AsyncIterable, AsyncIterator, Awaitable
from typing import Any, Callable, Generic, Optional, TypeVar, Union

from replit_river.client_transport import ClientTransport
Expand All @@ -21,20 +21,22 @@
class Client(Generic[HandshakeType]):
def __init__(
self,
websocket_uri: str,
websocket_uri_factory: Callable[[], Awaitable[str]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[HandshakeType] = None,
handshake_metadata_factory: Optional[
Callable[[], Awaitable[HandshakeType]]
] = None,
) -> None:
self._client_id = client_id
self._server_id = server_id
self._transport = ClientTransport[HandshakeType](
websocket_uri=websocket_uri,
websocket_uri_factory=websocket_uri_factory,
client_id=client_id,
server_id=server_id,
transport_options=transport_options,
handshake_metadata=handshake_metadata,
handshake_metadata_factory=handshake_metadata_factory,
)

async def close(self) -> None:
Expand Down
21 changes: 15 additions & 6 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Generic, Optional, Tuple, TypeVar

import websockets
Expand Down Expand Up @@ -47,24 +48,26 @@
class ClientTransport(Transport, Generic[HandshakeType]):
def __init__(
self,
websocket_uri: str,
websocket_uri_factory: Callable[[], Awaitable[str]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[HandshakeType] = None,
handshake_metadata_factory: Optional[
Callable[[], Awaitable[HandshakeType]]
] = None,
):
super().__init__(
transport_id=client_id,
transport_options=transport_options,
is_server=False,
)
self._websocket_uri = websocket_uri
self._websocket_uri_factory = websocket_uri_factory
self._client_id = client_id
self._server_id = server_id
self._rate_limiter = LeakyBucketRateLimit(
transport_options.connection_retry_options
)
self._handshake_metadata = handshake_metadata
self._handshake_metadata_factory = handshake_metadata_factory
# We want to make sure there's only one session creation at a time
self._create_session_lock = asyncio.Lock()

Expand Down Expand Up @@ -110,12 +113,18 @@ async def _establish_new_connection(
break
rate_limit.consume_budget(client_id)
try:
ws = await websockets.connect(self._websocket_uri)
websocket_uri = await self._websocket_uri_factory()
ws = await websockets.connect(websocket_uri)
session_id = (
self.generate_session_id()
if not old_session
else old_session.session_id
)

handshake_metadata: Optional[HandshakeType] = None
if self._handshake_metadata_factory is not None:
handshake_metadata = await self._handshake_metadata_factory()

try:
(
handshake_request,
Expand All @@ -124,7 +133,7 @@ async def _establish_new_connection(
self._transport_id,
self._server_id,
session_id,
self._handshake_metadata,
handshake_metadata,
ws,
old_session,
)
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,14 @@ async def client(
transport_options: TransportOptions,
no_logging_error: NoErrors,
) -> AsyncGenerator[Client, None]:

async def websocket_uri_factory() -> str:
return "ws://localhost:8765"

try:
async with serve(server.serve, "localhost", 8765):
client: Client[NoReturn] = Client(
"ws://localhost:8765",
websocket_uri_factory,
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
Expand Down

0 comments on commit 1f6efd6

Please sign in to comment.