diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a214dbb --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,45 @@ +name: test suite + +on: + push: + branches: [master] + pull_request: + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest] + python-version: + ["3.7", "3.8", "3.9", "3.10", "3.11", pypy-3.8, pypy-3.9] + include: + - os: macos-latest + python-version: "3.7" + - os: macos-latest + python-version: "3.11" + - os: windows-latest + python-version: "3.7" + - os: windows-latest + python-version: "3.11" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: pip-test-${{ matrix.python-version }}-${{ matrix.os }} + - name: Install dependencies + run: pip install .[tests] + - name: Start Mosquitto MQTT Broker + uses: Namoshek/mosquitto-github-action@v1 + - name: Test with pytest + run: pytest --cov=asyncio_mqtt --cov-report=xml + - name: Upload coverage + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + name: ${{ matrix.os }} Python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index b5ee0d5..4f2a1f6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist local_test.py .venv .idea/ +.coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 499cad9..505bdb5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,8 @@ repos: hooks: - id: mypy additional_dependencies: + - anyio + - pytest - types-paho-mqtt == 1.6.0.1 - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/README.md b/README.md index 9a067e3..87ba5a7 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ License: BSD-3-Clause PyPI version PyPI downloads + Coverage + Coverage pre-commit.ci status Typing: strict Code Style: Black diff --git a/asyncio_mqtt/client.py b/asyncio_mqtt/client.py index ca7c140..c2b0b58 100644 --- a/asyncio_mqtt/client.py +++ b/asyncio_mqtt/client.py @@ -29,14 +29,14 @@ cast, ) -if sys.version_info >= (3, 10): +if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec -else: +else: # pragma: no cover from typing_extensions import ParamSpec if sys.version_info >= (3, 7): from contextlib import asynccontextmanager -else: +else: # pragma: no cover from async_generator import asynccontextmanager as _asynccontextmanager _P = ParamSpec("_P") @@ -415,7 +415,7 @@ async def publish( @asynccontextmanager async def filtered_messages( self, topic_filter: str, *, queue_maxsize: int = 0 - ) -> AsyncIterator[AsyncGenerator[mqtt.MQTTMessage, None]]: + ) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]: """Return async generator of messages that match the given filter. Use queue_maxsize to restrict the queue size. If the queue is full, @@ -441,7 +441,7 @@ async def filtered_messages( @asynccontextmanager async def unfiltered_messages( self, *, queue_maxsize: int = 0 - ) -> AsyncIterator[AsyncGenerator[mqtt.MQTTMessage, None]]: + ) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]: """Return async generator of all messages that are not caught in filters.""" # Early out if self._client.on_message is not None: diff --git a/setup.py b/setup.py index 176d37d..7bd009b 100644 --- a/setup.py +++ b/setup.py @@ -42,5 +42,6 @@ extras_require={ "lint": ["mypy>=0.982", "flake8>=5.0.4", "types-paho-mqtt>=1.6.0.1"], "format": ["black>=22.10.0", "isort>=5.10.1"], + "tests": ["pytest>=7.2.0", "pytest-cov>=4.0.0", "anyio>=3.6.2"], }, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5c53fe0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..829ae16 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,109 @@ +import anyio +import anyio.abc +import pytest + +from asyncio_mqtt import Client +from asyncio_mqtt.client import ProtocolVersion, Will + +pytestmark = pytest.mark.anyio + + +async def test_client_filtered_messages() -> None: + topic_header = "tests/asyncio_mqtt/filtered_messages/" + good_topic = topic_header + "good" + bad_topic = topic_header + "bad" + + async def handle_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.filtered_messages(good_topic) as messages: + async for message in messages: + assert message.topic == good_topic + tg.cancel_scope.cancel() + + async with Client("localhost") as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic_header + "#") + tg.start_soon(handle_messages, tg) + await client.publish(bad_topic) + await client.publish(good_topic) + + +async def test_client_unfiltered_messages() -> None: + topic_header = "tests/asyncio_mqtt/unfiltered_messages/" + topic_filtered = topic_header + "filtered" + topic_unfiltered = topic_header + "unfiltered" + + async def handle_unfiltered_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.unfiltered_messages() as messages: + async for message in messages: + assert message.topic == topic_unfiltered + tg.cancel_scope.cancel() + + async def handle_filtered_messages() -> None: + async with client.filtered_messages(topic_filtered) as messages: + async for message in messages: + assert message.topic == topic_filtered + + async with Client("localhost") as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic_header + "#") + tg.start_soon(handle_filtered_messages) + tg.start_soon(handle_unfiltered_messages, tg) + await client.publish(topic_filtered) + await client.publish(topic_unfiltered) + + +async def test_client_unsubscribe() -> None: + topic_header = "tests/asyncio_mqtt/unsubscribe/" + topic1 = topic_header + "1" + topic2 = topic_header + "2" + + async def handle_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.unfiltered_messages() as messages: + i = 0 + async for message in messages: + if i == 0: + assert message.topic == topic1 + elif i == 1: + assert message.topic == topic2 + tg.cancel_scope.cancel() + i += 1 + + async with Client("localhost") as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic1) + await client.subscribe(topic2) + tg.start_soon(handle_messages, tg) + await client.publish(topic1) + await client.unsubscribe(topic1) + await client.publish(topic1) + await client.publish(topic2) + + +@pytest.mark.parametrize( + "protocol, length", + ((ProtocolVersion.V31, 22), (ProtocolVersion.V311, 0), (ProtocolVersion.V5, 0)), +) +async def test_client_id(protocol: ProtocolVersion, length: int) -> None: + client = Client("localhost", protocol=protocol) + assert len(client.id) == length + + +async def test_client_will() -> None: + topic = "tests/asyncio_mqtt/will" + event = anyio.Event() + + async def launch_client() -> None: + with anyio.CancelScope(shield=True) as cs: + async with Client("localhost") as client: + await client.subscribe(topic) + event.set() + async with client.filtered_messages(topic) as messages: + async for message in messages: + assert message.topic == topic + cs.cancel() + + async with anyio.create_task_group() as tg: + tg.start_soon(launch_client) + await event.wait() + async with Client("localhost", will=Will(topic)) as client: + client._client._sock_close() # type: ignore[attr-defined]