diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..a214dbb
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,45 @@
+name: test suite
+
+on:
+ push:
+ branches: [master]
+ pull_request:
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version:
+ ["3.7", "3.8", "3.9", "3.10", "3.11", pypy-3.8, pypy-3.9]
+ include:
+ - os: macos-latest
+ python-version: "3.7"
+ - os: macos-latest
+ python-version: "3.11"
+ - os: windows-latest
+ python-version: "3.7"
+ - os: windows-latest
+ python-version: "3.11"
+ runs-on: ${{ matrix.os }}
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/cache@v3
+ with:
+ path: ~/.cache/pip
+ key: pip-test-${{ matrix.python-version }}-${{ matrix.os }}
+ - name: Install dependencies
+ run: pip install .[tests]
+ - name: Start Mosquitto MQTT Broker
+ uses: Namoshek/mosquitto-github-action@v1
+ - name: Test with pytest
+ run: pytest --cov=asyncio_mqtt --cov-report=xml
+ - name: Upload coverage
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ name: ${{ matrix.os }} Python ${{ matrix.python-version }}
diff --git a/.gitignore b/.gitignore
index b5ee0d5..4f2a1f6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,4 @@ dist
local_test.py
.venv
.idea/
+.coverage
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 499cad9..505bdb5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -25,6 +25,8 @@ repos:
hooks:
- id: mypy
additional_dependencies:
+ - anyio
+ - pytest
- types-paho-mqtt == 1.6.0.1
- repo: https://github.com/pre-commit/mirrors-prettier
diff --git a/README.md b/README.md
index 9a067e3..87ba5a7 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,8 @@
+
+
diff --git a/asyncio_mqtt/client.py b/asyncio_mqtt/client.py
index ca7c140..c2b0b58 100644
--- a/asyncio_mqtt/client.py
+++ b/asyncio_mqtt/client.py
@@ -29,14 +29,14 @@
cast,
)
-if sys.version_info >= (3, 10):
+if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
-else:
+else: # pragma: no cover
from typing_extensions import ParamSpec
if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager
-else:
+else: # pragma: no cover
from async_generator import asynccontextmanager as _asynccontextmanager
_P = ParamSpec("_P")
@@ -415,7 +415,7 @@ async def publish(
@asynccontextmanager
async def filtered_messages(
self, topic_filter: str, *, queue_maxsize: int = 0
- ) -> AsyncIterator[AsyncGenerator[mqtt.MQTTMessage, None]]:
+ ) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]:
"""Return async generator of messages that match the given filter.
Use queue_maxsize to restrict the queue size. If the queue is full,
@@ -441,7 +441,7 @@ async def filtered_messages(
@asynccontextmanager
async def unfiltered_messages(
self, *, queue_maxsize: int = 0
- ) -> AsyncIterator[AsyncGenerator[mqtt.MQTTMessage, None]]:
+ ) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]:
"""Return async generator of all messages that are not caught in filters."""
# Early out
if self._client.on_message is not None:
diff --git a/setup.py b/setup.py
index 176d37d..7bd009b 100644
--- a/setup.py
+++ b/setup.py
@@ -42,5 +42,6 @@
extras_require={
"lint": ["mypy>=0.982", "flake8>=5.0.4", "types-paho-mqtt>=1.6.0.1"],
"format": ["black>=22.10.0", "isort>=5.10.1"],
+ "tests": ["pytest>=7.2.0", "pytest-cov>=4.0.0", "anyio>=3.6.2"],
},
)
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..5c53fe0
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,6 @@
+import pytest
+
+
+@pytest.fixture
+def anyio_backend() -> str:
+ return "asyncio"
diff --git a/tests/test_client.py b/tests/test_client.py
new file mode 100644
index 0000000..829ae16
--- /dev/null
+++ b/tests/test_client.py
@@ -0,0 +1,109 @@
+import anyio
+import anyio.abc
+import pytest
+
+from asyncio_mqtt import Client
+from asyncio_mqtt.client import ProtocolVersion, Will
+
+pytestmark = pytest.mark.anyio
+
+
+async def test_client_filtered_messages() -> None:
+ topic_header = "tests/asyncio_mqtt/filtered_messages/"
+ good_topic = topic_header + "good"
+ bad_topic = topic_header + "bad"
+
+ async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
+ async with client.filtered_messages(good_topic) as messages:
+ async for message in messages:
+ assert message.topic == good_topic
+ tg.cancel_scope.cancel()
+
+ async with Client("localhost") as client:
+ async with anyio.create_task_group() as tg:
+ await client.subscribe(topic_header + "#")
+ tg.start_soon(handle_messages, tg)
+ await client.publish(bad_topic)
+ await client.publish(good_topic)
+
+
+async def test_client_unfiltered_messages() -> None:
+ topic_header = "tests/asyncio_mqtt/unfiltered_messages/"
+ topic_filtered = topic_header + "filtered"
+ topic_unfiltered = topic_header + "unfiltered"
+
+ async def handle_unfiltered_messages(tg: anyio.abc.TaskGroup) -> None:
+ async with client.unfiltered_messages() as messages:
+ async for message in messages:
+ assert message.topic == topic_unfiltered
+ tg.cancel_scope.cancel()
+
+ async def handle_filtered_messages() -> None:
+ async with client.filtered_messages(topic_filtered) as messages:
+ async for message in messages:
+ assert message.topic == topic_filtered
+
+ async with Client("localhost") as client:
+ async with anyio.create_task_group() as tg:
+ await client.subscribe(topic_header + "#")
+ tg.start_soon(handle_filtered_messages)
+ tg.start_soon(handle_unfiltered_messages, tg)
+ await client.publish(topic_filtered)
+ await client.publish(topic_unfiltered)
+
+
+async def test_client_unsubscribe() -> None:
+ topic_header = "tests/asyncio_mqtt/unsubscribe/"
+ topic1 = topic_header + "1"
+ topic2 = topic_header + "2"
+
+ async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
+ async with client.unfiltered_messages() as messages:
+ i = 0
+ async for message in messages:
+ if i == 0:
+ assert message.topic == topic1
+ elif i == 1:
+ assert message.topic == topic2
+ tg.cancel_scope.cancel()
+ i += 1
+
+ async with Client("localhost") as client:
+ async with anyio.create_task_group() as tg:
+ await client.subscribe(topic1)
+ await client.subscribe(topic2)
+ tg.start_soon(handle_messages, tg)
+ await client.publish(topic1)
+ await client.unsubscribe(topic1)
+ await client.publish(topic1)
+ await client.publish(topic2)
+
+
+@pytest.mark.parametrize(
+ "protocol, length",
+ ((ProtocolVersion.V31, 22), (ProtocolVersion.V311, 0), (ProtocolVersion.V5, 0)),
+)
+async def test_client_id(protocol: ProtocolVersion, length: int) -> None:
+ client = Client("localhost", protocol=protocol)
+ assert len(client.id) == length
+
+
+async def test_client_will() -> None:
+ topic = "tests/asyncio_mqtt/will"
+ event = anyio.Event()
+
+ async def launch_client() -> None:
+ with anyio.CancelScope(shield=True) as cs:
+ async with Client("localhost") as client:
+ await client.subscribe(topic)
+ event.set()
+ async with client.filtered_messages(topic) as messages:
+ async for message in messages:
+ assert message.topic == topic
+ cs.cancel()
+
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(launch_client)
+ await event.wait()
+ async with Client("localhost", will=Will(topic)) as client:
+ client._client._sock_close() # type: ignore[attr-defined]