diff --git a/AWSIoTPythonSDK/core/protocol/paho/client.py b/AWSIoTPythonSDK/core/protocol/paho/client.py index 1c60c81..acaa067 100755 --- a/AWSIoTPythonSDK/core/protocol/paho/client.py +++ b/AWSIoTPythonSDK/core/protocol/paho/client.py @@ -49,11 +49,6 @@ from AWSIoTPythonSDK.core.protocol.connection.cores import SecuredWebSocketCore from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder -VERSION_MAJOR=1 -VERSION_MINOR=0 -VERSION_REVISION=0 -VERSION_NUMBER=(VERSION_MAJOR*1000000+VERSION_MINOR*1000+VERSION_REVISION) - MQTTv31 = 3 MQTTv311 = 4 @@ -497,6 +492,7 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT self._msgtime_mutex = threading.Lock() self._out_message_mutex = threading.Lock() self._in_message_mutex = threading.Lock() + self._mid_generate_mutex = threading.Lock() self._thread = None self._thread_terminate = False self._ssl = None @@ -515,7 +511,8 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT self._alpn_protocols = None def __del__(self): - pass + # Closes socket in client destructor to avoid FD leak. + self._reset_sockets() def setBackoffTiming(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond): @@ -547,7 +544,8 @@ def config_alpn_protocols(self, alpn_protocols): """ self._alpn_protocols = alpn_protocols - def reinitialise(self, client_id="", clean_session=True, userdata=None): + # Closes socket in client destructor to avoid FD leak. + def _reset_sockets(self): if self._ssl: self._ssl.close() self._ssl = None @@ -562,6 +560,9 @@ def reinitialise(self, client_id="", clean_session=True, userdata=None): self._sockpairW.close() self._sockpairW = None + # Closes socket in client destructor to avoid FD leak. + def reinitialise(self, client_id="", clean_session=True, userdata=None): + self._reset_sockets() self.__init__(client_id, clean_session, userdata) def tls_set(self, ca_certs, certfile=None, keyfile=None, cert_reqs=cert_reqs, tls_version=tls_version, ciphers=None): @@ -831,24 +832,14 @@ def reconnect(self): verify_hostname = False # Since check_hostname in SSLContext is already set to True, no need to verify it again self._ssl.do_handshake() else: - if force_ssl_context: - ssl_context = ssl.SSLContext(self._tls_version) - ssl_context.load_cert_chain(self._tls_certfile, self._tls_keyfile) - ssl_context.load_verify_locations(self._tls_ca_certs) - ssl_context.verify_mode = self._tls_cert_reqs - if self._tls_ciphers is not None: - ssl_context.set_ciphers(self._tls_ciphers) - - self._ssl = ssl_context.wrap_socket(sock) - else: - self._ssl = ssl.wrap_socket( - sock, - certfile=self._tls_certfile, - keyfile=self._tls_keyfile, - ca_certs=self._tls_ca_certs, - cert_reqs=self._tls_cert_reqs, - ssl_version=self._tls_version, - ciphers=self._tls_ciphers) + # ssl.wrap_socket is deprecated in Python 3.7+. Use SSLContext instead. + ssl_context = ssl.SSLContext(self._tls_version) + ssl_context.load_cert_chain(self._tls_certfile, self._tls_keyfile) + ssl_context.load_verify_locations(self._tls_ca_certs) + ssl_context.verify_mode = self._tls_cert_reqs + if self._tls_ciphers is not None: + ssl_context.set_ciphers(self._tls_ciphers) + self._ssl = ssl_context.wrap_socket(sock) if verify_hostname: if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 5): # No IP host match before 3.5.x @@ -924,6 +915,9 @@ def loop(self, timeout=1.0, max_packets=1): # Can occur if we just reconnected but rlist/wlist contain a -1 for # some reason. return MQTT_ERR_CONN_LOST + except KeyboardInterrupt: + # Allow ^C to interrupt + raise except: return MQTT_ERR_UNKNOWN @@ -981,8 +975,9 @@ def publish(self, topic, payload=None, qos=0, retain=False): raise ValueError('Invalid QoS level.') if isinstance(payload, str) or isinstance(payload, bytearray): local_payload = payload - elif sys.version_info[0] < 3 and isinstance(payload, unicode): - local_payload = payload + # Client.publish() now accepts bytes() payloads on Python 3. + elif sys.version_info[0] == 3 and isinstance(payload, bytes): + local_payload = bytearray(payload) elif isinstance(payload, int) or isinstance(payload, float): local_payload = str(payload) elif payload is None: @@ -1047,10 +1042,15 @@ def username_pw_set(self, username, password=None): Requires a broker that supports MQTT v3.1. username: The username to authenticate with. Need have no relationship to the client id. + [MQTT-3.1.3-11]. + Set to None to reset client back to not using username/password for broker authentication. password: The password to authenticate with. Optional, set to None if not required. """ - self._username = username.encode('utf-8') + # [MQTT-3.1.3-11] User name must be UTF-8 encoded string + self._username = None if username is None else username.encode('utf-8') self._password = password + if isinstance(self._password, str): + self._password = self._password.encode('utf-8') def socket_factory_set(self, socket_factory): """Set a socket factory to custom configure a different socket type for @@ -1117,7 +1117,7 @@ def subscribe(self, topic, qos=0): zero string length, or if topic is not a string, tuple or list. """ topic_qos_list = None - if isinstance(topic, str): + if isinstance(topic, str) : if qos<0 or qos>2: raise ValueError('Invalid QoS level.') if topic is None or len(topic) == 0: @@ -1165,7 +1165,7 @@ def unsubscribe(self, topic): topic_list = None if topic is None: raise ValueError('Invalid topic.') - if isinstance(topic, str): + if isinstance(topic, str) : if len(topic) == 0: raise ValueError('Invalid topic.') topic_list = [topic.encode('utf-8')] @@ -1453,8 +1453,10 @@ def loop_stop(self, force=False): return MQTT_ERR_INVAL self._thread_terminate = True - self._thread.join() - self._thread = None + # Don't attempt to join() own thread. + if threading.current_thread() != self._thread: + self._thread.join() + self._thread = None def message_callback_add(self, sub, callback): """Register a message callback for a specific topic. @@ -1704,6 +1706,10 @@ def _easy_log(self, level, buf): self.on_log(self, self._userdata, level, buf) def _check_keepalive(self): + # Fix for keepalive=0 causing an infinite disconnect/reconnect loop. + if self._keepalive == 0: + return MQTT_ERR_SUCCESS + now = time.time() self._msgtime_mutex.acquire() last_msg_out = self._last_msg_out @@ -1736,10 +1742,12 @@ def _check_keepalive(self): self._callback_mutex.release() def _mid_generate(self): - self._last_mid = self._last_mid + 1 - if self._last_mid == 65536: - self._last_mid = 1 - return self._last_mid + # Make sure mid generation that was thread-safe. + with self._mid_generate_mutex: + self._last_mid += 1 + if self._last_mid == 65536: + self._last_mid = 1 + return self._last_mid def _topic_wildcard_len_check(self, topic): # Search for + or # in a topic. Return MQTT_ERR_INVAL if found. @@ -1903,11 +1911,11 @@ def _send_connect(self, keepalive, clean_session): connect_flags = connect_flags | 0x04 | ((self._will_qos&0x03) << 3) | ((self._will_retain&0x01) << 5) if self._username: - remaining_length = remaining_length + 2+len(self._username) + remaining_length += 2+len(self._username) connect_flags = connect_flags | 0x80 if self._password: connect_flags = connect_flags | 0x40 - remaining_length = remaining_length + 2+len(self._password) + remaining_length += 2+len(self._password) command = CONNECT packet = bytearray()