Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply "Make client accept a function for websocket uri and hadnshakemetadata (#62) (#71) #73

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator
from collections.abc import AsyncIterable, AsyncIterator, Awaitable
from typing import Any, Callable, Generic, Optional, TypeVar, Union

from replit_river.client_transport import ClientTransport
Expand All @@ -21,20 +21,22 @@
class Client(Generic[HandshakeType]):
def __init__(
self,
websocket_uri: str,
websocket_uri_factory: Callable[[], Awaitable[str]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[HandshakeType] = None,
handshake_metadata_factory: Optional[
Callable[[], Awaitable[HandshakeType]]
] = None,
) -> None:
self._client_id = client_id
self._server_id = server_id
self._transport = ClientTransport[HandshakeType](
websocket_uri=websocket_uri,
websocket_uri_factory=websocket_uri_factory,
client_id=client_id,
server_id=server_id,
transport_options=transport_options,
handshake_metadata=handshake_metadata,
handshake_metadata_factory=handshake_metadata_factory,
)

async def close(self) -> None:
Expand Down
21 changes: 15 additions & 6 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Generic, Optional, Tuple, TypeVar

import websockets
Expand Down Expand Up @@ -47,24 +48,26 @@
class ClientTransport(Transport, Generic[HandshakeType]):
def __init__(
self,
websocket_uri: str,
websocket_uri_factory: Callable[[], Awaitable[str]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[HandshakeType] = None,
handshake_metadata_factory: Optional[
Callable[[], Awaitable[HandshakeType]]
] = None,
):
super().__init__(
transport_id=client_id,
transport_options=transport_options,
is_server=False,
)
self._websocket_uri = websocket_uri
self._websocket_uri_factory = websocket_uri_factory
self._client_id = client_id
self._server_id = server_id
self._rate_limiter = LeakyBucketRateLimit(
transport_options.connection_retry_options
)
self._handshake_metadata = handshake_metadata
self._handshake_metadata_factory = handshake_metadata_factory
# We want to make sure there's only one session creation at a time
self._create_session_lock = asyncio.Lock()

Expand Down Expand Up @@ -110,12 +113,18 @@ async def _establish_new_connection(
break
rate_limit.consume_budget(client_id)
try:
ws = await websockets.connect(self._websocket_uri)
websocket_uri = await self._websocket_uri_factory()
ws = await websockets.connect(websocket_uri)
session_id = (
self.generate_session_id()
if not old_session
else old_session.session_id
)

handshake_metadata: Optional[HandshakeType] = None
if self._handshake_metadata_factory is not None:
handshake_metadata = await self._handshake_metadata_factory()

try:
(
handshake_request,
Expand All @@ -124,7 +133,7 @@ async def _establish_new_connection(
self._transport_id,
self._server_id,
session_id,
self._handshake_metadata,
handshake_metadata,
ws,
old_session,
)
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,14 @@ async def client(
transport_options: TransportOptions,
no_logging_error: NoErrors,
) -> AsyncGenerator[Client, None]:

async def websocket_uri_factory() -> str:
return "ws://localhost:8765"

try:
async with serve(server.serve, "localhost", 8765):
client: Client[NoReturn] = Client(
"ws://localhost:8765",
websocket_uri_factory,
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
Expand Down
Loading