Skip to content

Commit

Permalink
Fix #96: Process incoming topic_alias in rigth way
Browse files Browse the repository at this point in the history
  • Loading branch information
mitu committed Jul 31, 2020
1 parent da0624f commit 67deb47
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
2 changes: 1 addition & 1 deletion gmqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"Mikhail Turchunovich",
"Elena Nikolaichik"
]
__version__ = "0.6.6"
__version__ = "0.6.7"


__all__ = [
Expand Down
3 changes: 3 additions & 0 deletions gmqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ async def connect(self, host, port=1883, ssl=False, keepalive=60, version=MQTTv5
async def _create_connection(self, host, port, ssl, clean_session, keepalive):
# important for reconnects, make sure u know what u are doing if wanna change :(
self._exit_reconnecting_state()
self._clear_topics_aliases()
connection = await MQTTConnection.create_connection(host, port, ssl, clean_session, keepalive)
connection.set_handler(self)
return connection
Expand Down Expand Up @@ -213,6 +214,8 @@ async def disconnect(self, reason_code=0, **properties):
await self._disconnect(reason_code=reason_code, **properties)

async def _disconnect(self, reason_code=0, **properties):
self._clear_topics_aliases()

self._connected.clear()
if self._connection:
self._connection.send_disconnect(reason_code=reason_code, **properties)
Expand Down
40 changes: 28 additions & 12 deletions gmqtt/mqtt/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(self, *args, **kwargs):
self._handler_cache = {}
self._error = None
self._connection = None
self._server_topics_aliases = {}

self._id_generator = IdGenerator(max=kwargs.get('receive_maximum', 65535))

Expand All @@ -179,6 +180,9 @@ def __init__(self, *args, **kwargs):
else:
self._optimistic_acknowledgement = True

def _clear_topics_aliases(self):
self._server_topics_aliases = {}

def _send_command_with_mid(self, cmd, mid, dup, reason_code=0):
raise NotImplementedError

Expand Down Expand Up @@ -219,6 +223,9 @@ def _default_handler(self, cmd, packet):
logger.warning('[UNKNOWN CMD] %s %s', hex(cmd), packet)

def _handle_disconnect_packet(self, cmd, packet):
# reset server topics on disconnect
self._clear_topics_aliases()

future = asyncio.ensure_future(self.reconnect(delay=True))
future.add_done_callback(self._handle_exception_in_future)
self.on_disconnect(self, packet)
Expand Down Expand Up @@ -297,20 +304,9 @@ def _handle_publish_packet(self, cmd, raw_packet):
pack_format = '!' + str(slen) + 's' + str(len(packet) - slen) + 's'
(topic, packet) = struct.unpack(pack_format, packet)

if not topic:
logger.warning('[MQTT ERR PROTO] topic name is empty')
return

try:
print_topic = topic.decode('utf-8')
except UnicodeDecodeError as exc:
logger.warning('[INVALID CHARACTER IN TOPIC] %s', topic, exc_info=exc)
print_topic = topic

# we will change the packet ref, let's save origin
payload = packet

logger.debug('[RECV %s with QoS: %s] %s', print_topic, qos, payload)

if qos > 0:
pack_format = "!H" + str(len(packet) - 2) + 's'
(mid, packet) = struct.unpack(pack_format, packet)
Expand All @@ -325,6 +321,26 @@ def _handle_publish_packet(self, cmd, raw_packet):
logger.critical('[INVALID MESSAGE] skipping: {}'.format(raw_packet))
return

if 'topic_alias' in properties:
# TODO: need to add validation (topic alias must be greater than 0 and less than topic_alias_maximum)
topic_alias = properties['topic_alias'][0]
if topic:
self._server_topics_aliases[topic_alias] = topic
else:
topic = self._server_topics_aliases.get(topic_alias, None)

if not topic:
logger.warning('[MQTT ERR PROTO] topic name is empty (or server has send invalid topic alias)')
return

try:
print_topic = topic.decode('utf-8')
except UnicodeDecodeError as exc:
logger.warning('[INVALID CHARACTER IN TOPIC] %s', topic, exc_info=exc)
print_topic = topic

logger.debug('[RECV %s with QoS: %s] %s', print_topic, qos, payload)

if qos == 0:
run_coroutine_or_function(self.on_message, self, print_topic, packet, qos, properties)
elif qos == 1:
Expand Down

0 comments on commit 67deb47

Please sign in to comment.