From 0d8fe6ceabf6c8939b562c10f225805fddec9b48 Mon Sep 17 00:00:00 2001 From: ffa500 Date: Wed, 15 Feb 2023 13:10:26 +0100 Subject: [PATCH] fix: correct decoding of packets including Properties exceeding 127 bytes in length --- packets/packets.go | 14 +++++++------- packets/properties.go | 19 ++++++++++--------- packets/properties_test.go | 2 +- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/packets/packets.go b/packets/packets.go index 7ce992b3..36125d73 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -383,7 +383,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error { if err != nil { return fmt.Errorf("%s: %w", err, ErrMalformedProperties) } - offset += n + 1 + offset += n } pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4] @@ -397,7 +397,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error { if err != nil { return ErrMalformedWillProperties } - offset += n + 1 + offset += n } pk.Connect.WillTopic, offset, err = decodeString(buf, offset) @@ -644,7 +644,7 @@ func (pk *Packet) PublishDecode(buf []byte) error { return fmt.Errorf("%s: %w", err, ErrMalformedProperties) } - offset += n + 1 + offset += n } pk.Payload = buf[offset:] @@ -861,7 +861,7 @@ func (pk *Packet) SubackDecode(buf []byte) error { if err != nil { return fmt.Errorf("%s: %w", err, ErrMalformedProperties) } - offset += n + 1 + offset += n } pk.ReasonCodes = buf[offset:] @@ -918,7 +918,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error { if err != nil { return fmt.Errorf("%s: %w", err, ErrMalformedProperties) } - offset += n + 1 + offset += n } var filter string @@ -1014,7 +1014,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error { return fmt.Errorf("%s: %w", err, ErrMalformedProperties) } - offset += n + 1 + offset += n pk.ReasonCodes = buf[offset:] } @@ -1066,7 +1066,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error { if err != nil { return fmt.Errorf("%s: %w", err, ErrMalformedProperties) } - offset += n + 1 + offset += n } var filter string diff --git a/packets/properties.go b/packets/properties.go index 5405bcf2..6c8d56b1 100644 --- a/packets/properties.go +++ b/packets/properties.go @@ -366,13 +366,14 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { return 0, nil } - n, _, err = DecodeLength(b) + var bu int + n, bu, err = DecodeLength(b) if err != nil { - return n, err + return n + bu, err } if n == 0 { - return n, nil + return n + bu, nil } bt := b.Bytes() @@ -380,11 +381,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { for offset := 0; offset < n; { k, offset, err = decodeByte(bt, offset) if err != nil { - return n, err + return n + bu, err } if _, ok := validPacketProperties[k][pk]; !ok { - return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty) + return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty) } switch k { @@ -406,7 +407,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:])) if err != nil { - return n, err + return n + bu, err } p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n) offset += bu @@ -452,7 +453,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { var k, v string k, offset, err = decodeString(bt, offset) if err != nil { - return n, err + return n + bu, err } v, offset, err = decodeString(bt, offset) p.User = append(p.User, UserProperty{Key: k, Val: v}) @@ -470,9 +471,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { } if err != nil { - return n, err + return n + bu, err } } - return n, nil + return n + bu, nil } diff --git a/packets/properties_test.go b/packets/properties_test.go index 672530e0..d10445f0 100644 --- a/packets/properties_test.go +++ b/packets/properties_test.go @@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) { props := new(Properties) n, err := props.Decode(Reserved, b) require.NoError(t, err) - require.Equal(t, 172, n) + require.Equal(t, 172 + 2, n) require.EqualValues(t, propertiesStruct, *props) }