Skip to content

Commit

Permalink
Merge branch 'subid'
Browse files Browse the repository at this point in the history
  • Loading branch information
smurfix committed Nov 11, 2024
2 parents 9f750b0 + 4f1d93e commit 04e61db
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Start external services
run: docker compose up -d
- name: Install the project and its dependencies
run: pip install .[test]
run: pip install -e .[test]
- name: Test with pytest
run: coverage run -m pytest
- name: Upload Coverage
Expand Down
4 changes: 4 additions & 0 deletions docs/userguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ fine too:
hurts code readability, try using :class:`~contextlib.AsyncExitStack` to manage the
context managers.

While the MQTT protocol doesn't support multiple subscriptions to the same
topic pattern, :meth:`~.AsyncMQTTClient.subscribe` hides that from you —
albeit with some restrictions.

.. seealso:: :meth:`~.AsyncMQTTClient.subscribe`

Publishing messages
Expand Down
115 changes: 93 additions & 22 deletions src/mqttproto/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
else:
from typing_extensions import TypeAlias

PropertyValue: TypeAlias = "str | bytes | int | tuple[str, str]"
PropertyValue: TypeAlias = "str | bytes | int | tuple[str, str] | list[int]"

VARIABLE_HEADER_START = b"\x00\x04MQTT\x05"

Expand Down Expand Up @@ -340,6 +340,10 @@ def get(cls, value: int) -> Self:
raise MQTTDecodeError(f"unknown property type: 0x{value:02X}")


# These properties may appear more than once
multi_properties = frozenset((PropertyType.SUBSCRIPTION_IDENTIFIER,))


@define(kw_only=True)
class PropertiesMixin:
allowed_property_types: ClassVar[frozenset[PropertyType]] = frozenset()
Expand All @@ -349,8 +353,13 @@ class PropertiesMixin:
def encode_properties(self, buffer: bytearray) -> None:
internal_buffer = bytearray()
for identifier, value in self.properties.items():
encode_variable_integer(identifier, internal_buffer)
identifier.encoder(value, internal_buffer)
if identifier in multi_properties and isinstance(value, (list, tuple)):
for val in value:
encode_variable_integer(identifier, internal_buffer)
identifier.encoder(val, internal_buffer)
else:
encode_variable_integer(identifier, internal_buffer)
identifier.encoder(value, internal_buffer)

for key, value in self.user_properties.items():
encode_variable_integer(PropertyType.USER_PROPERTY, internal_buffer)
Expand Down Expand Up @@ -380,6 +389,10 @@ def decode_properties(
if property_type is PropertyType.USER_PROPERTY:
key, value = cast("tuple[str, str]", value)
user_properties[key] = value
elif property_type in multi_properties:
if property_type not in properties:
properties[property_type] = []
cast("list[int]", properties[property_type]).append(cast(int, value))
else:
if property_type in properties:
raise MQTTDecodeError(
Expand Down Expand Up @@ -445,35 +458,67 @@ class Subscription:
retain_handling: RetainHandling = field(
kw_only=True, default=RetainHandling.SEND_RETAINED
)
subscription_id: int = field(kw_only=True, default=0)
group_id: str | None = field(init=False, default=None)
_parts: tuple[str, ...] = field(init=False, repr=False, eq=False)
_prefix: str | None = field(init=False, repr=False, eq=False)
_only_hash: bool = field(init=False, repr=False, eq=False, default=False)

def __attrs_post_init__(self) -> None:
self._parts = tuple(self.pattern.split("/"))
for i, part in enumerate(self._parts):
if "+" in part and len(part) != 1:
# MQTT-4.7.1-2
raise InvalidPattern(
"single-level wildcard ('+') must occupy an entire level of the "
"topic filter"
)
parts = tuple(self.pattern.split("/"))
prefix: str | None = ""
n_chop = 0
for i, part in enumerate(parts):
if "+" in part:
if prefix:
prefix += "/"

if len(part) != 1:
# MQTT-4.7.1-2
raise InvalidPattern(
"single-level wildcard ('+') must occupy an entire level of the "
"topic filter"
)

elif "#" in part:
if len(part) != 1:
# MQTT-4.7.1-1
raise InvalidPattern(
"multi-level wildcard ('#') must occupy an entire level of the "
"topic filter"
)
elif i != len(self._parts) - 1:

if i != len(parts) - 1:
# MQTT-4.7.1-1
raise InvalidPattern(
"multi-level wildcard ('#') must be the last character in the "
"topic filter"
)

if prefix is not None:
self._only_hash = True
else:
if prefix is not None:
n_chop += 1
if prefix:
prefix += "/"

prefix += part

continue

if prefix is not None:
self._prefix = prefix
prefix = None

if prefix is not None:
self._prefix = None # no wildcard

# Save the group ID for a shared subscription
if len(self._parts) > 2 and self._parts[0] == "$share":
self.group_id = self._parts[1]
if len(parts) > 2 and parts[0] == "$share":
self.group_id = parts[1]

self._parts = parts[n_chop:]

def __eq__(self, other: object) -> bool:
if isinstance(other, Subscription):
Expand All @@ -500,9 +545,11 @@ def decode(cls, data: memoryview) -> tuple[memoryview, Subscription]:
retain_handling=retain_handling,
)

def encode(self, buffer: bytearray) -> None:
def encode(self, buffer: bytearray, max_qos: QoS | None = None) -> None:
encode_utf8(self.pattern, buffer)
options = self.max_qos | self.retain_handling << 4
options = (
max_qos if max_qos is not None else self.max_qos
) | self.retain_handling << 4
if self.no_local:
options |= self.NO_LOCAL_FLAG

Expand All @@ -520,13 +567,30 @@ def matches(self, publish: MQTTPublishPacket) -> bool:
not
"""
# Don't match if the message's QoS is higher than the accepted maximum in this
# subscription
if publish.qos > self.max_qos:
# Check that the topic filter matches the message's topic.

# No wildcards
if self._prefix is None:
return self.pattern == publish.topic

# Static prefix must be identical
if not publish.topic.startswith(self._prefix):
return False

# Check that the topic filter matches the message's topic
topic_parts = publish.topic.split("/")
topic = publish.topic[len(self._prefix) :]
if self._only_hash:
# 'foo/bar/#' matches 'foo/bar' as well as 'foo/bar/baz'
# thus the prefix is 'foo/bar' and '._only_slash' is set
# so the remainder is either empty or starts with a slash
if self._prefix == "":
# or the pattern is a single '#', in which case we don't
# match '$foo'
return topic[0] != "$"
else:
return topic == "" or topic[0] == "/"

# Now for the complicated part
topic_parts = topic.split("/")
for i, (pattern_part, topic_part) in enumerate(
zip_longest(self._parts, topic_parts)
):
Expand Down Expand Up @@ -1152,6 +1216,7 @@ class MQTTSubscribePacket(MQTTPacket, PropertiesMixin):

packet_id: int
subscriptions: Sequence[Subscription]
max_qos: QoS | None = field(init=True, default=None)

def __attrs_post_init__(self) -> None:
if not self.subscriptions:
Expand All @@ -1169,6 +1234,12 @@ def decode(
subscriptions: list[Subscription] = []
while data:
data, subscription = Subscription.decode(data)
subscr_id = cast(
"list[int] | None", properties.get(PropertyType.SUBSCRIPTION_IDENTIFIER)
)
if subscr_id:
subscription.subscription_id = subscr_id[0]

subscriptions.append(subscription)

return data, MQTTSubscribePacket(
Expand All @@ -1187,7 +1258,7 @@ def encode(self, buffer: bytearray) -> None:

# Encode the payload
for subscription in self.subscriptions:
subscription.encode(internal_buffer)
subscription.encode(internal_buffer, max_qos=self.max_qos)

self.encode_fixed_header(self.expected_reserved_bits, internal_buffer, buffer)

Expand Down
Loading

0 comments on commit 04e61db

Please sign in to comment.