Skip to content

Commit

Permalink
Fix: middleware type hints (#2033)
Browse files Browse the repository at this point in the history
* feat: Mypy tests for broker middleware.

* refactor: Move BaseMiddleware to _internal module.

* fix: Added missed second type variable to BrokerMiddleware protocol.

* fix: Broker middleware mypy issues.

* fix: Added missed type vars to base prometheus middleware classes.

* fix: Redis prometheus middleware type hints.

* fix: Rabbit prometheus middleware type hints.

* fix: Nats prometheus middleware type-hints.

* fix: Kafka broker public middleware contract.

* fix: Confluent broker middleware public contract.

* fix: Confluent prometheus middleware type hints.

* fix: Kafka pometheus middleware type hints.

* fix: Linter issues.
  • Loading branch information
DABND19 authored Jan 14, 2025
1 parent 21976c2 commit 57c0633
Show file tree
Hide file tree
Showing 22 changed files with 143 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from collections.abc import Awaitable
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional

from typing_extensions import (
Self,
TypeVar as TypeVar313,
)
from typing_extensions import Self

from faststream.response import PublishCommand
from faststream._internal.types import AnyMsg, PublishCommandType

if TYPE_CHECKING:
from types import TracebackType
Expand All @@ -16,19 +13,12 @@
from faststream.message import StreamMessage


PublishCommandType = TypeVar313(
"PublishCommandType",
bound=PublishCommand,
default=PublishCommand,
)


class BaseMiddleware(Generic[PublishCommandType]):
class BaseMiddleware(Generic[PublishCommandType, AnyMsg]):
"""A base middleware class."""

def __init__(
self,
msg: Optional[Any],
msg: Optional[AnyMsg],
/,
*,
context: "ContextRepo",
Expand Down Expand Up @@ -63,8 +53,8 @@ async def __aexit__(

async def on_consume(
self,
msg: "StreamMessage[Any]",
) -> "StreamMessage[Any]":
msg: "StreamMessage[AnyMsg]",
) -> "StreamMessage[AnyMsg]":
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
return msg

Expand All @@ -76,7 +66,7 @@ async def after_consume(self, err: Optional[Exception]) -> None:
async def consume_scope(
self,
call_next: "AsyncFuncAny",
msg: "StreamMessage[Any]",
msg: "StreamMessage[AnyMsg]",
) -> Any:
"""Asynchronously consumes a message and returns an asynchronous iterator of decoded messages."""
err: Optional[Exception] = None
Expand Down
20 changes: 16 additions & 4 deletions faststream/_internal/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Awaitable
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Expand All @@ -11,18 +12,29 @@
from typing_extensions import (
ParamSpec,
TypeAlias,
TypeVar as TypeVar313,
)

from faststream._internal.basic_types import AsyncFuncAny
from faststream._internal.context.repository import ContextRepo
from faststream.message import StreamMessage
from faststream.middlewares import BaseMiddleware
from faststream.response.response import PublishCommand

if TYPE_CHECKING:
from faststream._internal.middlewares import BaseMiddleware


AnyMsg = TypeVar313("AnyMsg", default=Any)
AnyMsg_contra = TypeVar313("AnyMsg_contra", default=Any, contravariant=True)
MsgType = TypeVar("MsgType")
Msg_contra = TypeVar("Msg_contra", contravariant=True)
StreamMsg = TypeVar("StreamMsg", bound=StreamMessage[Any])
ConnectionType = TypeVar("ConnectionType")
PublishCommandType = TypeVar313(
"PublishCommandType",
bound=PublishCommand,
default=Any,
)

SyncFilter: TypeAlias = Callable[[StreamMsg], bool]
AsyncFilter: TypeAlias = Callable[[StreamMsg], Awaitable[bool]]
Expand Down Expand Up @@ -66,16 +78,16 @@
]


class BrokerMiddleware(Protocol[Msg_contra]):
class BrokerMiddleware(Protocol[AnyMsg_contra, PublishCommandType]):
"""Middleware builder interface."""

def __call__(
self,
msg: Optional[Msg_contra],
msg: Optional[AnyMsg_contra],
/,
*,
context: ContextRepo,
) -> BaseMiddleware: ...
) -> "BaseMiddleware[PublishCommandType]": ...


SubscriberMiddleware: TypeAlias = Callable[
Expand Down
7 changes: 1 addition & 6 deletions faststream/confluent/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,7 @@ def __init__(
Doc("Dependencies to apply to all broker subscribers."),
] = (),
middlewares: Annotated[
Sequence[
Union[
"BrokerMiddleware[Message]",
"BrokerMiddleware[tuple[Message, ...]]",
]
],
Sequence["BrokerMiddleware[Union[Message, tuple[Message, ...]]]"],
Doc("Middlewares to apply to all broker publishers/subscribers."),
] = (),
routers: Annotated[
Expand Down
13 changes: 10 additions & 3 deletions faststream/confluent/prometheus/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from confluent_kafka import Message

from faststream._internal.constants import EMPTY
from faststream.confluent.prometheus.provider import settings_provider_factory
Expand All @@ -10,7 +12,12 @@
from prometheus_client import CollectorRegistry


class KafkaPrometheusMiddleware(PrometheusMiddleware[KafkaPublishCommand]):
class KafkaPrometheusMiddleware(
PrometheusMiddleware[
KafkaPublishCommand,
Union[Message, Sequence[Message]],
]
):
def __init__(
self,
*,
Expand All @@ -20,7 +27,7 @@ def __init__(
received_messages_size_buckets: Optional[Sequence[float]] = None,
) -> None:
super().__init__(
settings_provider_factory=settings_provider_factory,
settings_provider_factory=settings_provider_factory, # type: ignore[arg-type]
registry=registry,
app_name=app_name,
metrics_prefix=metrics_prefix,
Expand Down
5 changes: 1 addition & 4 deletions faststream/kafka/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,7 @@ def __init__(
] = (),
middlewares: Annotated[
Sequence[
Union[
"BrokerMiddleware[ConsumerRecord]",
"BrokerMiddleware[tuple[ConsumerRecord, ...]]",
]
"BrokerMiddleware[Union[ConsumerRecord, tuple[ConsumerRecord, ...]]]"
],
Doc("Middlewares to apply to all broker publishers/subscribers."),
] = (),
Expand Down
13 changes: 10 additions & 3 deletions faststream/kafka/prometheus/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from aiokafka import ConsumerRecord

from faststream._internal.constants import EMPTY
from faststream.kafka.prometheus.provider import settings_provider_factory
Expand All @@ -10,7 +12,12 @@
from prometheus_client import CollectorRegistry


class KafkaPrometheusMiddleware(PrometheusMiddleware[KafkaPublishCommand]):
class KafkaPrometheusMiddleware(
PrometheusMiddleware[
KafkaPublishCommand,
Union[ConsumerRecord, Sequence[ConsumerRecord]],
],
):
def __init__(
self,
*,
Expand All @@ -20,7 +27,7 @@ def __init__(
received_messages_size_buckets: Optional[Sequence[float]] = None,
) -> None:
super().__init__(
settings_provider_factory=settings_provider_factory,
settings_provider_factory=settings_provider_factory, # type: ignore[arg-type]
registry=registry,
app_name=app_name,
metrics_prefix=metrics_prefix,
Expand Down
2 changes: 1 addition & 1 deletion faststream/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from faststream._internal.middlewares import BaseMiddleware
from faststream.middlewares.acknowledgement.conf import AckPolicy
from faststream.middlewares.acknowledgement.middleware import AcknowledgementMiddleware
from faststream.middlewares.base import BaseMiddleware
from faststream.middlewares.exception import ExceptionMiddleware

__all__ = (
Expand Down
2 changes: 1 addition & 1 deletion faststream/middlewares/acknowledgement/middleware.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from typing import TYPE_CHECKING, Any, Optional

from faststream._internal.middlewares import BaseMiddleware
from faststream.exceptions import (
AckMessage,
HandlerException,
NackMessage,
RejectMessage,
)
from faststream.middlewares.acknowledgement.conf import AckPolicy
from faststream.middlewares.base import BaseMiddleware

if TYPE_CHECKING:
from types import TracebackType
Expand Down
2 changes: 1 addition & 1 deletion faststream/middlewares/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

from typing_extensions import Literal, TypeAlias

from faststream._internal.middlewares import BaseMiddleware
from faststream._internal.utils import apply_types
from faststream._internal.utils.functions import sync_fake_context, to_async
from faststream.exceptions import IgnoredException
from faststream.middlewares.base import BaseMiddleware

if TYPE_CHECKING:
from contextlib import AbstractContextManager
Expand Down
3 changes: 1 addition & 2 deletions faststream/middlewares/logging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from typing import TYPE_CHECKING, Any, Optional

from faststream._internal.middlewares import BaseMiddleware
from faststream.exceptions import IgnoredException
from faststream.message.source_type import SourceType

from .base import BaseMiddleware

if TYPE_CHECKING:
from types import TracebackType

Expand Down
10 changes: 7 additions & 3 deletions faststream/nats/prometheus/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from nats.aio.msg import Msg

from faststream._internal.constants import EMPTY
from faststream.nats.prometheus.provider import settings_provider_factory
Expand All @@ -10,7 +12,9 @@
from prometheus_client import CollectorRegistry


class NatsPrometheusMiddleware(PrometheusMiddleware[NatsPublishCommand]):
class NatsPrometheusMiddleware(
PrometheusMiddleware[NatsPublishCommand, Union[Msg, Sequence[Msg]]]
):
def __init__(
self,
*,
Expand All @@ -20,7 +24,7 @@ def __init__(
received_messages_size_buckets: Optional[Sequence[float]] = None,
) -> None:
super().__init__(
settings_provider_factory=settings_provider_factory,
settings_provider_factory=settings_provider_factory, # type: ignore[arg-type]
registry=registry,
app_name=app_name,
metrics_prefix=metrics_prefix,
Expand Down
3 changes: 2 additions & 1 deletion faststream/opentelemetry/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from opentelemetry.trace import Link, Span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from faststream.middlewares.base import BaseMiddleware, PublishCommandType
from faststream._internal.middlewares import BaseMiddleware
from faststream._internal.types import PublishCommandType
from faststream.opentelemetry.baggage import Baggage
from faststream.opentelemetry.consts import (
ERROR_TYPE,
Expand Down
24 changes: 14 additions & 10 deletions faststream/prometheus/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional

from faststream._internal.constants import EMPTY
from faststream._internal.middlewares import BaseMiddleware
from faststream._internal.types import AnyMsg, PublishCommandType
from faststream.exceptions import IgnoredException
from faststream.message import SourceType
from faststream.middlewares.base import BaseMiddleware, PublishCommandType
from faststream.prometheus.consts import (
PROCESSING_STATUS_BY_ACK_STATUS,
PROCESSING_STATUS_BY_HANDLER_EXCEPTION_MAP,
Expand All @@ -24,15 +25,15 @@
from faststream.message.message import StreamMessage


class PrometheusMiddleware(Generic[PublishCommandType]):
class PrometheusMiddleware(Generic[PublishCommandType, AnyMsg]):
__slots__ = ("_metrics_container", "_metrics_manager", "_settings_provider_factory")

def __init__(
self,
*,
settings_provider_factory: Callable[
[Any],
Optional[MetricsSettingsProvider[Any, PublishCommandType]],
[Optional[AnyMsg]],
Optional[MetricsSettingsProvider[AnyMsg, PublishCommandType]],
],
registry: "CollectorRegistry",
app_name: str = EMPTY,
Expand All @@ -55,7 +56,7 @@ def __init__(

def __call__(
self,
msg: Optional[Any],
msg: Optional[AnyMsg],
/,
*,
context: "ContextRepo",
Expand All @@ -68,15 +69,18 @@ def __call__(
)


class BasePrometheusMiddleware(BaseMiddleware[PublishCommandType]):
class BasePrometheusMiddleware(
BaseMiddleware[PublishCommandType, AnyMsg],
Generic[PublishCommandType, AnyMsg],
):
def __init__(
self,
msg: Optional[Any],
msg: Optional[AnyMsg],
/,
*,
settings_provider_factory: Callable[
[Any],
Optional[MetricsSettingsProvider[Any, PublishCommandType]],
[Optional[AnyMsg]],
Optional[MetricsSettingsProvider[AnyMsg, PublishCommandType]],
],
metrics_manager: MetricsManager,
context: "ContextRepo",
Expand All @@ -88,7 +92,7 @@ def __init__(
async def consume_scope(
self,
call_next: "AsyncFuncAny",
msg: "StreamMessage[Any]",
msg: "StreamMessage[AnyMsg]",
) -> Any:
if self._settings_provider is None or msg._source_type is SourceType.RESPONSE:
return await call_next(msg)
Expand Down
7 changes: 4 additions & 3 deletions faststream/prometheus/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing_extensions import TypeVar as TypeVar313

from faststream.message.message import MsgType, StreamMessage
from faststream._internal.types import AnyMsg
from faststream.response.response import PublishCommand

if TYPE_CHECKING:
from faststream.message.message import StreamMessage
from faststream.prometheus import ConsumeAttrs


Expand All @@ -17,12 +18,12 @@
)


class MetricsSettingsProvider(Protocol[MsgType, PublishCommandType_contra]):
class MetricsSettingsProvider(Protocol[AnyMsg, PublishCommandType_contra]):
messaging_system: str

def get_consume_attrs_from_message(
self,
msg: "StreamMessage[MsgType]",
msg: "StreamMessage[AnyMsg]",
) -> "ConsumeAttrs": ...

def get_publish_destination_name_from_cmd(
Expand Down
Loading

0 comments on commit 57c0633

Please sign in to comment.