From 5e7d43723ab7e7d7d576e72d64cb484b46416644 Mon Sep 17 00:00:00 2001 From: toxazhl Date: Thu, 5 Dec 2024 07:10:32 +0200 Subject: [PATCH 1/2] add property type checking in package encoding --- src/mqttproto/_exceptions.py | 2 +- src/mqttproto/_types.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mqttproto/_exceptions.py b/src/mqttproto/_exceptions.py index c944e65..28eb980 100644 --- a/src/mqttproto/_exceptions.py +++ b/src/mqttproto/_exceptions.py @@ -56,7 +56,7 @@ class InsufficientData(MQTTDecodeError): class MQTTUnsupportedPropertyType(MQTTDecodeError): """ - Raised when decoding an MQTT packet and it contains a property of a type not + Raised when decoding or encoding an MQTT packet and it contains a property of a type not supported by that packet type. """ diff --git a/src/mqttproto/_types.py b/src/mqttproto/_types.py index ae9c161..87eaf8e 100644 --- a/src/mqttproto/_types.py +++ b/src/mqttproto/_types.py @@ -349,6 +349,9 @@ class PropertiesMixin: def encode_properties(self, buffer: bytearray) -> None: internal_buffer = bytearray() for identifier, value in self.properties.items(): + if identifier not in self.allowed_property_types: + raise MQTTUnsupportedPropertyType(identifier, self.__class__) + encode_variable_integer(identifier, internal_buffer) identifier.encoder(value, internal_buffer) From d5ff5563c80bc9f8a5e05f287d639caa570a9d26 Mon Sep 17 00:00:00 2001 From: toxazhl Date: Thu, 5 Dec 2024 09:18:39 +0200 Subject: [PATCH 2/2] add propetries for client.publish #28 --- src/mqttproto/async_client.py | 6 +++++- src/mqttproto/client_state_machine.py | 9 ++++++++- src/mqttproto/sync_client.py | 14 ++++++++++++-- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/mqttproto/async_client.py b/src/mqttproto/async_client.py index 857eac0..eed1164 100644 --- a/src/mqttproto/async_client.py +++ b/src/mqttproto/async_client.py @@ -42,6 +42,7 @@ MQTTSubscribeAckPacket, MQTTUnsubscribeAckPacket, PropertyType, + PropertyValue, QoS, ReasonCode, RetainHandling, @@ -498,6 +499,7 @@ async def publish( *, qos: QoS = QoS.AT_MOST_ONCE, retain: bool = False, + properties: dict[PropertyType, PropertyValue] | None = None, ) -> None: """ Publish a message to the given topic. @@ -510,7 +512,9 @@ async def publish( before the subscription happened """ - packet_id = self._state_machine.publish(topic, payload, qos=qos, retain=retain) + packet_id = self._state_machine.publish( + topic, payload, qos=qos, retain=retain, properties=properties + ) if qos is QoS.EXACTLY_ONCE: assert packet_id is not None await self._run_operation(MQTTQoS2PublishOperation(packet_id)) diff --git a/src/mqttproto/client_state_machine.py b/src/mqttproto/client_state_machine.py index 8b1dd39..ebed5a6 100644 --- a/src/mqttproto/client_state_machine.py +++ b/src/mqttproto/client_state_machine.py @@ -22,6 +22,7 @@ MQTTUnsubscribeAckPacket, MQTTUnsubscribePacket, PropertyType, + PropertyValue, QoS, ReasonCode, Subscription, @@ -153,6 +154,7 @@ def publish( *, qos: QoS = QoS.AT_MOST_ONCE, retain: bool = False, + properties: dict[PropertyType, PropertyValue] | None = None, ) -> int | None: """ Send a ``PUBLISH`` request. @@ -171,7 +173,12 @@ def publish( self._out_require_state(MQTTClientState.CONNECTED) packet_id = self._generate_packet_id() if qos > QoS.AT_MOST_ONCE else None packet = MQTTPublishPacket( - topic=topic, payload=payload, qos=qos, retain=retain, packet_id=packet_id + topic=topic, + payload=payload, + qos=qos, + retain=retain, + packet_id=packet_id, + properties=properties if properties is not None else {}, ) packet.encode(self._out_buffer) if packet_id is not None: diff --git a/src/mqttproto/sync_client.py b/src/mqttproto/sync_client.py index 9eebfb7..a44efff 100644 --- a/src/mqttproto/sync_client.py +++ b/src/mqttproto/sync_client.py @@ -10,7 +10,14 @@ from anyio.from_thread import BlockingPortal, BlockingPortalProvider from attrs import define -from ._types import MQTTPublishPacket, QoS, RetainHandling, Will +from ._types import ( + MQTTPublishPacket, + PropertyType, + PropertyValue, + QoS, + RetainHandling, + Will, +) from .async_client import AsyncMQTTClient, AsyncMQTTSubscription if sys.version_info >= (3, 11): @@ -104,9 +111,12 @@ def publish( *, qos: QoS = QoS.AT_MOST_ONCE, retain: bool = False, + properties: dict[PropertyType, PropertyValue] | None = None, ) -> None: return self._portal.call( - lambda: self._async_client.publish(topic, payload, qos=qos, retain=retain) + lambda: self._async_client.publish( + topic, payload, qos=qos, retain=retain, properties=properties + ) ) @contextmanager