diff --git a/protocol/decoder/real_decoder.go b/protocol/decoder/real_decoder.go index fd0ae3d..3754737 100644 --- a/protocol/decoder/real_decoder.go +++ b/protocol/decoder/real_decoder.go @@ -256,6 +256,11 @@ func (rd *RealDecoder) GetString() (string, error) { return "", err } + if rd.Remaining() < n { + rem := rd.Remaining() + return "", errors.NewPacketDecodingError(fmt.Sprintf("Expected string length to be %d bytes, got %d bytes", n, rem), "STRING") + } + tmpStr := string(rd.raw[rd.off : rd.off+n]) rd.off += n return tmpStr, nil @@ -270,6 +275,11 @@ func (rd *RealDecoder) GetNullableString() (*string, error) { return nil, err } + if rd.Remaining() < n { + rem := rd.Remaining() + return nil, errors.NewPacketDecodingError(fmt.Sprintf("Expected string length to be %d bytes, got %d bytes", n, rem), "NULLABLE_STRING") + } + tmpStr := string(rd.raw[rd.off : rd.off+n]) rd.off += n return &tmpStr, nil @@ -288,6 +298,12 @@ func (rd *RealDecoder) GetCompactString() (string, error) { if length < 0 { return "", errors.NewPacketDecodingError(fmt.Sprintf("Expected COMPACT_STRING length to be > 0, got %d", length), "COMPACT_STRING") } + + if rd.Remaining() < length { + rem := rd.Remaining() + return "", errors.NewPacketDecodingError(fmt.Sprintf("Expected string length to be %d bytes, got %d bytes", length, rem), "COMPACT_STRING") + } + tmpStr := string(rd.raw[rd.off : rd.off+length]) rd.off += length return tmpStr, nil @@ -308,6 +324,11 @@ func (rd *RealDecoder) GetCompactNullableString() (*string, error) { return nil, errors.NewPacketDecodingError(fmt.Sprintf("Expected compact nullable string length to be > 0, got %d", length), "COMPACT_NULLABLE_STRING") } + if rd.Remaining() < length { + rem := rd.Remaining() + return nil, errors.NewPacketDecodingError(fmt.Sprintf("Expected string length to be %d bytes, got %d bytes", length, rem), "COMPACT_NULLABLE_STRING") + } + tmpStr := string(rd.raw[rd.off : rd.off+length]) rd.off += length return &tmpStr, nil @@ -456,9 +477,7 @@ func (rd *RealDecoder) GetRawBytesFromOffset(length int) ([]byte, error) { if rd.off >= len(rd.raw) { return nil, errors.NewPacketDecodingError(fmt.Sprintf("Expected offset to be less than length of raw bytes (%d), got %d", len(rd.raw), rd.off), "RAW_BYTES_FROM_OFFSET") - } - - if rd.off+length > len(rd.raw) { + } else if rd.off+length > len(rd.raw) { return nil, errors.NewPacketDecodingError(fmt.Sprintf("Expected offset to be less than length of raw bytes (%d), got %d", len(rd.raw), rd.off), "RAW_BYTES_FROM_OFFSET") }