|
26 | 26 | * Adafruit CircuitPython firmware for the supported boards: |
27 | 27 | https://github.com/adafruit/circuitpython/releases |
28 | 28 |
|
| 29 | +* Adafruit's Connection Manager library: |
| 30 | + https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager |
| 31 | +
|
29 | 32 | """ |
30 | 33 | import errno |
31 | 34 | import struct |
32 | 35 | import time |
33 | 36 | from random import randint |
34 | 37 |
|
| 38 | +from adafruit_connection_manager import get_connection_manager |
| 39 | + |
35 | 40 | try: |
36 | 41 | from typing import List, Optional, Tuple, Type, Union |
37 | 42 | except ImportError: |
|
82 | 87 | class MMQTTException(Exception): |
83 | 88 | """MiniMQTT Exception class.""" |
84 | 89 |
|
85 | | - # pylint: disable=unnecessary-pass |
86 | | - # pass |
87 | | - |
88 | | - |
89 | | -class TemporaryError(Exception): |
90 | | - """Temporary error class used for handling reconnects.""" |
91 | | - |
92 | | - |
93 | | -# Legacy ESP32SPI Socket API |
94 | | -def set_socket(sock, iface=None) -> None: |
95 | | - """Legacy API for setting the socket and network interface. |
96 | | -
|
97 | | - :param sock: socket object. |
98 | | - :param iface: internet interface object |
99 | | -
|
100 | | - """ |
101 | | - global _default_sock # pylint: disable=invalid-name, global-statement |
102 | | - global _fake_context # pylint: disable=invalid-name, global-statement |
103 | | - _default_sock = sock |
104 | | - if iface: |
105 | | - _default_sock.set_interface(iface) |
106 | | - _fake_context = _FakeSSLContext(iface) |
107 | | - |
108 | | - |
109 | | -class _FakeSSLSocket: |
110 | | - def __init__(self, socket, tls_mode) -> None: |
111 | | - self._socket = socket |
112 | | - self._mode = tls_mode |
113 | | - self.settimeout = socket.settimeout |
114 | | - self.send = socket.send |
115 | | - self.recv = socket.recv |
116 | | - self.close = socket.close |
117 | | - |
118 | | - def connect(self, address): |
119 | | - """connect wrapper to add non-standard mode parameter""" |
120 | | - try: |
121 | | - return self._socket.connect(address, self._mode) |
122 | | - except RuntimeError as error: |
123 | | - raise OSError(errno.ENOMEM) from error |
124 | | - |
125 | | - |
126 | | -class _FakeSSLContext: |
127 | | - def __init__(self, iface) -> None: |
128 | | - self._iface = iface |
129 | | - |
130 | | - def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket: |
131 | | - """Return the same socket""" |
132 | | - # pylint: disable=unused-argument |
133 | | - return _FakeSSLSocket(socket, self._iface.TLS_MODE) |
134 | | - |
135 | 90 |
|
136 | 91 | class NullLogger: |
137 | 92 | """Fake logger class that does not do anything""" |
138 | 93 |
|
139 | 94 | # pylint: disable=unused-argument |
140 | 95 | def nothing(self, msg: str, *args) -> None: |
141 | 96 | """no action""" |
142 | | - pass |
143 | 97 |
|
144 | 98 | def __init__(self) -> None: |
145 | 99 | for log_level in ["debug", "info", "warning", "error", "critical"]: |
@@ -194,6 +148,7 @@ def __init__( |
194 | 148 | user_data=None, |
195 | 149 | use_imprecise_time: Optional[bool] = None, |
196 | 150 | ) -> None: |
| 151 | + self._connection_manager = get_connection_manager(socket_pool) |
197 | 152 | self._socket_pool = socket_pool |
198 | 153 | self._ssl_context = ssl_context |
199 | 154 | self._sock = None |
@@ -300,75 +255,6 @@ def get_monotonic_time(self) -> float: |
300 | 255 |
|
301 | 256 | return time.monotonic() |
302 | 257 |
|
303 | | - # pylint: disable=too-many-branches |
304 | | - def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): |
305 | | - """Obtains a new socket and connects to a broker. |
306 | | -
|
307 | | - :param str host: Desired broker hostname |
308 | | - :param int port: Desired broker port |
309 | | - :param int timeout: Desired socket timeout, in seconds |
310 | | - """ |
311 | | - # For reconnections - check if we're using a socket already and close it |
312 | | - if self._sock: |
313 | | - self._sock.close() |
314 | | - self._sock = None |
315 | | - |
316 | | - # Legacy API - use the interface's socket instead of a passed socket pool |
317 | | - if self._socket_pool is None: |
318 | | - self._socket_pool = _default_sock |
319 | | - |
320 | | - # Legacy API - fake the ssl context |
321 | | - if self._ssl_context is None: |
322 | | - self._ssl_context = _fake_context |
323 | | - |
324 | | - if not isinstance(port, int): |
325 | | - raise RuntimeError("Port must be an integer") |
326 | | - |
327 | | - if self._is_ssl and not self._ssl_context: |
328 | | - raise RuntimeError( |
329 | | - "ssl_context must be set before using adafruit_mqtt for secure MQTT." |
330 | | - ) |
331 | | - |
332 | | - if self._is_ssl: |
333 | | - self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}") |
334 | | - else: |
335 | | - self.logger.info(f"Establishing an INSECURE connection to {host}:{port}") |
336 | | - |
337 | | - addr_info = self._socket_pool.getaddrinfo( |
338 | | - host, port, 0, self._socket_pool.SOCK_STREAM |
339 | | - )[0] |
340 | | - |
341 | | - try: |
342 | | - sock = self._socket_pool.socket(addr_info[0], addr_info[1]) |
343 | | - except OSError as exc: |
344 | | - # Do not consider this for back-off. |
345 | | - self.logger.warning( |
346 | | - f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}" |
347 | | - ) |
348 | | - raise TemporaryError from exc |
349 | | - |
350 | | - connect_host = addr_info[-1][0] |
351 | | - if self._is_ssl: |
352 | | - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) |
353 | | - connect_host = host |
354 | | - sock.settimeout(timeout) |
355 | | - |
356 | | - try: |
357 | | - sock.connect((connect_host, port)) |
358 | | - except MemoryError as exc: |
359 | | - sock.close() |
360 | | - self.logger.warning(f"Failed to allocate memory for connect: {exc}") |
361 | | - # Do not consider this for back-off. |
362 | | - raise TemporaryError from exc |
363 | | - except OSError as exc: |
364 | | - sock.close() |
365 | | - self.logger.warning(f"Failed to connect: {exc}") |
366 | | - # Do not consider this for back-off. |
367 | | - raise TemporaryError from exc |
368 | | - |
369 | | - self._backwards_compatible_sock = not hasattr(sock, "recv_into") |
370 | | - return sock |
371 | | - |
372 | 258 | def __enter__(self): |
373 | 259 | return self |
374 | 260 |
|
@@ -538,8 +424,8 @@ def connect( |
538 | 424 | ) |
539 | 425 | self._reset_reconnect_backoff() |
540 | 426 | return ret |
541 | | - except TemporaryError as e: |
542 | | - self.logger.warning(f"temporary error when connecting: {e}") |
| 427 | + except RuntimeError as e: |
| 428 | + self.logger.warning(f"Socket error when connecting: {e}") |
543 | 429 | backoff = False |
544 | 430 | except MMQTTException as e: |
545 | 431 | last_exception = e |
@@ -587,9 +473,15 @@ def _connect( |
587 | 473 | time.sleep(self._reconnect_timeout) |
588 | 474 |
|
589 | 475 | # Get a new socket |
590 | | - self._sock = self._get_connect_socket( |
591 | | - self.broker, self.port, timeout=self._socket_timeout |
| 476 | + self._sock = self._connection_manager.get_socket( |
| 477 | + self.broker, |
| 478 | + self.port, |
| 479 | + proto="mqtt:", |
| 480 | + timeout=self._socket_timeout, |
| 481 | + is_ssl=self._is_ssl, |
| 482 | + ssl_context=self._ssl_context, |
592 | 483 | ) |
| 484 | + self._backwards_compatible_sock = not hasattr(self._sock, "recv_into") |
593 | 485 |
|
594 | 486 | fixed_header = bytearray([0x10]) |
595 | 487 |
|
@@ -686,7 +578,7 @@ def disconnect(self) -> None: |
686 | 578 | except RuntimeError as e: |
687 | 579 | self.logger.warning(f"Unable to send DISCONNECT packet: {e}") |
688 | 580 | self.logger.debug("Closing socket") |
689 | | - self._sock.close() |
| 581 | + self._connection_manager.free_socket(self._sock) |
690 | 582 | self._is_connected = False |
691 | 583 | self._subscribed_topics = [] |
692 | 584 | self._last_msg_sent_timestamp = 0 |
|
0 commit comments