From a19a5a450da5ca872336b08fa7c9b07b739cd9cf Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Wed, 24 Feb 2021 10:18:40 +0000 Subject: [PATCH 1/7] Add type stubs for client.py and error.py --- asyncio_mqtt/client.pyi | 67 +++++++++++++++++++++++++++++++++++++++++ asyncio_mqtt/error.pyi | 10 ++++++ 2 files changed, 77 insertions(+) create mode 100644 asyncio_mqtt/client.pyi create mode 100644 asyncio_mqtt/error.pyi diff --git a/asyncio_mqtt/client.pyi b/asyncio_mqtt/client.pyi new file mode 100644 index 0000000..50bb0ce --- /dev/null +++ b/asyncio_mqtt/client.pyi @@ -0,0 +1,67 @@ +import asyncio +import logging +import ssl +from types import TracebackType +from typing import Any, Generator, List, Optional, Type, Union + +from paho.mqtt import client as paho # type: ignore + +from .error import MqttCodeError as MqttCodeError +from .error import MqttConnectError as MqttConnectError +from .error import MqttError as MqttError + +MQTT_LOGGER: logging.Logger + +ProtocolType = Union[paho.MQTTv31, paho.MQTTv311, paho.MQTTv5] + +class Client: + def __init__( + self, + hostname: str, + port: int = ..., + *, + username: Optional[str] = ..., + password: Optional[str] = ..., + logger: Optional[logging.Logger] = ..., + client_id: Optional[str] = ..., + tls_context: Optional[ssl.SSLContext] = ..., + protocol: Optional[ProtocolType] = ..., + will: Optional[Will] = ..., + clean_session: Optional[bool] = ..., + transport: str = ... + ) -> None: ... + @property + def id(self) -> str: ... + async def connect(self, *, timeout: int = ...) -> None: ... + async def disconnect(self, *, timeout: int = ...) -> None: ... + async def force_disconnect(self) -> None: ... + async def subscribe( + self, *args: Any, timeout: int = ..., **kwargs: Any + ) -> List[int]: ... + async def unsubscribe(self, *args: Any, timeout: int = ...) -> None: ... + async def publish(self, *args: Any, timeout: int = ..., **kwargs: Any) -> None: ... + async def filtered_messages( + self, topic_filter: str, *, queue_maxsize: int = ... + ) -> Generator[asyncio.Queue[paho.MQTTMessage], None, None]: ... + async def unfiltered_messages( + self, *, queue_maxsize: int = ... + ) -> Generator[asyncio.Queue[paho.MQTTMessage], None, None]: ... + async def __aenter__(self) -> Client: ... + async def __aexit__( + self, exc_type: Type[Exception], exc: Exception, tb: TracebackType + ) -> None: ... + +class Will: + topic: str = ... + payload: Optional[Union[str, bytes, bytearray, int, float]] = ... + qos: int = ... + retain: bool = ... + properties: Optional[paho.Properties] = ... + def __init__( + self, + topic: str, + payload: Optional[Union[str, bytes, bytearray, int, float]] = ..., + qos: int = ..., + retain: bool = ..., + properties: Optional[paho.Properties] = ..., + ) -> None: ... diff --git a/asyncio_mqtt/error.pyi b/asyncio_mqtt/error.pyi new file mode 100644 index 0000000..6f7e8c8 --- /dev/null +++ b/asyncio_mqtt/error.pyi @@ -0,0 +1,10 @@ +from typing import Any + +class MqttError(Exception): ... + +class MqttCodeError(MqttError): + rc: int = ... + def __init__(self, rc: int, *args: Any) -> None: ... + +class MqttConnectError(MqttCodeError): + def __init__(self, rc: int) -> None: ... From e62cd1033f92fc511fc38fa8d026b91a001e22c1 Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Wed, 24 Feb 2021 10:25:23 +0000 Subject: [PATCH 2/7] Add py.typed to package. --- asyncio_mqtt/py.typed | 0 setup.py | 3 +++ 2 files changed, 3 insertions(+) create mode 100644 asyncio_mqtt/py.typed diff --git a/asyncio_mqtt/py.typed b/asyncio_mqtt/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 9743312..6dbfa5e 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,9 @@ name="asyncio_mqtt", version=__version__, packages=find_packages(), + package_data={ + "asyncio_mqtt": ["py.typed"], + }, url="https://github.com/sbtinstruments/asyncio-mqtt", author="Frederik Aalund", author_email="fpa@sbtinstruments.com", From 0a9f0363b92768015aa2473921a60d58b18d0e21 Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Fri, 26 Feb 2021 15:54:18 +0000 Subject: [PATCH 3/7] Add inline type hints. --- asyncio_mqtt/client.py | 256 +++++++++++++++++++++++++++------------- asyncio_mqtt/client.pyi | 67 ----------- asyncio_mqtt/error.py | 11 +- asyncio_mqtt/error.pyi | 10 -- asyncio_mqtt/types.py | 8 ++ 5 files changed, 192 insertions(+), 160 deletions(-) delete mode 100644 asyncio_mqtt/client.pyi delete mode 100644 asyncio_mqtt/error.pyi create mode 100644 asyncio_mqtt/types.py diff --git a/asyncio_mqtt/client.py b/asyncio_mqtt/client.py index c5c7191..db0607e 100644 --- a/asyncio_mqtt/client.py +++ b/asyncio_mqtt/client.py @@ -2,49 +2,90 @@ import asyncio import logging import socket +import ssl from contextlib import contextmanager, suppress +from types import FunctionType, TracebackType +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + cast, +) try: from contextlib import asynccontextmanager except ImportError: - from async_generator import asynccontextmanager -import paho.mqtt.client as mqtt -from .error import MqttError, MqttCodeError, MqttConnectError + from async_generator import asynccontextmanager # type: ignore +import paho.mqtt.client as mqtt # type: ignore + +from .error import MqttCodeError, MqttConnectError, MqttError +from .types import ( + PayloadType, + ProtocolType, + T, +) MQTT_LOGGER = logging.getLogger("mqtt") MQTT_LOGGER.setLevel(logging.WARNING) +# TODO: This should be a (frozen) dataclass (from Python 3.7) +# when we drop Python 3.6 support +class Will: + def __init__( + self, + topic: str, + payload: PayloadType = None, + qos: int = 0, + retain: bool = False, + properties: mqtt.Properties = None, + ): + self.topic = topic + self.payload = payload + self.qos = qos + self.retain = retain + self.properties = properties + + class Client: def __init__( self, - hostname, - port=1883, + hostname: str, + port: int = 1883, *, - username=None, - password=None, - logger=None, - client_id=None, - tls_context=None, - protocol=None, - will=None, - clean_session=None, - transport="tcp", + username: Optional[str] = None, + password: Optional[str] = None, + logger: logging.Logger = MQTT_LOGGER, + client_id: Optional[str] = None, + tls_context: Optional[ssl.SSLContext] = None, + protocol: ProtocolType = mqtt.MQTTv311, + will: Optional[Will] = None, + clean_session: Optional[bool] = None, + transport: str = "tcp", ): self._hostname = hostname self._port = port self._loop = asyncio.get_event_loop() - self._connected = asyncio.Future() - self._disconnected = asyncio.Future() - self._pending_calls = {} # Pending subscribe, unsubscribe, and publish calls - self._pending_calls_threshold = 10 - self._misc_task = None - - if protocol is None: - protocol = mqtt.MQTTv311 - - self._client = mqtt.Client( + self._connected: "asyncio.Future[int]" = asyncio.Future() + self._disconnected: "asyncio.Future[Optional[int]]" = asyncio.Future() + # Pending subscribe, unsubscribe, and publish calls + self._pending_subscribes: Dict[int, "asyncio.Future[int]"] = {} + self._pending_unsubscribes: Dict[int, asyncio.Event] = {} + self._pending_publishes: Dict[int, asyncio.Event] = {} + self._pending_calls_threshold: int = 10 + self._misc_task: Optional["asyncio.Task[None]"] = None + + self._client: mqtt.Client = mqtt.Client( client_id=client_id, protocol=protocol, clean_session=clean_session, @@ -62,8 +103,6 @@ def __init__( self._client.on_socket_register_write = self._on_socket_register_write self._client.on_socket_unregister_write = self._on_socket_unregister_write - if logger is None: - logger = MQTT_LOGGER self._client.enable_logger(logger) if username is not None and password is not None: @@ -78,16 +117,27 @@ def __init__( ) @property - def id(self): + def id(self) -> str: """Return the client ID. Note that paho-mqtt stores the client ID as `bytes` internally. We assume that the client ID is a UTF8-encoded string and decode it first. """ - return self._client._client_id.decode() + return cast(bytes, self._client._client_id).decode() - async def connect(self, *, timeout=10): + @property + def _pending_calls(self) -> Set[int]: + """ + Return a set of all message IDs with pending calls. + """ + mids: Set[int] = set() + mids.update(self._pending_subscribes.keys()) + mids.update(self._pending_unsubscribes.keys()) + mids.update(self._pending_publishes.keys()) + return mids + + async def connect(self, *, timeout: int = 10) -> None: try: loop = asyncio.get_running_loop() # [3] Run connect() within an executor thread, since it blocks on socket @@ -106,7 +156,7 @@ async def connect(self, *, timeout=10): raise MqttError(str(error)) await self._wait_for(self._connected, timeout=timeout) - async def disconnect(self, *, timeout=10): + async def disconnect(self, *, timeout: int = 10) -> None: rc = self._client.disconnect() # Early out on error if rc != mqtt.MQTT_ERR_SUCCESS: @@ -114,32 +164,32 @@ async def disconnect(self, *, timeout=10): # Wait for acknowledgement await self._wait_for(self._disconnected, timeout=timeout) - async def force_disconnect(self): + async def force_disconnect(self) -> None: self._disconnected.set_result(None) - async def subscribe(self, *args, timeout=10, **kwargs): + async def subscribe(self, *args: Any, timeout: int = 10, **kwargs: Any) -> int: result, mid = self._client.subscribe(*args, **kwargs) # Early out on error if result != mqtt.MQTT_ERR_SUCCESS: raise MqttCodeError(result, "Could not subscribe to topic") # Create future for when the on_subscribe callback is called - cb_result = asyncio.Future() - with self._pending_call(mid, cb_result): + cb_result: "asyncio.Future[int]" = asyncio.Future() + with self._pending_call(mid, cb_result, self._pending_subscribes): # Wait for cb_result return await self._wait_for(cb_result, timeout=timeout) - async def unsubscribe(self, *args, timeout=10): + async def unsubscribe(self, *args: Any, timeout: int = 10) -> None: result, mid = self._client.unsubscribe(*args) # Early out on error if result != mqtt.MQTT_ERR_SUCCESS: raise MqttCodeError(result, "Could not unsubscribe from topic") # Create event for when the on_unsubscribe callback is called confirmation = asyncio.Event() - with self._pending_call(mid, confirmation): + with self._pending_call(mid, confirmation, self._pending_unsubscribes): # Wait for confirmation await self._wait_for(confirmation.wait(), timeout=timeout) - async def publish(self, *args, timeout=10, **kwargs): + async def publish(self, *args: Any, timeout: int = 10, **kwargs: Any) -> None: info = self._client.publish(*args, **kwargs) # [2] # Early out on error if info.rc != mqtt.MQTT_ERR_SUCCESS: @@ -149,12 +199,14 @@ async def publish(self, *args, timeout=10, **kwargs): return # Create event for when the on_publish callback is called confirmation = asyncio.Event() - with self._pending_call(info.mid, confirmation): + with self._pending_call(info.mid, confirmation, self._pending_publishes): # Wait for confirmation await self._wait_for(confirmation.wait(), timeout=timeout) @asynccontextmanager - async def filtered_messages(self, topic_filter, *, queue_maxsize=0): + async def filtered_messages( + self, topic_filter: str, *, queue_maxsize: int = 0 + ) -> AsyncIterator[AsyncGenerator[mqtt.MQTTMessage, None]]: """Return async generator of messages that match the given filter. Use queue_maxsize to restrict the queue size. If the queue is full, @@ -178,7 +230,9 @@ async def filtered_messages(self, topic_filter, *, queue_maxsize=0): self._client.message_callback_remove(topic_filter) @asynccontextmanager - async def unfiltered_messages(self, *, queue_maxsize=0): + async def unfiltered_messages( + self, *, queue_maxsize: int = 0 + ) -> AsyncIterator[AsyncGenerator[mqtt.MQTTMessage, None]]: """Return async generator of all messages that are not caught in filters.""" # Early out if self._client.on_message is not None: @@ -197,11 +251,20 @@ async def unfiltered_messages(self, *, queue_maxsize=0): # We are exitting the with statement. Unset the callback. self._client.on_message = None - def _cb_and_generator(self, *, log_context, queue_maxsize=0): + def _cb_and_generator( + self, *, log_context: str, queue_maxsize: int = 0 + ) -> Tuple[ + Callable[[mqtt.Client, Any, mqtt.MQTTMessage], None], + AsyncGenerator[mqtt.MQTTMessage, None], + ]: # Queue to hold the incoming messages - messages = asyncio.Queue(maxsize=queue_maxsize) + messages: "asyncio.Queue[mqtt.MQTTMessage]" = asyncio.Queue( + maxsize=queue_maxsize + ) # Callback for the underlying API - def _put_in_queue(client, userdata, msg): + def _put_in_queue( + client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage + ) -> None: try: messages.put_nowait(msg) except asyncio.QueueFull: @@ -210,13 +273,15 @@ def _put_in_queue(client, userdata, msg): ) # The generator that we give to the caller - async def _message_generator(): + async def _message_generator() -> AsyncGenerator[mqtt.MQTTMessage, None]: # Forward all messages from the queue while True: # Wait until we either: # 1. Receive a message # 2. Disconnect from the broker - get = self._loop.create_task(messages.get()) + get: "asyncio.Task[mqtt.MQTTMessage]" = self._loop.create_task( + messages.get() + ) try: done, _ = await asyncio.wait( (get, self._disconnected), return_when=asyncio.FIRST_COMPLETED @@ -237,19 +302,23 @@ async def _message_generator(): return _put_in_queue, _message_generator() - async def _wait_for(self, *args, **kwargs): + async def _wait_for( + self, fut: Awaitable[T], timeout: Optional[float], **kwargs: Any + ) -> T: try: - return await asyncio.wait_for(*args, **kwargs) + return await asyncio.wait_for(fut, timeout=timeout, **kwargs) except asyncio.TimeoutError: raise MqttError("Operation timed out") @contextmanager - def _pending_call(self, mid, value): + def _pending_call( + self, mid: int, value: T, pending_dict: Dict[int, T] + ) -> Iterator[None]: if mid in self._pending_calls: raise RuntimeError( f'There already exists a pending call for message ID "{mid}"' ) - self._pending_calls[mid] = value # [1] + pending_dict[mid] = value # [1] try: # Log a warning if there is a concerning number of pending calls pending = len(self._pending_calls) @@ -264,9 +333,19 @@ def _pending_call(self, mid, value): # # However, if the callback doesn't get called (e.g., due to a # network error) we still need to remove the item from the dict. - self._pending_calls.pop(mid, None) + try: + del pending_dict[mid] + except KeyError: + pass - def _on_connect(self, client, userdata, flags, rc, properties=None): + def _on_connect( + self, + client: mqtt.Client, + userdata: Any, + flags: Dict[str, int], + rc: int, + properties: Optional[mqtt.Properties] = None, + ) -> None: # Return early if already connected. Sometimes, paho-mqtt calls _on_connect # multiple times. Maybe because we receive multiple CONNACK messages # from the server. In any case, we return early so that we don't set @@ -279,7 +358,13 @@ def _on_connect(self, client, userdata, flags, rc, properties=None): else: self._connected.set_exception(MqttConnectError(rc)) - def _on_disconnect(self, client, userdata, rc, properties=None): + def _on_disconnect( + self, + client: mqtt.Client, + userdata: Any, + rc: int, + properties: Optional[mqtt.Properties] = None, + ) -> None: # Return early if the disconnect is already acknowledged. # Sometimes (e.g., due to timeouts), paho-mqtt calls _on_disconnect # twice. We return early to avoid setting self._disconnected twice @@ -302,67 +387,91 @@ def _on_disconnect(self, client, userdata, rc, properties=None): else: self._disconnected.set_exception(MqttCodeError(rc, "Unexpected disconnect")) - def _on_subscribe(self, client, userdata, mid, granted_qos, properties=None): + def _on_subscribe( + self, + client: mqtt.Client, + userdata: Any, + mid: int, + granted_qos: int, + properties: mqtt.Properties = None, + ) -> None: try: - self._pending_calls.pop(mid).set_result(granted_qos) + self._pending_subscribes.pop(mid).set_result(granted_qos) except KeyError: MQTT_LOGGER.error(f'Unexpected message ID "{mid}" in on_subscribe callback') - def _on_unsubscribe(self, client, userdata, mid, properties=None, reasonCodes=None): + def _on_unsubscribe( + self, + client: mqtt.Client, + userdata: Any, + mid: int, + properties: Optional[mqtt.Properties] = None, + reasonCodes: Optional[List[mqtt.ReasonCodes]] = None, + ) -> None: try: - self._pending_calls.pop(mid).set() + self._pending_unsubscribes.pop(mid).set() except KeyError: MQTT_LOGGER.error( f'Unexpected message ID "{mid}" in on_unsubscribe callback' ) - def _on_publish(self, client, userdata, mid): + def _on_publish(self, client: mqtt.Client, userdata: Any, mid: int) -> None: try: - self._pending_calls.pop(mid).set() + self._pending_publishes.pop(mid).set() except KeyError: # Do nothing since [2] may call on_publish before it even returns. # That is, the message may already be published before we even get a # chance to set up the 'pending_call' logic. pass - def _on_socket_open(self, client, userdata, sock): - def cb(): + def _on_socket_open( + self, client: mqtt.Client, userdata: Any, sock: socket.socket + ) -> None: + def cb() -> None: client.loop_read() self._loop.add_reader(sock.fileno(), cb) # paho-mqtt calls this function from the executor thread on which we've called # `self._client.connect()` (see [3]), so we create a callback function to schedule # `_misc_loop()` and run it on the loop thread-safely. - def create_task_cb(): + def create_task_cb() -> None: self._misc_task = self._loop.create_task(self._misc_loop()) self._loop.call_soon_threadsafe(create_task_cb) - def _on_socket_close(self, client, userdata, sock): + def _on_socket_close( + self, client: mqtt.Client, userdata: Any, sock: socket.socket + ) -> None: self._loop.remove_reader(sock.fileno()) if self._misc_task is not None: with suppress(asyncio.CancelledError): self._misc_task.cancel() - def _on_socket_register_write(self, client, userdata, sock): - def cb(): + def _on_socket_register_write( + self, client: mqtt.Client, userdata: Any, sock: socket.socket + ) -> None: + def cb() -> None: client.loop_write() self._loop.add_writer(sock, cb) - def _on_socket_unregister_write(self, client, userdata, sock): + def _on_socket_unregister_write( + self, client: mqtt.Client, userdata: Any, sock: socket.socket + ) -> None: self._loop.remove_writer(sock) - async def _misc_loop(self): + async def _misc_loop(self) -> None: while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: await asyncio.sleep(1) - async def __aenter__(self): + async def __aenter__(self) -> "Client": """Connect to the broker.""" await self.connect() return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, exc_type: Type[Exception], exc: Exception, tb: TracebackType + ) -> None: """Disconnect from the broker.""" # Early out if already disconnected... if self._disconnected.done(): @@ -381,14 +490,3 @@ async def __aexit__(self, exc_type, exc, tb): f'Could not gracefully disconnect due to "{error}". Forcing disconnection.' ) await self.force_disconnect() - - -# TODO: This should be a (frozen) dataclass (from Python 3.7) -# when we drop Python 3.6 support -class Will: - def __init__(self, topic, payload=None, qos=0, retain=False, properties=None): - self.topic = topic - self.payload = payload - self.qos = qos - self.retain = retain - self.properties = properties diff --git a/asyncio_mqtt/client.pyi b/asyncio_mqtt/client.pyi deleted file mode 100644 index 50bb0ce..0000000 --- a/asyncio_mqtt/client.pyi +++ /dev/null @@ -1,67 +0,0 @@ -import asyncio -import logging -import ssl -from types import TracebackType -from typing import Any, Generator, List, Optional, Type, Union - -from paho.mqtt import client as paho # type: ignore - -from .error import MqttCodeError as MqttCodeError -from .error import MqttConnectError as MqttConnectError -from .error import MqttError as MqttError - -MQTT_LOGGER: logging.Logger - -ProtocolType = Union[paho.MQTTv31, paho.MQTTv311, paho.MQTTv5] - -class Client: - def __init__( - self, - hostname: str, - port: int = ..., - *, - username: Optional[str] = ..., - password: Optional[str] = ..., - logger: Optional[logging.Logger] = ..., - client_id: Optional[str] = ..., - tls_context: Optional[ssl.SSLContext] = ..., - protocol: Optional[ProtocolType] = ..., - will: Optional[Will] = ..., - clean_session: Optional[bool] = ..., - transport: str = ... - ) -> None: ... - @property - def id(self) -> str: ... - async def connect(self, *, timeout: int = ...) -> None: ... - async def disconnect(self, *, timeout: int = ...) -> None: ... - async def force_disconnect(self) -> None: ... - async def subscribe( - self, *args: Any, timeout: int = ..., **kwargs: Any - ) -> List[int]: ... - async def unsubscribe(self, *args: Any, timeout: int = ...) -> None: ... - async def publish(self, *args: Any, timeout: int = ..., **kwargs: Any) -> None: ... - async def filtered_messages( - self, topic_filter: str, *, queue_maxsize: int = ... - ) -> Generator[asyncio.Queue[paho.MQTTMessage], None, None]: ... - async def unfiltered_messages( - self, *, queue_maxsize: int = ... - ) -> Generator[asyncio.Queue[paho.MQTTMessage], None, None]: ... - async def __aenter__(self) -> Client: ... - async def __aexit__( - self, exc_type: Type[Exception], exc: Exception, tb: TracebackType - ) -> None: ... - -class Will: - topic: str = ... - payload: Optional[Union[str, bytes, bytearray, int, float]] = ... - qos: int = ... - retain: bool = ... - properties: Optional[paho.Properties] = ... - def __init__( - self, - topic: str, - payload: Optional[Union[str, bytes, bytearray, int, float]] = ..., - qos: int = ..., - retain: bool = ..., - properties: Optional[paho.Properties] = ..., - ) -> None: ... diff --git a/asyncio_mqtt/error.py b/asyncio_mqtt/error.py index 8c70415..72992b3 100644 --- a/asyncio_mqtt/error.py +++ b/asyncio_mqtt/error.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: BSD-3-Clause +from typing import Any, Dict + + class MqttError(Exception): """Base exception for all asyncio-mqtt exceptions.""" @@ -8,16 +11,16 @@ class MqttError(Exception): class MqttCodeError(MqttError): - def __init__(self, rc, *args): + def __init__(self, rc: int, *args: Any): super().__init__(*args) self.rc = rc - def __str__(self): + def __str__(self) -> str: return f"[code:{self.rc}] {super().__str__()}" class MqttConnectError(MqttCodeError): - def __init__(self, rc): + def __init__(self, rc: int): msg = "Connection refused" try: msg += f": {_CONNECT_RC_STRINGS[rc]}" @@ -26,7 +29,7 @@ def __init__(self, rc): super().__init__(rc, msg) -_CONNECT_RC_STRINGS = { +_CONNECT_RC_STRINGS: Dict[int, str] = { # Reference: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1898 # 0: Connection successful # 1: Connection refused - incorrect protocol version diff --git a/asyncio_mqtt/error.pyi b/asyncio_mqtt/error.pyi deleted file mode 100644 index 6f7e8c8..0000000 --- a/asyncio_mqtt/error.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any - -class MqttError(Exception): ... - -class MqttCodeError(MqttError): - rc: int = ... - def __init__(self, rc: int, *args: Any) -> None: ... - -class MqttConnectError(MqttCodeError): - def __init__(self, rc: int) -> None: ... diff --git a/asyncio_mqtt/types.py b/asyncio_mqtt/types.py new file mode 100644 index 0000000..8dbdf29 --- /dev/null +++ b/asyncio_mqtt/types.py @@ -0,0 +1,8 @@ +import asyncio +from typing import Any, Awaitable, Generator, Optional, TypeVar, Union +from paho.mqtt import client as paho # type: ignore + +T = TypeVar("T") + +ProtocolType = Union[paho.MQTTv31, paho.MQTTv311, paho.MQTTv5] +PayloadType = Optional[Union[str, bytes, bytearray, int, float]] From 6756d62b6743e841af091effd3b07308cf2cf284 Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Fri, 26 Feb 2021 16:57:18 +0000 Subject: [PATCH 4/7] < 500 is still good! --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d0bf2ed..b8b9483 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ asyncio-mqtt combines the stability of the time-proven [paho-mqtt](https://githu * Compatible with `async` code * Did we mention no more callbacks? -The whole thing is less than [400 lines of code](https://github.com/sbtinstruments/asyncio-mqtt/blob/master/asyncio_mqtt/client.py). +The whole thing is less than [500 lines of code](https://github.com/sbtinstruments/asyncio-mqtt/blob/master/asyncio_mqtt/client.py). ## Installation 📚 From f73bd651477d5095e90d40db9e54ab5283ddfeef Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Fri, 26 Feb 2021 16:58:36 +0000 Subject: [PATCH 5/7] Add type-hinting to list of features --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b8b9483..07509cd 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ asyncio-mqtt combines the stability of the time-proven [paho-mqtt](https://githu * No more return codes (welcome to the `MqttError`) * Graceful disconnection (forget about `on_unsubscribe`, `on_disconnect`, etc.) * Compatible with `async` code +* Fully type-hinted * Did we mention no more callbacks? The whole thing is less than [500 lines of code](https://github.com/sbtinstruments/asyncio-mqtt/blob/master/asyncio_mqtt/client.py). From d3815b24f246dbdee7ba1c9ac9d21325ec6f0fbe Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Fri, 26 Feb 2021 17:26:37 +0000 Subject: [PATCH 6/7] Remove unused import --- asyncio_mqtt/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio_mqtt/client.py b/asyncio_mqtt/client.py index db0607e..85ad86f 100644 --- a/asyncio_mqtt/client.py +++ b/asyncio_mqtt/client.py @@ -4,7 +4,7 @@ import socket import ssl from contextlib import contextmanager, suppress -from types import FunctionType, TracebackType +from types import TracebackType from typing import ( Any, AsyncGenerator, From 5231cf6293428b30a4c51de9932ebc1738b08337 Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Fri, 26 Feb 2021 19:37:16 +0000 Subject: [PATCH 7/7] Fix issues and suggestions from review --- asyncio_mqtt/client.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/asyncio_mqtt/client.py b/asyncio_mqtt/client.py index 85ad86f..68093cb 100644 --- a/asyncio_mqtt/client.py +++ b/asyncio_mqtt/client.py @@ -12,6 +12,7 @@ Awaitable, Callable, Dict, + Generator, Iterator, List, Optional, @@ -45,10 +46,10 @@ class Will: def __init__( self, topic: str, - payload: PayloadType = None, + payload: Optional[PayloadType] = None, qos: int = 0, retain: bool = False, - properties: mqtt.Properties = None, + properties: Optional[mqtt.Properties] = None, ): self.topic = topic self.payload = payload @@ -65,10 +66,10 @@ def __init__( *, username: Optional[str] = None, password: Optional[str] = None, - logger: logging.Logger = MQTT_LOGGER, + logger: Optional[logging.Logger] = None, client_id: Optional[str] = None, tls_context: Optional[ssl.SSLContext] = None, - protocol: ProtocolType = mqtt.MQTTv311, + protocol: Optional[ProtocolType] = None, will: Optional[Will] = None, clean_session: Optional[bool] = None, transport: str = "tcp", @@ -85,6 +86,9 @@ def __init__( self._pending_calls_threshold: int = 10 self._misc_task: Optional["asyncio.Task[None]"] = None + if protocol is None: + protocol = mqtt.MQTTv311 + self._client: mqtt.Client = mqtt.Client( client_id=client_id, protocol=protocol, @@ -103,6 +107,8 @@ def __init__( self._client.on_socket_register_write = self._on_socket_register_write self._client.on_socket_unregister_write = self._on_socket_unregister_write + if logger is None: + logger = MQTT_LOGGER self._client.enable_logger(logger) if username is not None and password is not None: @@ -127,15 +133,13 @@ def id(self) -> str: return cast(bytes, self._client._client_id).decode() @property - def _pending_calls(self) -> Set[int]: + def _pending_calls(self) -> Generator[int, None, None]: """ - Return a set of all message IDs with pending calls. + Yield all message IDs with pending calls. """ - mids: Set[int] = set() - mids.update(self._pending_subscribes.keys()) - mids.update(self._pending_unsubscribes.keys()) - mids.update(self._pending_publishes.keys()) - return mids + yield from self._pending_subscribes.keys() + yield from self._pending_unsubscribes.keys() + yield from self._pending_publishes.keys() async def connect(self, *, timeout: int = 10) -> None: try: @@ -321,7 +325,7 @@ def _pending_call( pending_dict[mid] = value # [1] try: # Log a warning if there is a concerning number of pending calls - pending = len(self._pending_calls) + pending = len(list(self._pending_calls)) if pending > self._pending_calls_threshold: MQTT_LOGGER.warning(f"There are {pending} pending publish calls.") # Back to the caller (run whatever is inside the with statement) @@ -393,7 +397,7 @@ def _on_subscribe( userdata: Any, mid: int, granted_qos: int, - properties: mqtt.Properties = None, + properties: Optional[mqtt.Properties] = None, ) -> None: try: self._pending_subscribes.pop(mid).set_result(granted_qos)