Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace messages generator with iterator class that implements len() #323

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from types import TracebackType
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Expand Down Expand Up @@ -125,7 +125,7 @@ class Will:


class Client:
"""The async context manager that manages the connection to the broker.
"""Asynchronous context manager for the connection to the MQTT broker.

Args:
hostname: The hostname or IP address of the remote broker.
Expand Down Expand Up @@ -320,10 +320,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
timeout = 10
self.timeout = timeout

@property
def messages(self) -> AsyncGenerator[Message, None]:
return self._messages()

@property
def identifier(self) -> str:
"""Return the client identifier.
Expand All @@ -333,6 +329,42 @@ def identifier(self) -> str:
"""
return self._client._client_id.decode() # noqa: SLF001

class MessagesIterator:
"""Dynamic view of the message queue."""

def __init__(self, client: Client) -> None:
self._client = client

def __aiter__(self) -> AsyncIterator[Message]:
return self

async def __anext__(self) -> Message:
# Wait until we either (1) receive a message or (2) disconnect
task = self._client._loop.create_task(self._client._queue.get()) # noqa: SLF001
try:
done, _ = await asyncio.wait(
(task, self._client._disconnected), # noqa: SLF001
return_when=asyncio.FIRST_COMPLETED,
)
# If the asyncio.wait is cancelled, we must also cancel the queue task
except asyncio.CancelledError:
task.cancel()
raise
# When we receive a message, return it
if task in done:
return task.result()
# If we disconnect from the broker, stop the generator with an exception
task.cancel()
msg = "Disconnected during message iteration"
raise MqttError(msg)

def __len__(self) -> int:
return self._client._queue.qsize() # noqa: SLF001

@property
def messages(self) -> MessagesIterator:
return self.MessagesIterator(self)

@property
def _pending_calls(self) -> Generator[int, None, None]:
"""Yield all message IDs with pending calls."""
Expand Down Expand Up @@ -456,32 +488,6 @@ async def publish( # noqa: PLR0913
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)

async def _messages(self) -> AsyncGenerator[Message, None]:
"""Async generator that yields messages from the underlying message queue."""
while True:
# Wait until we either:
# 1. Receive a message
# 2. Disconnect from the broker
task = self._loop.create_task(self._queue.get())
try:
done, _ = await asyncio.wait(
(task, self._disconnected), return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
# If the asyncio.wait is cancelled, we must make sure
# to also cancel the underlying tasks.
task.cancel()
raise
if task in done:
# We received a message. Return the result.
yield task.result()
else:
# We were disconnected from the broker
task.cancel()
# Stop the generator with an exception
msg = "Disconnected during message iteration"
raise MqttError(msg)

async def _wait_for(
self, fut: Awaitable[T], timeout: float | None, **kwargs: Any
) -> T:
Expand Down
34 changes: 33 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import asyncio
import logging
import pathlib
import ssl
import sys
from typing import Any

import anyio
import anyio.abc
Expand Down Expand Up @@ -413,7 +415,7 @@ async def test_messages_view_is_reusable() -> None:
@pytest.mark.network
async def test_messages_view_multiple_tasks_concurrently() -> None:
"""Test that ``.messages`` can be used concurrently by multiple tasks."""
topic = TOPIC_PREFIX + "test_messages_generator_is_reentrant"
topic = TOPIC_PREFIX + "test_messages_view_multiple_tasks_concurrently"
async with Client(HOSTNAME) as client, anyio.create_task_group() as tg:

async def handle() -> None:
Expand All @@ -426,3 +428,33 @@ async def handle() -> None:
await client.subscribe(topic)
await client.publish(topic, "foo")
await client.publish(topic, "bar")


@pytest.mark.network
async def test_messages_view_len() -> None:
"""Test that the ``__len__`` method of the messages view works correctly."""
topic = TOPIC_PREFIX + "test_messages_view_len"
count = 3

class TestClient(Client):
fut: asyncio.Future[None] = asyncio.Future()

def _on_message(
self, client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage
) -> None:
super()._on_message(client, userdata, message)
self.fut.set_result(None)
self.fut = asyncio.Future()

async with TestClient(HOSTNAME) as client:
assert len(client.messages) == 0
await client.subscribe(topic, qos=2)
# Publish a message and wait for it to arrive
for index in range(count):
await client.publish(topic, None, qos=2)
await asyncio.wait_for(client.fut, timeout=1)
assert len(client.messages) == index + 1
# Empty the queue
for _ in range(count):
await client.messages.__anext__()
assert len(client.messages) == 0
Loading