diff --git a/aiomqtt/client.py b/aiomqtt/client.py index 4b4439e..cf551d4 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -14,7 +14,7 @@ from types import TracebackType from typing import ( Any, - AsyncGenerator, + AsyncIterator, Awaitable, Callable, Coroutine, @@ -125,7 +125,7 @@ class Will: class Client: - """The async context manager that manages the connection to the broker. + """Asynchronous context manager for the connection to the MQTT broker. Args: hostname: The hostname or IP address of the remote broker. @@ -320,10 +320,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 timeout = 10 self.timeout = timeout - @property - def messages(self) -> AsyncGenerator[Message, None]: - return self._messages() - @property def identifier(self) -> str: """Return the client identifier. @@ -333,6 +329,42 @@ def identifier(self) -> str: """ return self._client._client_id.decode() # noqa: SLF001 + class MessagesIterator: + """Dynamic view of the message queue.""" + + def __init__(self, client: Client) -> None: + self._client = client + + def __aiter__(self) -> AsyncIterator[Message]: + return self + + async def __anext__(self) -> Message: + # Wait until we either (1) receive a message or (2) disconnect + task = self._client._loop.create_task(self._client._queue.get()) # noqa: SLF001 + try: + done, _ = await asyncio.wait( + (task, self._client._disconnected), # noqa: SLF001 + return_when=asyncio.FIRST_COMPLETED, + ) + # If the asyncio.wait is cancelled, we must also cancel the queue task + except asyncio.CancelledError: + task.cancel() + raise + # When we receive a message, return it + if task in done: + return task.result() + # If we disconnect from the broker, stop the generator with an exception + task.cancel() + msg = "Disconnected during message iteration" + raise MqttError(msg) + + def __len__(self) -> int: + return self._client._queue.qsize() # noqa: SLF001 + + @property + def messages(self) -> MessagesIterator: + return self.MessagesIterator(self) + @property def _pending_calls(self) -> Generator[int, None, None]: """Yield all message IDs with pending calls.""" @@ -456,32 +488,6 @@ async def publish( # noqa: PLR0913 # Wait for confirmation await self._wait_for(confirmation.wait(), timeout=timeout) - async def _messages(self) -> AsyncGenerator[Message, None]: - """Async generator that yields messages from the underlying message queue.""" - while True: - # Wait until we either: - # 1. Receive a message - # 2. Disconnect from the broker - task = self._loop.create_task(self._queue.get()) - try: - done, _ = await asyncio.wait( - (task, self._disconnected), return_when=asyncio.FIRST_COMPLETED - ) - except asyncio.CancelledError: - # If the asyncio.wait is cancelled, we must make sure - # to also cancel the underlying tasks. - task.cancel() - raise - if task in done: - # We received a message. Return the result. - yield task.result() - else: - # We were disconnected from the broker - task.cancel() - # Stop the generator with an exception - msg = "Disconnected during message iteration" - raise MqttError(msg) - async def _wait_for( self, fut: Awaitable[T], timeout: float | None, **kwargs: Any ) -> T: diff --git a/tests/test_client.py b/tests/test_client.py index 6ffc877..4a54782 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio import logging import pathlib import ssl import sys +from typing import Any import anyio import anyio.abc @@ -413,7 +415,7 @@ async def test_messages_view_is_reusable() -> None: @pytest.mark.network async def test_messages_view_multiple_tasks_concurrently() -> None: """Test that ``.messages`` can be used concurrently by multiple tasks.""" - topic = TOPIC_PREFIX + "test_messages_generator_is_reentrant" + topic = TOPIC_PREFIX + "test_messages_view_multiple_tasks_concurrently" async with Client(HOSTNAME) as client, anyio.create_task_group() as tg: async def handle() -> None: @@ -426,3 +428,33 @@ async def handle() -> None: await client.subscribe(topic) await client.publish(topic, "foo") await client.publish(topic, "bar") + + +@pytest.mark.network +async def test_messages_view_len() -> None: + """Test that the ``__len__`` method of the messages view works correctly.""" + topic = TOPIC_PREFIX + "test_messages_view_len" + count = 3 + + class TestClient(Client): + fut: asyncio.Future[None] = asyncio.Future() + + def _on_message( + self, client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage + ) -> None: + super()._on_message(client, userdata, message) + self.fut.set_result(None) + self.fut = asyncio.Future() + + async with TestClient(HOSTNAME) as client: + assert len(client.messages) == 0 + await client.subscribe(topic, qos=2) + # Publish a message and wait for it to arrive + for index in range(count): + await client.publish(topic, None, qos=2) + await asyncio.wait_for(client.fut, timeout=1) + assert len(client.messages) == index + 1 + # Empty the queue + for _ in range(count): + await client.messages.__anext__() + assert len(client.messages) == 0