Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python 3.9 support #140

Merged
merged 9 commits into from
Mar 27, 2024
Merged
Prev Previous commit
Next Next commit
Avoid type|type (PEP 604) which wasn't added until Python 3.10
chriso committed Mar 27, 2024
commit 9c5bda456b070bf66ee941b027f5b64f9947b580
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@ dependencies = [
"tblib >= 3.0.0",
"docopt >= 0.6.2",
"types-docopt >= 0.6.11.4",
"httpx >= 0.27.0"
"httpx >= 0.27.0",
"typing_extensions >= 4.10"
]

[project.optional-dependencies]
26 changes: 13 additions & 13 deletions src/dispatch/experimental/durable/frame.pyi
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@
from types import FrameType
from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple
from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple, Union

def get_frame_ip(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int:
def get_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int:
"""Get instruction pointer of a generator or coroutine."""

def set_frame_ip(frame: FrameType | Coroutine | Generator | AsyncGenerator, ip: int):
def set_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ip: int):
"""Set instruction pointer of a generator or coroutine."""

def get_frame_sp(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int:
def get_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int:
"""Get stack pointer of a generator or coroutine."""

def set_frame_sp(frame: FrameType | Coroutine | Generator | AsyncGenerator, sp: int):
def set_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], sp: int):
"""Set stack pointer of a generator or coroutine."""

def get_frame_bp(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int:
def get_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int:
"""Get block pointer of a generator or coroutine."""

def set_frame_bp(frame: FrameType | Coroutine | Generator | AsyncGenerator, bp: int):
def set_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], bp: int):
"""Set block pointer of a generator or coroutine."""

def get_frame_stack_at(
frame: FrameType | Coroutine | Generator | AsyncGenerator, index: int
frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int
) -> Tuple[bool, Any]:
"""Get an object from a generator or coroutine's stack, as an (is_null, obj) tuple."""

def set_frame_stack_at(
frame: FrameType | Coroutine | Generator | AsyncGenerator,
frame: Union[FrameType, Coroutine, Generator, AsyncGenerator],
index: int,
unset: bool,
value: Any,
):
"""Set or unset an object on the stack of a generator or coroutine."""

def get_frame_block_at(
frame: FrameType | Coroutine | Generator | AsyncGenerator, index: int
frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int
) -> Tuple[int, int, int]:
"""Get a block from a generator or coroutine."""

def set_frame_block_at(
frame: FrameType | Coroutine | Generator | AsyncGenerator,
frame: Union[FrameType, Coroutine, Generator, AsyncGenerator],
index: int,
value: Tuple[int, int, int],
):
"""Restore a block of a generator or coroutine."""

def get_frame_state(
frame: FrameType | Coroutine | Generator | AsyncGenerator,
frame: Union[FrameType, Coroutine, Generator, AsyncGenerator],
) -> int:
"""Get frame state of a generator or coroutine."""

def set_frame_state(
frame: FrameType | Coroutine | Generator | AsyncGenerator, state: int
frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], state: int
):
"""Set frame state of a generator or coroutine."""
18 changes: 9 additions & 9 deletions src/dispatch/experimental/durable/function.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
MethodType,
TracebackType,
)
from typing import Any, Callable, Coroutine, Generator, TypeVar, Union, cast
from typing import Any, Callable, Coroutine, Generator, TypeVar, Union, cast, Optional

from . import frame as ext
from .registry import RegisteredFunction, lookup_function, register_function
@@ -75,15 +75,15 @@ class Serializable:
"__qualname__",
)

g: GeneratorType | CoroutineType
g: Union[GeneratorType, CoroutineType]
registered_fn: RegisteredFunction
wrapped_coroutine: Union["DurableCoroutine", None]
args: tuple[Any, ...]
kwargs: dict[str, Any]

def __init__(
self,
g: GeneratorType | CoroutineType,
g: Union[GeneratorType, CoroutineType],
registered_fn: RegisteredFunction,
wrapped_coroutine: Union["DurableCoroutine", None],
*args: Any,
@@ -243,7 +243,7 @@ def __await__(self) -> Generator[Any, None, _ReturnT]:
def send(self, send: _SendT) -> _YieldT:
return self.coroutine.send(send)

def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT:
def throw(self, typ, val=None, tb: Optional[TracebackType] = None) -> _YieldT:
return self.coroutine.throw(typ, val, tb)

def close(self) -> None:
@@ -270,11 +270,11 @@ def cr_frame(self) -> FrameType:
return self.coroutine.cr_frame

@property
def cr_await(self) -> Any | None:
def cr_await(self) -> Any:
return self.coroutine.cr_await

@property
def cr_origin(self) -> tuple[tuple[str, int, str], ...] | None:
def cr_origin(self) -> Optional[tuple[tuple[str, int, str], ...]]:
return self.coroutine.cr_origin

def __repr__(self) -> str:
@@ -291,7 +291,7 @@ def __init__(
self,
generator: GeneratorType,
registered_fn: RegisteredFunction,
coroutine: DurableCoroutine | None,
coroutine: Optional[DurableCoroutine],
*args: Any,
**kwargs: Any,
):
@@ -309,7 +309,7 @@ def __next__(self) -> _YieldT:
def send(self, send: _SendT) -> _YieldT:
return self.generator.send(send)

def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT:
def throw(self, typ, val=None, tb: Optional[TracebackType] = None) -> _YieldT:
return self.generator.throw(typ, val, tb)

def close(self) -> None:
@@ -336,7 +336,7 @@ def gi_frame(self) -> FrameType:
return self.generator.gi_frame

@property
def gi_yieldfrom(self) -> GeneratorType | None:
def gi_yieldfrom(self) -> Optional[GeneratorType]:
return self.generator.gi_yieldfrom

def __repr__(self) -> str:
15 changes: 8 additions & 7 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ def read_root():
import os
from datetime import timedelta
from urllib.parse import urlparse
from typing import Optional, Union

import fastapi
import fastapi.responses
@@ -51,10 +52,10 @@ class Dispatch(Registry):
def __init__(
self,
app: fastapi.FastAPI,
endpoint: str | None = None,
verification_key: Ed25519PublicKey | str | bytes | None = None,
api_key: str | None = None,
api_url: str | None = None,
endpoint: Optional[str] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a FastAPI app.
@@ -122,8 +123,8 @@ def __init__(


def parse_verification_key(
verification_key: Ed25519PublicKey | str | bytes | None,
) -> Ed25519PublicKey | None:
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]],
) -> Optional[Ed25519PublicKey]:
if isinstance(verification_key, Ed25519PublicKey):
return verification_key

@@ -169,7 +170,7 @@ def __init__(self, status, code, message):
self.message = message


def _new_app(function_registry: Dispatch, verification_key: Ed25519PublicKey | None):
def _new_app(function_registry: Dispatch, verification_key: Optional[Ed25519PublicKey]):
app = fastapi.FastAPI()

@app.exception_handler(_ConnectError)
12 changes: 6 additions & 6 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@
Dict,
Generic,
Iterable,
ParamSpec,
TypeAlias,
TypeVar,
Optional,
overload,
)
from typing_extensions import ParamSpec, TypeAlias
from urllib.parse import urlparse

import grpc
@@ -73,7 +73,7 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID:
return dispatch_id

def _build_primitive_call(
self, input: Any, correlation_id: int | None = None
self, input: Any, correlation_id: Optional[int] = None
) -> Call:
return Call(
correlation_id=correlation_id,
@@ -137,7 +137,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
return self._primitive_dispatch(Arguments(args, kwargs))

def build_call(
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
self, *args: P.args, correlation_id: Optional[int] = None, **kwargs: P.kwargs
) -> Call:
"""Create a Call for this function with the provided input. Useful to
generate calls when using the Client.
@@ -162,7 +162,7 @@ class Registry:
__slots__ = ("functions", "endpoint", "client")

def __init__(
self, endpoint: str, api_key: str | None = None, api_url: str | None = None
self, endpoint: str, api_key: Optional[str] = None, api_url: Optional[str] = None
):
"""Initialize a function registry.
@@ -261,7 +261,7 @@ class Client:

__slots__ = ("api_url", "api_key", "_stub", "api_key_from")

def __init__(self, api_key: None | str = None, api_url: None | str = None):
def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None):
"""Create a new Dispatch client.
Args:
2 changes: 1 addition & 1 deletion src/dispatch/id.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeAlias
from typing_extensions import TypeAlias

DispatchID: TypeAlias = str
"""Unique identifier in Dispatch.
42 changes: 21 additions & 21 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from dataclasses import dataclass
from traceback import format_exception
from types import TracebackType
from typing import Any
from typing import Any, Optional

import google.protobuf.any_pb2
import google.protobuf.message
@@ -97,7 +97,7 @@ def call_results(self) -> list[CallResult]:
return self._call_results

@property
def poll_error(self) -> Error | None:
def poll_error(self) -> Optional[Error]:
self._assert_resume()
return self._poll_error

@@ -125,7 +125,7 @@ def from_poll_results(
function: str,
coroutine_state: Any,
call_results: list[CallResult],
error: Error | None = None,
error: Optional[Error] = None,
):
return Input(
req=function_pb.RunRequest(
@@ -163,7 +163,7 @@ def __init__(self, proto: function_pb.RunResponse):
self._message = proto

@classmethod
def value(cls, value: Any, status: Status | None = None) -> Output:
def value(cls, value: Any, status: Optional[Status] = None) -> Output:
"""Terminally exit the function with the provided return value."""
if status is None:
status = status_for_output(value)
@@ -183,8 +183,8 @@ def tail_call(cls, tail_call: Call) -> Output:
@classmethod
def exit(
cls,
result: CallResult | None = None,
tail_call: Call | None = None,
result: Optional[CallResult] = None,
tail_call: Optional[Call] = None,
status: Status = Status.OK,
) -> Output:
"""Terminally exit the function."""
@@ -201,10 +201,10 @@ def exit(
def poll(
cls,
state: Any,
calls: None | list[Call] = None,
calls: Optional[list[Call]] = None,
min_results: int = 1,
max_results: int = 10,
max_wait_seconds: int | None = None,
max_wait_seconds: Optional[int] = None,
) -> Output:
"""Suspend the function with a set of Calls, instructing the
orchestrator to resume the function with the provided state when
@@ -249,9 +249,9 @@ class Call:
"""

function: str
input: Any | None = None
endpoint: str | None = None
correlation_id: int | None = None
input: Optional[Any] = None
endpoint: Optional[str] = None
correlation_id: Optional[int] = None

def _as_proto(self) -> call_pb.Call:
input_bytes = _pb_any_pickle(self.input)
@@ -267,9 +267,9 @@ def _as_proto(self) -> call_pb.Call:
class CallResult:
"""Result of a Call."""

correlation_id: int | None = None
output: Any | None = None
error: Error | None = None
correlation_id: Optional[int] = None
output: Optional[Any] = None
error: Optional[Error] = None

def _as_proto(self) -> call_pb.CallResult:
output_any = None
@@ -297,11 +297,11 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult:
)

@classmethod
def from_value(cls, output: Any, correlation_id: int | None = None) -> CallResult:
def from_value(cls, output: Any, correlation_id: Optional[int] = None) -> CallResult:
return CallResult(correlation_id=correlation_id, output=output)

@classmethod
def from_error(cls, error: Error, correlation_id: int | None = None) -> CallResult:
def from_error(cls, error: Error, correlation_id: Optional[int] = None) -> CallResult:
return CallResult(correlation_id=correlation_id, error=error)


@@ -316,16 +316,16 @@ class Error:
status: Status
type: str
message: str
value: Exception | None = None
traceback: bytes | None = None
value: Optional[Exception] = None
traceback: Optional[bytes] = None

def __init__(
self,
status: Status,
type: str,
message: str,
value: Exception | None = None,
traceback: bytes | None = None,
value: Optional[Exception] = None,
traceback: Optional[bytes] = None,
):
"""Create a new Error.
@@ -355,7 +355,7 @@ def __init__(
self.traceback = "".join(format_exception(value)).encode("utf-8")

@classmethod
def from_exception(cls, ex: Exception, status: Status | None = None) -> Error:
def from_exception(cls, ex: Exception, status: Optional[Status] = None) -> Error:
"""Create an Error from a Python exception, using its class qualified
named as type.
55 changes: 28 additions & 27 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,8 @@
import pickle
import sys
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Protocol, TypeAlias
from typing import Any, Awaitable, Callable, Protocol, Optional, Union
from typing_extensions import TypeAlias

from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective
from dispatch.error import IncompatibleStateError
@@ -22,27 +23,27 @@ class CoroutineResult:
"""The result from running a coroutine to completion."""

coroutine_id: CoroutineID
value: Any | None = None
error: Exception | None = None
value: Optional[Any] = None
error: Optional[Exception] = None


@dataclass
class CallResult:
"""The result of an asynchronous function call."""

call_id: CallID
value: Any | None = None
error: Exception | None = None
value: Optional[Any] = None
error: Optional[Exception] = None


class Future(Protocol):
def add_result(self, result: CallResult | CoroutineResult): ...
def add_result(self, result: Union[CallResult, CoroutineResult]): ...

def add_error(self, error: Exception): ...

def ready(self) -> bool: ...

def error(self) -> Exception | None: ...
def error(self) -> Optional[Exception]: ...

def value(self) -> Any: ...

@@ -51,10 +52,10 @@ def value(self) -> Any: ...
class CallFuture:
"""A future result of a dispatch.coroutine.call() operation."""

result: CallResult | None = None
first_error: Exception | None = None
result: Optional[CallResult] = None
first_error: Optional[Exception] = None

def add_result(self, result: CallResult | CoroutineResult):
def add_result(self, result: Union[CallResult, CoroutineResult]):
assert isinstance(result, CallResult)
if self.result is None:
self.result = result
@@ -68,7 +69,7 @@ def add_error(self, error: Exception):
def ready(self) -> bool:
return self.first_error is not None or self.result is not None

def error(self) -> Exception | None:
def error(self) -> Optional[Exception]:
assert self.ready()
return self.first_error

@@ -85,9 +86,9 @@ class AllFuture:
order: list[CoroutineID] = field(default_factory=list)
waiting: set[CoroutineID] = field(default_factory=set)
results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict)
first_error: Exception | None = None
first_error: Optional[Exception] = None

def add_result(self, result: CallResult | CoroutineResult):
def add_result(self, result: Union[CallResult, CoroutineResult]):
assert isinstance(result, CoroutineResult)

try:
@@ -109,7 +110,7 @@ def add_error(self, error: Exception):
def ready(self) -> bool:
return self.first_error is not None or len(self.waiting) == 0

def error(self) -> Exception | None:
def error(self) -> Optional[Exception]:
assert self.ready()
return self.first_error

@@ -126,11 +127,11 @@ class AnyFuture:

order: list[CoroutineID] = field(default_factory=list)
waiting: set[CoroutineID] = field(default_factory=set)
first_result: CoroutineResult | None = None
first_result: Optional[CoroutineResult] = None
errors: dict[CoroutineID, Exception] = field(default_factory=dict)
generic_error: Exception | None = None
generic_error: Optional[Exception] = None

def add_result(self, result: CallResult | CoroutineResult):
def add_result(self, result: Union[CallResult, CoroutineResult]):
assert isinstance(result, CoroutineResult)

try:
@@ -156,7 +157,7 @@ def ready(self) -> bool:
or len(self.waiting) == 0
)

def error(self) -> Exception | None:
def error(self) -> Optional[Exception]:
assert self.ready()
if self.generic_error is not None:
return self.generic_error
@@ -182,10 +183,10 @@ class RaceFuture:
"""A future result of a dispatch.coroutine.race() operation."""

waiting: set[CoroutineID] = field(default_factory=set)
first_result: CoroutineResult | None = None
first_error: Exception | None = None
first_result: Optional[CoroutineResult] = None
first_error: Optional[Exception] = None

def add_result(self, result: CallResult | CoroutineResult):
def add_result(self, result: Union[CallResult, CoroutineResult]):
assert isinstance(result, CoroutineResult)

if result.error is not None:
@@ -208,7 +209,7 @@ def ready(self) -> bool:
or len(self.waiting) == 0
)

def error(self) -> Exception | None:
def error(self) -> Optional[Exception]:
assert self.ready()
return self.first_error

@@ -222,9 +223,9 @@ class Coroutine:
"""An in-flight coroutine."""

id: CoroutineID
parent_id: CoroutineID | None
coroutine: DurableCoroutine | DurableGenerator
result: Future | None = None
parent_id: Optional[CoroutineID]
coroutine: Union[DurableCoroutine, DurableGenerator]
result: Optional[Future] = None

def run(self) -> Any:
if self.result is None:
@@ -278,7 +279,7 @@ def __init__(
version: str = sys.version,
poll_min_results: int = 1,
poll_max_results: int = 10,
poll_max_wait_seconds: int | None = None,
poll_max_wait_seconds: Optional[int] = None,
):
"""Initialize the scheduler.
@@ -422,7 +423,7 @@ def _run(self, input: Input) -> Output:
assert coroutine.id not in state.suspended

coroutine_yield = None
coroutine_result: CoroutineResult | None = None
coroutine_result: Optional[CoroutineResult] = None
try:
coroutine_yield = coroutine.run()
except StopIteration as e:
5 changes: 3 additions & 2 deletions src/dispatch/signature/digest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import hashlib
import hmac
from typing import Union

import http_sfv
from http_message_signatures import InvalidSignature


def generate_content_digest(body: str | bytes) -> str:
def generate_content_digest(body: Union[str, bytes]) -> str:
"""Returns a SHA-512 Content-Digest header, according to
https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-digest-headers-13
"""
@@ -16,7 +17,7 @@ def generate_content_digest(body: str | bytes) -> str:
return str(http_sfv.Dictionary({"sha-512": digest}))


def verify_content_digest(digest_header: str | bytes, body: str | bytes):
def verify_content_digest(digest_header: Union[str, bytes], body: Union[str, bytes]):
"""Verify a SHA-256 or SHA-512 Content-Digest header matches a
request body."""
if isinstance(body, str):
9 changes: 5 additions & 4 deletions src/dispatch/signature/key.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Union, Optional

from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
@@ -11,7 +12,7 @@
from http_message_signatures import HTTPSignatureKeyResolver


def public_key_from_pem(pem: str | bytes) -> Ed25519PublicKey:
def public_key_from_pem(pem: Union[str, bytes]) -> Ed25519PublicKey:
"""Returns an Ed25519 public key given a PEM representation."""
if isinstance(pem, str):
pem = pem.encode()
@@ -28,7 +29,7 @@ def public_key_from_bytes(key: bytes) -> Ed25519PublicKey:


def private_key_from_pem(
pem: str | bytes, password: bytes | None = None
pem: Union[str, bytes], password: Optional[bytes] = None
) -> Ed25519PrivateKey:
"""Returns an Ed25519 private key given a PEM representation
and optional password."""
@@ -57,8 +58,8 @@ class KeyResolver(HTTPSignatureKeyResolver):
"""

key_id: str
public_key: Ed25519PublicKey | None = None
private_key: Ed25519PrivateKey | None = None
public_key: Optional[Ed25519PublicKey] = None
private_key: Optional[Ed25519PrivateKey] = None

def resolve_public_key(self, key_id: str):
if key_id != self.key_id or self.public_key is None:
3 changes: 2 additions & 1 deletion src/dispatch/signature/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Union

from http_message_signatures.structures import CaseInsensitiveDict

@@ -10,4 +11,4 @@ class Request:
method: str
url: str
headers: CaseInsensitiveDict
body: str | bytes
body: Union[str, bytes]
11 changes: 6 additions & 5 deletions src/dispatch/test/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Optional

import fastapi
import grpc
@@ -25,7 +26,7 @@ class EndpointClient:
"""

def __init__(
self, http_client: httpx.Client, signing_key: Ed25519PrivateKey | None = None
self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None
):
"""Initialize the client.
@@ -48,14 +49,14 @@ def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse:
return self._stub.Run(request)

@classmethod
def from_url(cls, url: str, signing_key: Ed25519PrivateKey | None = None):
def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None):
"""Returns an EndpointClient for a Dispatch endpoint URL."""
http_client = httpx.Client(base_url=url)
return EndpointClient(http_client, signing_key)

@classmethod
def from_app(
cls, app: fastapi.FastAPI, signing_key: Ed25519PrivateKey | None = None
cls, app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
):
"""Returns an EndpointClient for a Dispatch endpoint bound to a
FastAPI app instance."""
@@ -65,7 +66,7 @@ def from_app(

class _HttpxGrpcChannel(grpc.Channel):
def __init__(
self, http_client: httpx.Client, signing_key: Ed25519PrivateKey | None = None
self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None
):
self.http_client = http_client
self.signing_key = signing_key
@@ -113,7 +114,7 @@ def __init__(
method,
request_serializer,
response_deserializer,
signing_key: Ed25519PrivateKey | None = None,
signing_key: Optional[Ed25519PrivateKey] = None,
):
self.client = client
self.method = method
9 changes: 5 additions & 4 deletions src/dispatch/test/service.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import TypeAlias
from typing import Optional
from typing_extensions import TypeAlias

import grpc
import httpx
@@ -52,8 +53,8 @@ class DispatchService(dispatch_grpc.DispatchServiceServicer):
def __init__(
self,
endpoint_client: EndpointClient,
api_key: str | None = None,
retry_on_status: set[Status] | None = None,
api_key: Optional[str] = None,
retry_on_status: Optional[set[Status]] = None,
collect_roundtrips: bool = False,
):
"""Initialize the Dispatch service.
@@ -90,7 +91,7 @@ def __init__(
if collect_roundtrips:
self.roundtrips = OrderedDict()

self._thread: threading.Thread | None = None
self._thread: threading.Optional[Thread] = None
self._stop_event = threading.Event()
self._work_signal = threading.Condition()

6 changes: 3 additions & 3 deletions tests/dispatch/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from typing import Any, Callable
from typing import Any, Callable, Optional

from dispatch.coroutine import AnyException, any, call, gather, race
from dispatch.experimental.durable import durable
@@ -414,7 +414,7 @@ def resume(
main: Callable,
prev_output: Output,
call_results: list[CallResult],
poll_error: Exception | None = None,
poll_error: Optional[Exception] = None,
):
poll = self.assert_poll(prev_output)
input = Input.from_poll_results(
@@ -444,7 +444,7 @@ def assert_exit_result_value(self, output: Output, expect: Any):
self.assertEqual(expect, any_unpickle(result.output))

def assert_exit_result_error(
self, output: Output, expect: type[Exception], message: str | None = None
self, output: Output, expect: type[Exception], message: Optional[str] = None
):
result = self.assert_exit_result(output)
self.assertFalse(result.HasField("output"))