diff --git a/aiomqtt/client.py b/aiomqtt/client.py index 8d64de4..b647a3b 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -129,6 +129,8 @@ class Client: password: The password to authenticate with. logger: Custom logger instance. identifier: The client identifier. Generated automatically if ``None``. + reconnect: If ``True``, the client will automatically reconnect to the broker + if the connection is lost. Defaults to ``False``. queue_type: The class to use for the queue. The default is ``asyncio.Queue``, which stores messages in FIFO order. For LIFO order, you can use ``asyncio.LifoQueue``; For priority order you can subclass @@ -181,6 +183,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 password: str | None = None, logger: logging.Logger | None = None, identifier: str | None = None, + reconnect: bool = False, queue_type: type[asyncio.Queue[Message]] | None = None, protocol: ProtocolVersion | None = None, will: Will | None = None, @@ -206,6 +209,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 ) -> None: self._hostname = hostname self._port = port + self._reconnect = reconnect self._keepalive = keepalive self._bind_address = bind_address self._bind_port = bind_port @@ -225,7 +229,10 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self._pending_unsubscribes: dict[int, asyncio.Event] = {} self._pending_publishes: dict[int, asyncio.Event] = {} self.pending_calls_threshold: int = 10 + + # Background tasks self._misc_task: asyncio.Task[None] | None = None + self._reconnection_task: asyncio.Task[None] | None = None # Queue that holds incoming messages if queue_type is None: @@ -432,9 +439,17 @@ async def publish( # noqa: PLR0913 **kwargs: Additional keyword arguments to pass to paho-mqtt's publish method. """ - info = self._client.publish( - topic, payload, qos, retain, properties, *args, **kwargs - ) # [2] + while True: + info = self._client.publish( + topic, payload, qos, retain, properties, *args, **kwargs + ) # [2] + if not (info.rc == mqtt.MQTT_ERR_NO_CONN and self._reconnect): + break + while True: + with contextlib.suppress(asyncio.CancelledError): + await self._connected + break + self._connected = asyncio.Future() # Early out on error if info.rc != mqtt.MQTT_ERR_SUCCESS: raise MqttCodeError(info.rc, "Could not publish message") @@ -677,43 +692,65 @@ async def _misc_loop(self) -> None: while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: await asyncio.sleep(1) + async def _connect(self) -> None: + """Connect to the broker. Retry indefinitely if self._reconnect is True.""" + while True: + try: + try: + loop = asyncio.get_running_loop() + # [3] Run connect() within an executor thread, since it blocks on socket + # connection for up to `keepalive` seconds: https://git.io/Jt5Yc + await loop.run_in_executor( + None, + self._client.connect, + self._hostname, + self._port, + self._keepalive, + self._bind_address, + self._bind_port, + self._clean_start, + self._properties, + ) + _set_client_socket_defaults(self._client.socket(), self._socket_options) + # Convert all possible paho-mqtt Client.connect exceptions to our MqttError + # See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770 + except (OSError, mqtt.WebsocketConnectionError) as exc: + raise MqttError(str(exc)) from None + await self._wait_for(self._connected, timeout=None) + self._logger.info("Successfully connected to the broker.") + break + except MqttError: + # Reset internal state if the connection attempt failed + if self._connected.done(): + self._connected = asyncio.Future() + if self._disconnected.done(): + self._disconnected = asyncio.Future() + if not self._reconnect: + self._lock.release() + raise + self._logger.warning("Failed to connect. Trying again in 2 seconds...") + await asyncio.sleep(2) + + async def _reconnection(self) -> None: + """Reconnect when the connection is lost.""" + while True: + with contextlib.suppress(MqttError): + await self._disconnected + self._logger.warning("Connection lost. Reconnecting...") + self._connected = asyncio.Future() + self._disconnected = asyncio.Future() + await self._connect() + async def __aenter__(self) -> Self: """Connect to the broker.""" if self._lock.locked(): msg = "The client context manager is reusable, but not reentrant" raise MqttReentrantError(msg) await self._lock.acquire() - try: - loop = asyncio.get_running_loop() - # [3] Run connect() within an executor thread, since it blocks on socket - # connection for up to `keepalive` seconds: https://git.io/Jt5Yc - await loop.run_in_executor( - None, - self._client.connect, - self._hostname, - self._port, - self._keepalive, - self._bind_address, - self._bind_port, - self._clean_start, - self._properties, - ) - _set_client_socket_defaults(self._client.socket(), self._socket_options) - # Convert all possible paho-mqtt Client.connect exceptions to our MqttError - # See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770 - except (OSError, mqtt.WebsocketConnectionError) as exc: - self._lock.release() - raise MqttError(str(exc)) from None - try: - await self._wait_for(self._connected, timeout=None) - except MqttError: - # Reset state if connection attempt times out or CONNACK returns negative - self._lock.release() - self._connected = asyncio.Future() - raise - # Reset `_disconnected` if it's already in completed state after connecting - if self._disconnected.done(): - self._disconnected = asyncio.Future() + await self._connect() + # Start the reconnection task + if self._reconnect: + self._reconnection_task = asyncio.create_task(self._reconnection()) return self async def __aexit__( @@ -723,8 +760,10 @@ async def __aexit__( tb: TracebackType | None, ) -> None: """Disconnect from the broker.""" + if self._reconnect: + self._reconnection_task.cancel() + # Return early if the client is already disconnected if self._disconnected.done(): - # Return early if the client is already disconnected if self._lock.locked(): self._lock.release() if (exc := self._disconnected.exception()) is not None: