Skip to content

Commit

Permalink
Fix repo up (#18)
Browse files Browse the repository at this point in the history
* Fix repo up

* Update ci.yml

* Update .replit
  • Loading branch information
masad-frost authored May 16, 2024
1 parent 3d2cae2 commit 3d769eb
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 19 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ jobs:
- name: Install dependencies
run: |
poetry install --no-interaction
- name: Format check
run: |
poetry run black --check .
- name: Lint check
run: |
poetry run ruff .
- name: Type check
run: |
poetry run mypy .
- name: Test with pytest
run: |
poetry run pytest tests
6 changes: 6 additions & 0 deletions .replit
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
run = "poetry run pytest tests"

modules = ["python-3.11"]

[nix]
channel = "stable-23_11"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mypy = "^1.4.0"
black = ">=23.3,<25.0"
pytest-cov = "^4.1.0"
ruff = "^0.0.278"

pytest-mock = "^3.11.1"
pytest-asyncio = "^0.21.1"
types-protobuf = "^4.24.0.20240311"
Expand Down
14 changes: 14 additions & 0 deletions replit.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{pkgs}: {
deps = [
pkgs.borgbackup
pkgs.rustc
pkgs.libiconv
pkgs.cargo
pkgs.libxcrypt
pkgs.zlib
pkgs.pkg-config
pkgs.openssl
pkgs.grpc
pkgs.c-ares
];
}
3 changes: 2 additions & 1 deletion replit_river/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def get_backoff_ms(self, user: str) -> float:
exponent = max(0, self.get_budget_consumed(user) - 1)
jitter = random.randint(0, self.options.max_jitter_ms)
backoff_ms = min(
self.options.base_interval_ms * (2**exponent), self.options.max_backoff_ms
float(self.options.base_interval_ms * (2**exponent)),
float(self.options.max_backoff_ms),
)
return backoff_ms + jitter

Expand Down
36 changes: 24 additions & 12 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
STREAM_OPEN_BIT = 0x0002
STREAM_CLOSED_BIT = 0x0004


# Equivalent of https://github.com/replit/river/blob/c1345f1ff6a17a841d4319fad5c153b5bda43827/transport/message.ts#L23-L33


Expand Down Expand Up @@ -178,11 +177,13 @@ def rpc_method_handler(
request_deserializer: Callable[[Any], RequestType],
response_serializer: Callable[[ResponseType], Any],
) -> GenericRpcHandler:

async def wrapped(
peer: str,
input: Channel[Any],
output: Channel[Any],
) -> None:
context = None
try:
context = GrpcContext(peer)
request = request_deserializer(await input.get())
Expand All @@ -191,13 +192,17 @@ async def wrapped(
get_response_or_error_payload(response, response_serializer)
)
except grpc.RpcError:
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"
message = (
f"{method.__name__} threw an exception: "
f"{context._abort_details if context else 'Unknown error details'}"
)
await output.put(
{
"ok": False,
"payload": {
"code": grpc.StatusCode(context._abort_code).name,
"message": f"{method.__name__} threw an exception: "
f"{context._abort_details}",
"code": code,
"message": message,
},
}
)
Expand Down Expand Up @@ -230,6 +235,7 @@ async def wrapped(
input: Channel[Any],
output: Channel[Any],
) -> None:
context = None
try:
context = GrpcContext(peer)
request = request_deserializer(await input.get())
Expand All @@ -238,14 +244,15 @@ async def wrapped(
get_response_or_error_payload(response, response_serializer)
)
except grpc.RpcError:
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"
message = (
f"{method.__name__} threw an exception: "
f"{context._abort_details if context else 'Unknown error details'}"
)
await output.put(
{
"ok": False,
"payload": {
"code": grpc.StatusCode(context._abort_code).name,
"message": f"{method.__name__} threw an exception: "
f"{context._abort_details}",
},
"payload": {"code": code, "message": message},
}
)
except Exception as e:
Expand Down Expand Up @@ -348,6 +355,7 @@ async def wrapped(
output: Channel[Any],
) -> None:
task_manager = BackgroundTaskManager()
context = None
try:
context = GrpcContext(peer)
request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE)
Expand Down Expand Up @@ -375,13 +383,17 @@ async def _convert_outputs() -> None:
await asyncio.wait((convert_inputs_task, convert_outputs_task))
except grpc.RpcError:
logging.exception("RPC exception in stream")
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"
message = (
f"{method.__name__} threw an exception: "
f"{context._abort_details if context else 'Unknown error details'}"
)
await output.put(
{
"ok": False,
"payload": {
"code": grpc.StatusCode(context._abort_code).name,
"message": f"{method.__name__} threw an exception: "
f"{context._abort_details}",
"code": code,
"message": message,
},
}
)
Expand Down
11 changes: 5 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from typing import Any, AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any, AsyncGenerator, Generator
from unittest.mock import MagicMock, patch

import nanoid # type: ignore
Expand Down Expand Up @@ -74,17 +75,15 @@ async def subscription_handler(
yield f"Subscription message {i} for {request}"


async def upload_handler(
request: AsyncGenerator[str, None], context: GrpcContext
) -> str:
async def upload_handler(request: AsyncIterator[str], context: Any) -> str:
uploaded_data = []
async for data in request:
uploaded_data.append(data)
return f"Uploaded: {', '.join(uploaded_data)}"


async def stream_handler(
request: AsyncGenerator[str, None], context: GrpcContext
request: AsyncIterator[str], context: GrpcContext
) -> AsyncGenerator[str, None]:
async for data in request:
yield f"Stream response for {data}"
Expand Down Expand Up @@ -130,7 +129,7 @@ def server(transport_options: TransportOptions) -> Server:


@pytest.fixture
def no_logging_error() -> MagicMock:
def no_logging_error() -> Generator[MagicMock, None, None]:
with patch("logging.error") as mock_error:
yield mock_error

Expand Down

0 comments on commit 3d769eb

Please sign in to comment.