From 97ff3cbc81e6a65e3887bd59cf562a6686ff7043 Mon Sep 17 00:00:00 2001 From: doublehomixide Date: Thu, 9 Jan 2025 21:36:51 +0300 Subject: [PATCH] Type-hints + refactoring (#2018) --- faststream/_internal/fastapi/_compat.py | 2 +- faststream/_internal/fastapi/route.py | 6 +- faststream/_internal/utils/functions.py | 7 +- faststream/confluent/client.py | 14 +- faststream/middlewares/exception.py | 4 +- faststream/rabbit/broker/broker.py | 161 +++++++----------- faststream/rabbit/publisher/usecase.py | 11 +- faststream/rabbit/schemas/queue.py | 4 +- faststream/redis/broker/broker.py | 107 +++++++----- faststream/redis/publisher/factory.py | 8 +- faststream/redis/publisher/producer.py | 17 +- faststream/redis/publisher/usecase.py | 4 +- faststream/redis/subscriber/factory.py | 16 +- .../specification/asyncapi/v2_6_0/generate.py | 1 + .../asyncapi/v2_6_0/schema/contact.py | 2 +- .../asyncapi/v2_6_0/schema/docs.py | 2 +- .../asyncapi/v2_6_0/schema/license.py | 2 +- .../asyncapi/v2_6_0/schema/tag.py | 2 +- tests/a_docs/rabbit/test_bind.py | 4 +- tests/brokers/kafka/test_consume.py | 4 +- tests/brokers/nats/test_consume.py | 1 - tests/brokers/rabbit/test_consume.py | 11 +- tests/brokers/redis/test_consume.py | 7 +- tests/cli/utils/test_imports.py | 8 +- tests/prometheus/basic.py | 10 +- tests/prometheus/utils.py | 30 ++-- 26 files changed, 223 insertions(+), 222 deletions(-) diff --git a/faststream/_internal/fastapi/_compat.py b/faststream/_internal/fastapi/_compat.py index f4b423fb44..864492373f 100644 --- a/faststream/_internal/fastapi/_compat.py +++ b/faststream/_internal/fastapi/_compat.py @@ -92,7 +92,7 @@ async def solve_faststream_dependency( **kwargs, ) values, errors, background = ( - solved_result.values, # noqa: PD011 + solved_result.values, solved_result.errors, solved_result.background_tasks, ) diff --git a/faststream/_internal/fastapi/route.py b/faststream/_internal/fastapi/route.py index a132daf3f9..70573fc818 100644 --- a/faststream/_internal/fastapi/route.py +++ b/faststream/_internal/fastapi/route.py @@ -78,9 +78,9 @@ def wrap_callable_to_fastapi_compatible( response_model_exclude_none: bool, state: "DIState", ) -> Callable[["NativeMessage[Any]"], Awaitable[Any]]: - __magic_attr = "__faststream_consumer__" + magic_attr = "__faststream_consumer__" - if getattr(user_callable, __magic_attr, False): + if getattr(user_callable, magic_attr, False): return user_callable # type: ignore[return-value] if response_model: @@ -105,7 +105,7 @@ def wrap_callable_to_fastapi_compatible( state=state, ) - setattr(parsed_callable, __magic_attr, True) + setattr(parsed_callable, magic_attr, True) return wraps(user_callable)(parsed_callable) diff --git a/faststream/_internal/utils/functions.py b/faststream/_internal/utils/functions.py index efea5541d6..b824955efd 100644 --- a/faststream/_internal/utils/functions.py +++ b/faststream/_internal/utils/functions.py @@ -85,6 +85,11 @@ async def return_input(x: Any) -> Any: return x -async def run_in_executor(executor: Optional[Executor], func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: +async def run_in_executor( + executor: Optional[Executor], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: loop = asyncio.get_running_loop() return await loop.run_in_executor(executor, partial(func, *args, **kwargs)) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index 2931a32b67..92b91dd3e1 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -157,7 +157,9 @@ async def send( def ack_callback(err: Any, msg: Optional[Message]) -> None: if err or (msg is not None and (err := msg.error())): - loop.call_soon_threadsafe(result_future.set_exception, KafkaException(err)) + loop.call_soon_threadsafe( + result_future.set_exception, KafkaException(err) + ) else: loop.call_soon_threadsafe(result_future.set_result, msg) @@ -357,7 +359,9 @@ async def start(self) -> None: async def commit(self, asynchronous: bool = True) -> None: """Commits the offsets of all messages returned by the last poll operation.""" - await run_in_executor(self._thread_pool, self.consumer.commit, asynchronous=asynchronous) + await run_in_executor( + self._thread_pool, self.consumer.commit, asynchronous=asynchronous + ) async def stop(self) -> None: """Stops the Kafka consumer and releases all resources.""" @@ -382,7 +386,7 @@ async def stop(self) -> None: # Wrap calls to async to make method cancelable by timeout # We shouldn't read messages and close consumer concurrently # https://github.com/airtai/faststream/issues/1904#issuecomment-2506990895 - # Now it works withouth lock due `ThreadPoolExecutor(max_workers=1)` + # Now it works without lock due `ThreadPoolExecutor(max_workers=1)` # that makes all calls to consumer sequential await run_in_executor(self._thread_pool, self.consumer.close) @@ -414,7 +418,9 @@ async def seek(self, topic: str, partition: int, offset: int) -> None: partition=partition, offset=offset, ) - await run_in_executor(self._thread_pool, self.consumer.seek, topic_partition.to_confluent()) + await run_in_executor( + self._thread_pool, self.consumer.seek, topic_partition.to_confluent() + ) def check_msg_error(msg: Optional[Message]) -> Optional[Message]: diff --git a/faststream/middlewares/exception.py b/faststream/middlewares/exception.py index 4172fb41bb..30add14ba7 100644 --- a/faststream/middlewares/exception.py +++ b/faststream/middlewares/exception.py @@ -30,10 +30,10 @@ Callable[..., None], Callable[..., Awaitable[None]], ] -PublishingExceptionHandler: TypeAlias = Callable[..., "Any"] +PublishingExceptionHandler: TypeAlias = Callable[..., Any] CastedGeneralExceptionHandler: TypeAlias = Callable[..., Awaitable[None]] -CastedPublishingExceptionHandler: TypeAlias = Callable[..., Awaitable["Any"]] +CastedPublishingExceptionHandler: TypeAlias = Callable[..., Awaitable[Any]] CastedHandlers: TypeAlias = list[ tuple[ type[Exception], diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index 084758401d..95d54fce37 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -536,104 +536,30 @@ async def start(self) -> None: logger_state.log(f"Set max consumers to {self._max_consumers}") @override - async def publish( # type: ignore[override] + async def publish( self, - message: Annotated[ - "AioPikaSendableMessage", - Doc("Message body to send."), - ] = None, - queue: Annotated[ - Union["RabbitQueue", str], - Doc("Message routing key to publish with."), - ] = "", - exchange: Annotated[ - Union["RabbitExchange", str, None], - Doc("Target exchange to publish message to."), - ] = None, + message: "AioPikaSendableMessage" = None, + queue: Union["RabbitQueue", str] = "", + exchange: Union["RabbitExchange", str, None] = None, *, - routing_key: Annotated[ - str, - Doc( - "Message routing key to publish with. " - "Overrides `queue` option if presented.", - ), - ] = "", - mandatory: Annotated[ - bool, - Doc( - "Client waits for confirmation that the message is placed to some queue. " - "RabbitMQ returns message to client if there is no suitable queue.", - ), - ] = True, - immediate: Annotated[ - bool, - Doc( - "Client expects that there is consumer ready to take the message to work. " - "RabbitMQ returns message to client if there is no suitable consumer.", - ), - ] = False, - timeout: Annotated[ - "TimeoutType", - Doc("Send confirmation time from RabbitMQ."), - ] = None, - persist: Annotated[ - bool, - Doc("Restore the message on RabbitMQ reboot."), - ] = False, - reply_to: Annotated[ - Optional[str], - Doc( - "Reply message routing key to send with (always sending to default exchange).", - ), - ] = None, - # message args - correlation_id: Annotated[ - Optional[str], - Doc( - "Manual message **correlation_id** setter. " - "**correlation_id** is a useful option to trace messages.", - ), - ] = None, - headers: Annotated[ - Optional["HeadersType"], - Doc("Message headers to store metainformation."), - ] = None, - content_type: Annotated[ - Optional[str], - Doc( - "Message **content-type** header. " - "Used by application, not core RabbitMQ. " - "Will be set automatically if not specified.", - ), - ] = None, - content_encoding: Annotated[ - Optional[str], - Doc("Message body content encoding, e.g. **gzip**."), - ] = None, - expiration: Annotated[ - Optional["DateType"], - Doc("Message expiration (lifetime) in seconds (or datetime or timedelta)."), - ] = None, - message_id: Annotated[ - Optional[str], - Doc("Arbitrary message id. Generated automatically if not presented."), - ] = None, - timestamp: Annotated[ - Optional["DateType"], - Doc("Message publish timestamp. Generated automatically if not presented."), - ] = None, - message_type: Annotated[ - Optional[str], - Doc("Application-specific message type, e.g. **orders.created**."), - ] = None, - user_id: Annotated[ - Optional[str], - Doc("Publisher connection User ID, validated if set."), - ] = None, - priority: Annotated[ - Optional[int], - Doc("The message priority (0 by default)."), - ] = None, + routing_key: str = "", + # publish options + mandatory: bool = True, + immediate: bool = False, + timeout: "TimeoutType" = None, + persist: bool = False, + reply_to: Optional[str] = None, + correlation_id: Optional[str] = None, + # message options + headers: Optional["HeadersType"] = None, + content_type: Optional[str] = None, + content_encoding: Optional[str] = None, + expiration: Optional["DateType"] = None, + message_id: Optional[str] = None, + timestamp: Optional["DateType"] = None, + message_type: Optional[str] = None, + user_id: Optional[str] = None, + priority: Optional[int] = None, ) -> Optional["aiormq.abc.ConfirmationFrameType"]: """Publish message directly. @@ -641,6 +567,49 @@ async def publish( # type: ignore[override] applications or to publish messages from time to time. Please, use `@broker.publisher(...)` or `broker.publisher(...).publish(...)` instead in a regular way. + + Args: + message: + Message body to send. + queue: + Message routing key to publish with. + exchange: + Target exchange to publish message to. + routing_key: + Message routing key to publish with. Overrides `queue` option if presented. + mandatory: + Client waits for confirmation that the message is placed to some queue. RabbitMQ returns message to client if there is no suitable queue. + immediate: + Client expects that there is consumer ready to take the message to work. RabbitMQ returns message to client if there is no suitable consumer. + timeout: + Send confirmation time from RabbitMQ. + persist: + Restore the message on RabbitMQ reboot. + reply_to: + Reply message routing key to send with (always sending to default exchange). + correlation_id: + Manual message **correlation_id** setter. **correlation_id** is a useful option to trace messages. + headers: + Message headers to store metainformation. + content_type: + Message **content-type** header. Used by application, not core RabbitMQ. Will be set automatically if not specified. + content_encoding: + Message body content encoding, e.g. **gzip**. + expiration: + Message expiration (lifetime) in seconds (or datetime or timedelta). + message_id: + Arbitrary message id. Generated automatically if not presented. + timestamp: + Message publish timestamp. Generated automatically if not presented. + message_type: + Application-specific message type, e.g. **orders.created**. + user_id: + Publisher connection User ID, validated if set. + priority: + The message priority (0 by default). + + Returns: + An optional `aiormq.abc.ConfirmationFrameType` representing the confirmation frame if RabbitMQ is configured to send confirmations. """ cmd = RabbitPublishCommand( message, diff --git a/faststream/rabbit/publisher/usecase.py b/faststream/rabbit/publisher/usecase.py index d4b736cfa4..a898e968d0 100644 --- a/faststream/rabbit/publisher/usecase.py +++ b/faststream/rabbit/publisher/usecase.py @@ -70,15 +70,14 @@ def __init__( middlewares=middlewares, ) - request_options = dict(message_kwargs) - self.headers = request_options.pop("headers") or {} - self.reply_to = request_options.pop("reply_to", None) or "" - self.timeout = request_options.pop("timeout", None) + self.headers = message_kwargs.pop("headers") or {} + self.reply_to: str = message_kwargs.pop("reply_to", None) or "" + self.timeout = message_kwargs.pop("timeout", None) - message_options, _ = filter_by_dict(MessageOptions, request_options) + message_options, _ = filter_by_dict(MessageOptions, dict(message_kwargs)) self.message_options = message_options - publish_options, _ = filter_by_dict(PublishOptions, request_options) + publish_options, _ = filter_by_dict(PublishOptions, dict(message_kwargs)) self.publish_options = publish_options self.app_id = None diff --git a/faststream/rabbit/schemas/queue.py b/faststream/rabbit/schemas/queue.py index 8aa5480f97..0fff0a83ca 100644 --- a/faststream/rabbit/schemas/queue.py +++ b/faststream/rabbit/schemas/queue.py @@ -163,8 +163,8 @@ def __init__( if durable is EMPTY: durable = True elif not durable: - _error_msg = "Quorum and Stream queues must be durable" - raise SetupError(_error_msg) + error_msg = "Quorum and Stream queues must be durable" + raise SetupError(error_msg) elif durable is EMPTY: durable = False diff --git a/faststream/redis/broker/broker.py b/faststream/redis/broker/broker.py index aa90c1d5ca..f3faffad89 100644 --- a/faststream/redis/broker/broker.py +++ b/faststream/redis/broker/broker.py @@ -21,7 +21,7 @@ parse_url, ) from redis.exceptions import ConnectionError -from typing_extensions import Doc, TypeAlias, override +from typing_extensions import Doc, TypeAlias, overload, override from faststream.__about__ import __version__ from faststream._internal.broker.broker import BrokerUsecase @@ -368,55 +368,72 @@ def _subscriber_setup_extra(self) -> "AnyDict": "connection": self._connection, } + @overload + async def publish( + self, + message: "SendableMessage" = None, + channel: Optional[str] = None, + *, + reply_to: str = "", + headers: Optional["AnyDict"] = None, + correlation_id: Optional[str] = None, + list: Optional[str] = None, + stream: None = None, + maxlen: Optional[int] = None, + ) -> int: ... + + @overload + async def publish( + self, + message: "SendableMessage" = None, + channel: Optional[str] = None, + *, + reply_to: str = "", + headers: Optional["AnyDict"] = None, + correlation_id: Optional[str] = None, + list: Optional[str] = None, + stream: str, + maxlen: Optional[int] = None, + ) -> bytes: ... + @override - async def publish( # type: ignore[override] + async def publish( self, - message: Annotated[ - "SendableMessage", - Doc("Message body to send."), - ] = None, - channel: Annotated[ - Optional[str], - Doc("Redis PubSub object name to send message."), - ] = None, + message: "SendableMessage" = None, + channel: Optional[str] = None, *, - reply_to: Annotated[ - str, - Doc("Reply message destination PubSub object name."), - ] = "", - headers: Annotated[ - Optional["AnyDict"], - Doc("Message headers to store metainformation."), - ] = None, - correlation_id: Annotated[ - Optional[str], - Doc( - "Manual message **correlation_id** setter. " - "**correlation_id** is a useful option to trace messages.", - ), - ] = None, - list: Annotated[ - Optional[str], - Doc("Redis List object name to send message."), - ] = None, - stream: Annotated[ - Optional[str], - Doc("Redis Stream object name to send message."), - ] = None, - maxlen: Annotated[ - Optional[int], - Doc( - "Redis Stream maxlen publish option. " - "Remove eldest message if maxlen exceeded.", - ), - ] = None, - ) -> int: + reply_to: str = "", + headers: Optional["AnyDict"] = None, + correlation_id: Optional[str] = None, + list: Optional[str] = None, + stream: Optional[str] = None, + maxlen: Optional[int] = None, + ) -> Union[int, bytes]: """Publish message directly. - This method allows you to publish message in not AsyncAPI-documented way. You can use it in another frameworks - applications or to publish messages from time to time. - - Please, use `@broker.publisher(...)` or `broker.publisher(...).publish(...)` instead in a regular way. + This method allows you to publish a message in a non-AsyncAPI-documented way. + It can be used in other frameworks or to publish messages at specific intervals. + + Args: + message: + Message body to send. + channel: + Redis PubSub object name to send message. + reply_to: + Reply message destination PubSub object name. + headers: + Message headers to store metainformation. + correlation_id: + Manual message correlation_id setter. correlation_id is a useful option to trace messages. + list: + Redis List object name to send message. + stream: + Redis Stream object name to send message. + maxlen: + Redis Stream maxlen publish option. Remove eldest message if maxlen exceeded. + + Returns: + int: The result of the publish operation, typically the number of messages published. """ cmd = RedisPublishCommand( message, diff --git a/faststream/redis/publisher/factory.py b/faststream/redis/publisher/factory.py index 3ddd3d2d3f..1f886f5885 100644 --- a/faststream/redis/publisher/factory.py +++ b/faststream/redis/publisher/factory.py @@ -21,10 +21,10 @@ PublisherType: TypeAlias = Union[ - "SpecificationChannelPublisher", - "SpecificationStreamPublisher", - "SpecificationListPublisher", - "SpecificationListBatchPublisher", + SpecificationChannelPublisher, + SpecificationStreamPublisher, + SpecificationListPublisher, + SpecificationListBatchPublisher, ] diff --git a/faststream/redis/publisher/producer.py b/faststream/redis/publisher/producer.py index 17bbe5077b..0de145f094 100644 --- a/faststream/redis/publisher/producer.py +++ b/faststream/redis/publisher/producer.py @@ -123,7 +123,9 @@ async def publish_batch( ] return await self._connection.client.rpush(cmd.destination, *batch) - async def __publish(self, msg: bytes, cmd: "RedisPublishCommand") -> Union[int, bytes]: + async def __publish( + self, msg: bytes, cmd: "RedisPublishCommand" + ) -> Union[int, bytes]: if cmd.destination_type is DestinationType.Channel: return await self._connection.client.publish(cmd.destination, msg) @@ -131,11 +133,14 @@ async def __publish(self, msg: bytes, cmd: "RedisPublishCommand") -> Union[int, return await self._connection.client.rpush(cmd.destination, msg) if cmd.destination_type is DestinationType.Stream: - return cast("bytes", await self._connection.client.xadd( - name=cmd.destination, - fields={DATA_KEY: msg}, - maxlen=cmd.maxlen, - )) + return cast( + "bytes", + await self._connection.client.xadd( + name=cmd.destination, + fields={DATA_KEY: msg}, + maxlen=cmd.maxlen, + ), + ) error_msg = "unreachable" raise AssertionError(error_msg) diff --git a/faststream/redis/publisher/usecase.py b/faststream/redis/publisher/usecase.py index 06dd333a74..82fcfd0b2b 100644 --- a/faststream/redis/publisher/usecase.py +++ b/faststream/redis/publisher/usecase.py @@ -1,7 +1,7 @@ from abc import abstractmethod from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Annotated, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Optional, Union from typing_extensions import Doc, override @@ -428,7 +428,7 @@ async def publish( "Remove eldest message if maxlen exceeded.", ), ] = None, - ) -> Any: + ) -> bytes: cmd = RedisPublishCommand( message, stream=stream or self.stream.name, diff --git a/faststream/redis/subscriber/factory.py b/faststream/redis/subscriber/factory.py index e598fe30cd..504aff488f 100644 --- a/faststream/redis/subscriber/factory.py +++ b/faststream/redis/subscriber/factory.py @@ -27,14 +27,14 @@ from faststream.redis.message import UnifyRedisDict SubsciberType: TypeAlias = Union[ - "SpecificationChannelSubscriber", - "SpecificationStreamBatchSubscriber", - "SpecificationStreamSubscriber", - "SpecificationListBatchSubscriber", - "SpecificationListSubscriber", - "SpecificationChannelConcurrentSubscriber", - "SpecificationListConcurrentSubscriber", - "SpecificationStreamConcurrentSubscriber", + SpecificationChannelSubscriber, + SpecificationStreamBatchSubscriber, + SpecificationStreamSubscriber, + SpecificationListBatchSubscriber, + SpecificationListSubscriber, + SpecificationChannelConcurrentSubscriber, + SpecificationListConcurrentSubscriber, + SpecificationStreamConcurrentSubscriber, ] diff --git a/faststream/specification/asyncapi/v2_6_0/generate.py b/faststream/specification/asyncapi/v2_6_0/generate.py index d0f4010da1..139c1ff4d0 100644 --- a/faststream/specification/asyncapi/v2_6_0/generate.py +++ b/faststream/specification/asyncapi/v2_6_0/generate.py @@ -157,6 +157,7 @@ def get_broker_channels( RuntimeWarning, stacklevel=1, ) + channels[key] = Channel.from_sub(sub) for p in broker._publishers: diff --git a/faststream/specification/asyncapi/v2_6_0/schema/contact.py b/faststream/specification/asyncapi/v2_6_0/schema/contact.py index da07f90910..d71cbdb781 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/contact.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/contact.py @@ -64,7 +64,7 @@ def from_spec( email=contact.email, ) - contact = cast(AnyDict, contact) + contact = cast("AnyDict", contact) contact_data, custom_data = filter_by_dict(ContactDict, contact) if custom_data: diff --git a/faststream/specification/asyncapi/v2_6_0/schema/docs.py b/faststream/specification/asyncapi/v2_6_0/schema/docs.py index 5bf8ebb458..0bbb933f6f 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/docs.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/docs.py @@ -58,7 +58,7 @@ def from_spec( if isinstance(docs, SpecDocs): return cls(url=docs.url, description=docs.description) - docs = cast(AnyDict, docs) + docs = cast("AnyDict", docs) docs_data, custom_data = filter_by_dict(ExternalDocsDict, docs) if custom_data: diff --git a/faststream/specification/asyncapi/v2_6_0/schema/license.py b/faststream/specification/asyncapi/v2_6_0/schema/license.py index a713b75fe4..fee3db4012 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/license.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/license.py @@ -64,7 +64,7 @@ def from_spec( url=license.url, ) - license = cast(AnyDict, license) + license = cast("AnyDict", license) license_data, custom_data = filter_by_dict(LicenseDict, license) if custom_data: diff --git a/faststream/specification/asyncapi/v2_6_0/schema/tag.py b/faststream/specification/asyncapi/v2_6_0/schema/tag.py index ba2ac8e17f..86dff2e613 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/tag.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/tag.py @@ -56,7 +56,7 @@ def from_spec(cls, tag: Union[SpecTag, TagDict, AnyDict]) -> Union[Self, AnyDict externalDocs=ExternalDocs.from_spec(tag.external_docs), ) - tag = cast(AnyDict, tag) + tag = cast("AnyDict", tag) tag_data, custom_data = filter_by_dict(TagDict, tag) if custom_data: diff --git a/tests/a_docs/rabbit/test_bind.py b/tests/a_docs/rabbit/test_bind.py index d2656a6f5c..76c7b8d6fd 100644 --- a/tests/a_docs/rabbit/test_bind.py +++ b/tests/a_docs/rabbit/test_bind.py @@ -7,8 +7,8 @@ from tests.marks import require_aiopika -@pytest.mark.asyncio -@pytest.mark.rabbit +@pytest.mark.asyncio() +@pytest.mark.rabbit() @require_aiopika async def test_bind(monkeypatch, async_mock: AsyncMock): from docs.docs_src.rabbit.bind import app, broker, some_exchange, some_queue diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index 1c8ee55f99..d9160ddca0 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -3,6 +3,7 @@ import pytest from aiokafka import AIOKafkaConsumer +from aiokafka.structs import RecordMetadata from faststream import AckPolicy from faststream.exceptions import AckMessage @@ -38,7 +39,7 @@ async def pattern_handler(msg) -> None: async with self.patch_broker(consume_broker) as br: await br.start() - await br.publish(1, topic=queue) + result = await br.publish(1, topic=queue) await asyncio.wait( ( @@ -48,6 +49,7 @@ async def pattern_handler(msg) -> None: ), timeout=3, ) + assert isinstance(result, RecordMetadata), result assert event.is_set() assert pattern_event.is_set() diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index f03daaec4c..f81bd7671f 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -84,7 +84,6 @@ def subscriber(m) -> None: ) assert isinstance(result, PubAck), result - assert event.is_set() async def test_consume_with_filter( diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 45b0ad0026..c89e429a70 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -3,6 +3,7 @@ import pytest from aio_pika import IncomingMessage, Message +from aiormq.abc import ConfirmationFrameType from faststream import AckPolicy from faststream.exceptions import AckMessage, NackMessage, RejectMessage, SkipMessage @@ -32,15 +33,13 @@ def h(m) -> None: async with self.patch_broker(consume_broker) as br: await br.start() + + result = await br.publish("hello", queue=queue, exchange=exchange) await asyncio.wait( - ( - asyncio.create_task( - br.publish("hello", queue=queue, exchange=exchange), - ), - asyncio.create_task(event.wait()), - ), + (asyncio.create_task(event.wait()),), timeout=3, ) + assert isinstance(result, ConfirmationFrameType), result assert event.is_set() diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index f0da8be4eb..78efae46c2 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -31,13 +31,12 @@ async def handler(msg) -> None: async with self.patch_broker(consume_broker) as br: await br.start() + result = await br._connection.publish(queue, "hello") await asyncio.wait( - ( - asyncio.create_task(br._connection.publish(queue, "hello")), - asyncio.create_task(event.wait()), - ), + (asyncio.create_task(event.wait()),), timeout=3, ) + assert result == 1, result mock.assert_called_once_with(b"hello") diff --git a/tests/cli/utils/test_imports.py b/tests/cli/utils/test_imports.py index c87cd81667..11263bc3c7 100644 --- a/tests/cli/utils/test_imports.py +++ b/tests/cli/utils/test_imports.py @@ -35,14 +35,14 @@ def test_import_wrong() -> None: ), ), ) -def test_get_app_path(test_input, exp_module, exp_app) -> None: +def test_get_app_path(test_input: str, exp_module: str, exp_app: str) -> None: dir, app = _get_obj_path(test_input) assert app == exp_app assert dir == Path.cwd() / exp_module def test_get_app_path_wrong() -> None: - with pytest.raises(ValueError, match="`module.app` is not a path to object"): + with pytest.raises(ValueError, match=r"`module.app` is not a path to object"): _get_obj_path("module.app") @@ -62,7 +62,7 @@ def test_import_from_string_import_wrong() -> None: @require_nats @require_aiopika @require_aiokafka -def test_import_from_string(test_input, exp_module) -> None: +def test_import_from_string(test_input: str, exp_module: str) -> None: module, app = import_from_string(test_input) assert isinstance(app, FastStream) assert module == (Path.cwd() / exp_module).parent @@ -91,7 +91,7 @@ def test_import_from_string(test_input, exp_module) -> None: @require_nats @require_aiopika @require_aiokafka -def test_import_module(test_input, exp_module) -> None: +def test_import_module(test_input: str, exp_module: str) -> None: module, app = import_from_string(test_input) assert isinstance(app, FastStream) assert module == (Path.cwd() / exp_module).parent diff --git a/tests/prometheus/basic.py b/tests/prometheus/basic.py index fcaaf12a1f..a48d54906f 100644 --- a/tests/prometheus/basic.py +++ b/tests/prometheus/basic.py @@ -178,7 +178,7 @@ def assert_metrics( app_name="faststream", broker=settings_provider.messaging_system, queue=consume_attrs["destination_name"], - duration=cast(float, IsPositiveFloat), + duration=cast("float", IsPositiveFloat), ) ) @@ -221,7 +221,7 @@ def assert_metrics( metrics_prefix="faststream", app_name="faststream", broker=settings_provider.messaging_system, - queue=cast(str, IsStr), + queue=cast("str", IsStr), status=PublishingStatus.success, messages_amount=consume_attrs["messages_count"], ) @@ -231,8 +231,8 @@ def assert_metrics( metrics_prefix="faststream", app_name="faststream", broker=settings_provider.messaging_system, - queue=cast(str, IsStr), - duration=cast(float, IsPositiveFloat), + queue=cast("str", IsStr), + duration=cast("float", IsPositiveFloat), ) ) @@ -240,7 +240,7 @@ def assert_metrics( metrics_prefix="faststream", app_name="faststream", broker=settings_provider.messaging_system, - queue=cast(str, IsStr), + queue=cast("str", IsStr), exception_type=None, ) diff --git a/tests/prometheus/utils.py b/tests/prometheus/utils.py index cb1e0738c0..29a813e927 100644 --- a/tests/prometheus/utils.py +++ b/tests/prometheus/utils.py @@ -33,7 +33,7 @@ def get_received_messages_metric( Sample( name=f"{metrics_prefix}_received_messages_created", labels={"app_name": app_name, "broker": broker, "handler": queue}, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -66,7 +66,7 @@ def get_received_messages_size_bytes_metric( "app_name": app_name, "broker": broker, "handler": queue, - "le": cast(str, IsStr), + "le": cast("str", IsStr), }, value=float(messages_amount), timestamp=None, @@ -91,7 +91,7 @@ def get_received_messages_size_bytes_metric( Sample( name=f"{metrics_prefix}_received_messages_size_bytes_created", labels={"app_name": app_name, "broker": broker, "handler": queue}, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -163,7 +163,7 @@ def get_received_processed_messages_metric( "handler": queue, "status": status.value, }, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -194,9 +194,9 @@ def get_received_processed_messages_duration_seconds_metric( "app_name": app_name, "broker": broker, "handler": queue, - "le": cast(str, IsStr), + "le": cast("str", IsStr), }, - value=cast(float, IsFloat), + value=cast("float", IsFloat), timestamp=None, exemplar=None, ) @@ -205,7 +205,7 @@ def get_received_processed_messages_duration_seconds_metric( Sample( name=f"{metrics_prefix}_received_processed_messages_duration_seconds_count", labels={"app_name": app_name, "broker": broker, "handler": queue}, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -219,7 +219,7 @@ def get_received_processed_messages_duration_seconds_metric( Sample( name=f"{metrics_prefix}_received_processed_messages_duration_seconds_created", labels={"app_name": app_name, "broker": broker, "handler": queue}, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -265,7 +265,7 @@ def get_received_processed_messages_exceptions_metric( "handler": queue, "exception_type": exception_type, }, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -313,7 +313,7 @@ def get_published_messages_metric( "destination": queue, "status": status.value, }, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -344,9 +344,9 @@ def get_published_messages_duration_seconds_metric( "app_name": app_name, "broker": broker, "destination": queue, - "le": cast(str, IsStr), + "le": cast("str", IsStr), }, - value=cast(float, IsFloat), + value=cast("float", IsFloat), timestamp=None, exemplar=None, ) @@ -355,7 +355,7 @@ def get_published_messages_duration_seconds_metric( Sample( name=f"{metrics_prefix}_published_messages_duration_seconds_count", labels={"app_name": app_name, "broker": broker, "destination": queue}, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -369,7 +369,7 @@ def get_published_messages_duration_seconds_metric( Sample( name=f"{metrics_prefix}_published_messages_duration_seconds_created", labels={"app_name": app_name, "broker": broker, "destination": queue}, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ), @@ -414,7 +414,7 @@ def get_published_messages_exceptions_metric( "destination": queue, "exception_type": exception_type, }, - value=cast(float, IsPositiveFloat), + value=cast("float", IsPositiveFloat), timestamp=None, exemplar=None, ),