Skip to content

Commit

Permalink
Add support for local scoped injection contexts.
Browse files Browse the repository at this point in the history
  • Loading branch information
FasterSpeeding committed Nov 24, 2024
1 parent 42974aa commit 86bbe2d
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 67 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Support for local scoped injection contexts.

### Removed
- Support for Python 3.9 and 3.10.

Expand Down
8 changes: 4 additions & 4 deletions alluka/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@
_TypeT = type[_T]


class _NoDefaultEnum(enum.Enum):
class _NoValueEnum(enum.Enum):
VALUE = object()


_NO_VALUE: typing.Literal[_NoDefaultEnum.VALUE] = _NoDefaultEnum.VALUE
_NoValueOr = _T | typing.Literal[_NoDefaultEnum.VALUE]
_NO_VALUE: typing.Literal[_NoValueEnum.VALUE] = _NoValueEnum.VALUE
_NoValue = typing.Literal[_NoValueEnum.VALUE]


@typing.overload
Expand Down Expand Up @@ -283,7 +283,7 @@ def get_type_dependency(self, type_: type[_T], /) -> _T: ...
@typing.overload
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_type_dependency(self, type_: type[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE) -> _T | _DefaultT:
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE) -> _T | _DefaultT:
# <<inherited docstring from alluka.abc.Client>>.
result = self._type_dependencies.get(type_, default)

Expand Down
16 changes: 8 additions & 8 deletions alluka/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@
_DefaultT = typing.TypeVar("_DefaultT")


class _NoDefaultEnum(enum.Enum):
class _NoValueEnum(enum.Enum):
VALUE = object()


_NO_VALUE: typing.Literal[_NoDefaultEnum.VALUE] = _NoDefaultEnum.VALUE
_NoValueOr = _T | typing.Literal[_NoDefaultEnum.VALUE]
_NO_VALUE: typing.Literal[_NoValueEnum.VALUE] = _NoValueEnum.VALUE
_NoValue = typing.Literal[_NoValueEnum.VALUE]


class Context(alluka.Context):
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_type_dependency(self, type_: type[_T], /) -> _T: ...
@typing.overload
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_type_dependency(self, type_: type[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE) -> _T | _DefaultT:
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE) -> _T | _DefaultT:
# <<inherited docstring from alluka.abc.Context>>.
if type_ is alluka.Context:
return self # type: ignore
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_cached_result(self, callback: alluka.CallbackSig[_T], /) -> _T: ...
def get_cached_result(self, callback: alluka.CallbackSig[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_cached_result(
self, callback: alluka.CallbackSig[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE
self, callback: alluka.CallbackSig[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE
) -> _T | _DefaultT:
# <<inherited docstring from alluka.abc.Context>>.
result = self._result_cache.get(callback, default)
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_cached_result(self, callback: alluka.CallbackSig[_T], /) -> _T: ...
def get_cached_result(self, callback: alluka.CallbackSig[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_cached_result(
self, callback: alluka.CallbackSig[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE
self, callback: alluka.CallbackSig[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE
) -> _T | _DefaultT:
value = self._context.get_cached_result(callback, default=default)

Expand All @@ -224,7 +224,7 @@ def get_type_dependency(self, type_: type[_T], /) -> _T: ...
@typing.overload
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_type_dependency(self, type_: type[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE) -> _T | _DefaultT:
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE) -> _T | _DefaultT:
# <<inherited docstring from alluka.abc.Context>>.
value = self._overrides.get(type_, default)
if value is default:
Expand Down Expand Up @@ -277,7 +277,7 @@ def get_type_dependency(self, type_: type[_T], /) -> _T: ...
@typing.overload
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_type_dependency(self, type_: type[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE) -> _T | _DefaultT:
def get_type_dependency(self, type_: type[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE) -> _T | _DefaultT:
# <<inherited docstring from alluka.abc.Context>>.
value = self._special_case_types.get(type_, default)
if value is default:
Expand Down
4 changes: 2 additions & 2 deletions alluka/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class _UndefinedEnum(enum.Enum):

UNDEFINED = _UndefinedEnum.UNDEFINED
"""Singleton used internally to indicate that a value is undefined."""
UndefinedOr = _T | typing.Literal[_UndefinedEnum.UNDEFINED]
Undefined = typing.Literal[_UndefinedEnum.UNDEFINED]
"""Union for a value which may be undefined."""


Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
types: collections.Sequence[type[typing.Any]],
/,
*,
default: UndefinedOr[typing.Any] = UNDEFINED,
default: typing.Any | Undefined = UNDEFINED,
) -> None:
"""Initialize the type descriptor.
Expand Down
6 changes: 3 additions & 3 deletions alluka/_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def parameters(self) -> collections.Mapping[str, inspect.Parameter]:
def accept(self, visitor: ParameterVisitor, /) -> dict[str, _types.InjectedTuple]:
return visitor.visit_callback(self)

def resolve_annotation(self, name: str, /) -> _types.UndefinedOr[typing.Any]:
def resolve_annotation(self, name: str, /) -> typing.Any | _types.Undefined:
if self._signature is None:
return _types.UNDEFINED

Expand Down Expand Up @@ -155,7 +155,7 @@ class ParameterVisitor:
_NODES: list[collections.Callable[[Callback, str], Node]] = [Default, Annotation]

def _parse_type(
self, type_: typing.Any, *, default: _types.UndefinedOr[typing.Any] = _types.UNDEFINED
self, type_: typing.Any, *, default: typing.Any | _types.Undefined = _types.UNDEFINED
) -> _types.InjectedTuple:
if typing.get_origin(type_) not in _UnionTypes:
return (_types.InjectedTypes.TYPE, _types.InjectedType(type_, [type_], default=default))
Expand All @@ -171,7 +171,7 @@ def _parse_type(
return (_types.InjectedTypes.TYPE, _types.InjectedType(type_, sub_types, default=default))

def _annotation_to_type(
self, value: typing.Any, /, default: _types.UndefinedOr[typing.Any] = _types.UNDEFINED
self, value: typing.Any, /, default: typing.Any | _types.Undefined = _types.UNDEFINED
) -> _types.InjectedTuple:
if typing.get_origin(value) is typing.Annotated:
args = typing.get_args(value)
Expand Down
8 changes: 4 additions & 4 deletions alluka/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@
_SyncCallbackT = typing.TypeVar("_SyncCallbackT", bound=collections.Callable[..., typing.Any])


class _NoDefaultEnum(enum.Enum):
class _NoValueEnum(enum.Enum):
VALUE = object()


_NO_VALUE: typing.Literal[_NoDefaultEnum.VALUE] = _NoDefaultEnum.VALUE
_NoValueOr = _T | typing.Literal[_NoDefaultEnum.VALUE]
_NO_VALUE: typing.Literal[_NoValueEnum.VALUE] = _NoValueEnum.VALUE
_NoValue = typing.Literal[_NoValueEnum.VALUE]


CallbackSig = collections.Callable[..., _CoroT[_T] | _T]
Expand Down Expand Up @@ -553,7 +553,7 @@ def get_cached_result(self, callback: CallbackSig[_T], /) -> _T: ...
def get_cached_result(self, callback: CallbackSig[_T], /, *, default: _DefaultT) -> _T | _DefaultT: ...

def get_cached_result(
self, callback: CallbackSig[_T], /, *, default: _NoValueOr[_DefaultT] = _NO_VALUE
self, callback: CallbackSig[_T], /, *, default: _DefaultT | _NoValue = _NO_VALUE
) -> _T | _DefaultT:
"""Get the cached result of a callback.
Expand Down
Loading

0 comments on commit 86bbe2d

Please sign in to comment.