Skip to content

Commit

Permalink
refactor: initial migration to paho-mqtt 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanPlasse committed Mar 18, 2024
1 parent f7697de commit 1862ead
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 105 deletions.
44 changes: 25 additions & 19 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
Generator,
Iterable,
Iterator,
Literal,
TypeVar,
cast,
)

import paho.mqtt.client as mqtt
from paho.mqtt.enums import CallbackAPIVersion
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCodes
from paho.mqtt.subscribeoptions import SubscribeOptions

from .exceptions import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError
from .message import Message
Expand Down Expand Up @@ -116,7 +121,7 @@ class Will:
payload: PayloadType | None = None
qos: int = 0
retain: bool = False
properties: mqtt.Properties | None = None
properties: Properties | None = None


class Client:
Expand Down Expand Up @@ -185,7 +190,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
protocol: ProtocolVersion | None = None,
will: Will | None = None,
clean_session: bool | None = None,
transport: str = "tcp",
transport: Literal["tcp", "websockets"] = "tcp",
timeout: float | None = None,
keepalive: int = 60,
bind_address: str = "",
Expand All @@ -195,7 +200,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
max_queued_outgoing_messages: int | None = None,
max_inflight_messages: int | None = None,
max_concurrent_outgoing_calls: int | None = None,
properties: mqtt.Properties | None = None,
properties: Properties | None = None,
tls_context: ssl.SSLContext | None = None,
tls_params: TLSParameters | None = None,
tls_insecure: bool | None = None,
Expand All @@ -220,7 +225,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915

# Pending subscribe, unsubscribe, and publish calls
self._pending_subscribes: dict[
int, asyncio.Future[tuple[int] | list[mqtt.ReasonCodes]]
int, asyncio.Future[tuple[int, ...] | list[ReasonCodes]]
] = {}
self._pending_unsubscribes: dict[int, asyncio.Event] = {}
self._pending_publishes: dict[int, asyncio.Event] = {}
Expand All @@ -247,6 +252,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915

# Create the underlying paho-mqtt client instance
self._client: mqtt.Client = mqtt.Client(
callback_api_version=CallbackAPIVersion.VERSION1,
client_id=identifier,
protocol=protocol,
clean_session=clean_session,
Expand Down Expand Up @@ -322,7 +328,7 @@ def identifier(self) -> str:
Note that paho-mqtt stores the client ID as `bytes` internally. We assume that
the client ID is a UTF8-encoded string and decode it first.
"""
return cast(bytes, self._client._client_id).decode() # type: ignore[attr-defined] # noqa: SLF001
return self._client._client_id.decode() # noqa: SLF001

@property
def _pending_calls(self) -> Generator[int, None, None]:
Expand All @@ -337,12 +343,12 @@ async def subscribe( # noqa: PLR0913
/,
topic: SubscribeTopic,
qos: int = 0,
options: mqtt.SubscribeOptions | None = None,
properties: mqtt.Properties | None = None,
options: SubscribeOptions | None = None,
properties: Properties | None = None,
*args: Any,
timeout: float | None = None,
**kwargs: Any,
) -> tuple[int] | list[mqtt.ReasonCodes]:
) -> tuple[int] | list[ReasonCodes]:
"""Subscribe to a topic or wildcard.
Args:
Expand All @@ -366,7 +372,7 @@ async def subscribe( # noqa: PLR0913
raise MqttCodeError(result, "Could not subscribe to topic")
# Create future for when the on_subscribe callback is called
callback_result: asyncio.Future[
tuple[int] | list[mqtt.ReasonCodes]
tuple[int] | list[ReasonCodes]
] = asyncio.Future()
with self._pending_call(mid, callback_result, self._pending_subscribes):
# Wait for callback_result
Expand All @@ -377,7 +383,7 @@ async def unsubscribe(
self,
/,
topic: str | list[str],
properties: mqtt.Properties | None = None,
properties: Properties | None = None,
*args: Any,
timeout: float | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -412,7 +418,7 @@ async def publish( # noqa: PLR0913
payload: PayloadType = None,
qos: int = 0,
retain: bool = False,
properties: mqtt.Properties | None = None,
properties: Properties | None = None,
*args: Any,
timeout: float | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -518,8 +524,8 @@ def _on_connect( # noqa: PLR0913
client: mqtt.Client,
userdata: Any,
flags: dict[str, int],
rc: int | mqtt.ReasonCodes,
properties: mqtt.Properties | None = None,
rc: int | ReasonCodes,
properties: Properties | None = None,
) -> None:
"""Called when we receive a CONNACK message from the broker."""
# Return early if already connected. Sometimes, paho-mqtt calls _on_connect
Expand All @@ -538,8 +544,8 @@ def _on_disconnect(
self,
client: mqtt.Client,
userdata: Any,
rc: int | mqtt.ReasonCodes | None,
properties: mqtt.Properties | None = None,
rc: int | ReasonCodes | None,
properties: Properties | None = None,
) -> None:
# Return early if the disconnect is already acknowledged.
# Sometimes (e.g., due to timeouts), paho-mqtt calls _on_disconnect
Expand Down Expand Up @@ -570,8 +576,8 @@ def _on_subscribe( # noqa: PLR0913
client: mqtt.Client,
userdata: Any,
mid: int,
granted_qos: tuple[int] | list[mqtt.ReasonCodes],
properties: mqtt.Properties | None = None,
granted_qos: tuple[int, ...] | list[ReasonCodes],
properties: Properties | None = None,
) -> None:
"""Called when we receive a SUBACK message from the broker."""
try:
Expand All @@ -588,8 +594,8 @@ def _on_unsubscribe( # noqa: PLR0913
client: mqtt.Client,
userdata: Any,
mid: int,
properties: mqtt.Properties | None = None,
reason_codes: list[mqtt.ReasonCodes] | mqtt.ReasonCodes | None = None,
properties: Properties | None = None,
reason_codes: list[ReasonCodes] | ReasonCodes | None = None,
) -> None:
"""Called when we receive an UNSUBACK message from the broker."""
try:
Expand Down
9 changes: 5 additions & 4 deletions aiomqtt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
from typing import Any

import paho.mqtt.client as mqtt
from paho.mqtt.reasoncodes import ReasonCodes


class MqttError(Exception):
pass


class MqttCodeError(MqttError):
def __init__(self, rc: int | mqtt.ReasonCodes | None, *args: Any) -> None:
def __init__(self, rc: int | ReasonCodes | None, *args: Any) -> None:
super().__init__(*args)
self.rc = rc

def __str__(self) -> str:
if isinstance(self.rc, mqtt.ReasonCodes):
if isinstance(self.rc, ReasonCodes):
return f"[code:{self.rc.value}] {self.rc!s}"
if isinstance(self.rc, int):
return f"[code:{self.rc}] {mqtt.error_string(self.rc)}"
return f"[code:{self.rc}] {super().__str__()}"


class MqttConnectError(MqttCodeError):
def __init__(self, rc: int | mqtt.ReasonCodes) -> None:
if isinstance(rc, mqtt.ReasonCodes):
def __init__(self, rc: int | ReasonCodes) -> None:
if isinstance(rc, ReasonCodes):
super().__init__(rc)
return
msg = "Connection refused"
Expand Down
3 changes: 2 additions & 1 deletion aiomqtt/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import paho.mqtt.client as mqtt
from paho.mqtt.properties import Properties

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__( # noqa: PLR0913
qos: int,
retain: bool,
mid: int,
properties: mqtt.Properties | None,
properties: Properties | None,
) -> None:
self.topic = Topic(topic) if not isinstance(topic, Topic) else topic
self.payload = payload
Expand Down
6 changes: 3 additions & 3 deletions aiomqtt/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Any, Callable, TypeVar

import paho.mqtt.client as mqtt
from paho.mqtt.subscribeoptions import SubscribeOptions

if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeAlias
Expand All @@ -18,10 +18,10 @@
P = ParamSpec("P")

PayloadType: TypeAlias = "str | bytes | bytearray | int | float | None"
SubscribeTopic: TypeAlias = "str | tuple[str, mqtt.SubscribeOptions] | list[tuple[str, mqtt.SubscribeOptions]] | list[tuple[str, int]]"
SubscribeTopic: TypeAlias = "str | tuple[str, SubscribeOptions] | list[tuple[str, SubscribeOptions]] | list[tuple[str, int]]"
WebSocketHeaders: TypeAlias = (
"dict[str, str] | Callable[[dict[str, str]], dict[str, str]]"
)
_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | mqtt.WebsocketWrapper | Any"
_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | Any"
# See the overloads of `socket.setsockopt` for details.
SocketOption: TypeAlias = "tuple[int, int, int | bytes] | tuple[int, int, None, int]"
Loading

0 comments on commit 1862ead

Please sign in to comment.