Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Read/WriteExtensionRaw #329

Merged
merged 2 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 112 additions & 130 deletions msgp/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,87 +121,97 @@ func (r *RawExtension) UnmarshalBinary(b []byte) error {
return nil
}

// WriteExtension writes an extension type to the writer
func (mw *Writer) WriteExtension(e Extension) error {
l := e.Len()
var err error
switch l {
func (mw *Writer) writeExtensionHeader(length int, extType int8) error {
switch length {
case 0:
o, err := mw.require(3)
if err != nil {
return err
}
mw.buf[o] = mext8
mw.buf[o+1] = 0
mw.buf[o+2] = byte(e.ExtensionType())
mw.buf[o+2] = byte(extType)
case 1:
o, err := mw.require(2)
if err != nil {
return err
}
mw.buf[o] = mfixext1
mw.buf[o+1] = byte(e.ExtensionType())
mw.buf[o+1] = byte(extType)
case 2:
o, err := mw.require(2)
if err != nil {
return err
}
mw.buf[o] = mfixext2
mw.buf[o+1] = byte(e.ExtensionType())
mw.buf[o+1] = byte(extType)
case 4:
o, err := mw.require(2)
if err != nil {
return err
}
mw.buf[o] = mfixext4
mw.buf[o+1] = byte(e.ExtensionType())
mw.buf[o+1] = byte(extType)
case 8:
o, err := mw.require(2)
if err != nil {
return err
}
mw.buf[o] = mfixext8
mw.buf[o+1] = byte(e.ExtensionType())
mw.buf[o+1] = byte(extType)
case 16:
o, err := mw.require(2)
if err != nil {
return err
}
mw.buf[o] = mfixext16
mw.buf[o+1] = byte(e.ExtensionType())
mw.buf[o+1] = byte(extType)
default:
switch {
case l < math.MaxUint8:
case length < math.MaxUint8:
o, err := mw.require(3)
if err != nil {
return err
}
mw.buf[o] = mext8
mw.buf[o+1] = byte(uint8(l))
mw.buf[o+2] = byte(e.ExtensionType())
case l < math.MaxUint16:
mw.buf[o+1] = byte(uint8(length))
mw.buf[o+2] = byte(extType)
case length < math.MaxUint16:
o, err := mw.require(4)
if err != nil {
return err
}
mw.buf[o] = mext16
big.PutUint16(mw.buf[o+1:], uint16(l))
mw.buf[o+3] = byte(e.ExtensionType())
big.PutUint16(mw.buf[o+1:], uint16(length))
mw.buf[o+3] = byte(extType)
default:
o, err := mw.require(6)
if err != nil {
return err
}
mw.buf[o] = mext32
big.PutUint32(mw.buf[o+1:], uint32(l))
mw.buf[o+5] = byte(e.ExtensionType())
big.PutUint32(mw.buf[o+1:], uint32(length))
mw.buf[o+5] = byte(extType)
}
}

return nil
}

// WriteExtension writes an extension type to the writer
func (mw *Writer) WriteExtension(e Extension) error {
length := e.Len()

err := mw.writeExtensionHeader(length, e.ExtensionType())
if err != nil {
return err
}

// we can only write directly to the
// buffer if we're sure that it
// fits the object
if l <= mw.bufsize() {
o, err := mw.require(l)
if length <= mw.bufsize() {
o, err := mw.require(length)
if err != nil {
return err
}
Expand All @@ -214,36 +224,35 @@ func (mw *Writer) WriteExtension(e Extension) error {
if err != nil {
return err
}
buf := make([]byte, l)
buf := make([]byte, length)
err = e.MarshalBinaryTo(buf)
if err != nil {
return err
}
mw.buf = buf
mw.wloc = l
mw.wloc = length
return nil
}

// WriteExtensionRaw writes an extension type to the writer
func (mw *Writer) WriteExtensionRaw(extType int8, payload []byte) error {
if err := mw.writeExtensionHeader(len(payload), extType); err != nil {
return err
}

if _, err := mw.Write(payload); err != nil {
return err
}

return nil
}

// peek at the extension type, assuming the next
// kind to be read is Extension
func (m *Reader) peekExtensionType() (int8, error) {
p, err := m.R.Peek(2)
if err != nil {
return 0, err
}
spec := getBytespec(p[0])
if spec.typ != ExtensionType {
return 0, badPrefix(ExtensionType, p[0])
}
if spec.extra == constsize {
return int8(p[1]), nil
}
size := spec.size
p, err = m.R.Peek(int(size))
if err != nil {
return 0, err
}
return int8(p[size-1]), nil
_, _, extType, err := m.peekExtensionHeader()

return extType, err
}

// peekExtension peeks at the extension encoding type
Expand All @@ -268,145 +277,118 @@ func peekExtension(b []byte) (int8, error) {
return int8(b[size-1]), nil
}

// ReadExtension reads the next object from the reader
// as an extension. ReadExtension will fail if the next
// object in the stream is not an extension, or if
// e.Type() is not the same as the wire type.
func (m *Reader) ReadExtension(e Extension) (err error) {
func (m *Reader) peekExtensionHeader() (offset int, length int, extType int8, err error) {
var p []byte
p, err = m.R.Peek(2)
if err != nil {
return
}

offset = 2

lead := p[0]
var read int
var off int
switch lead {
case mfixext1:
if int8(p[1]) != e.ExtensionType() {
err = errExt(int8(p[1]), e.ExtensionType())
return
}
p, err = m.R.Peek(3)
if err != nil {
return
}
err = e.UnmarshalBinary(p[2:])
if err == nil {
_, err = m.R.Skip(3)
}
extType = int8(p[1])
length = 1
return

case mfixext2:
if int8(p[1]) != e.ExtensionType() {
err = errExt(int8(p[1]), e.ExtensionType())
return
}
p, err = m.R.Peek(4)
if err != nil {
return
}
err = e.UnmarshalBinary(p[2:])
if err == nil {
_, err = m.R.Skip(4)
}
extType = int8(p[1])
length = 2
return

case mfixext4:
if int8(p[1]) != e.ExtensionType() {
err = errExt(int8(p[1]), e.ExtensionType())
return
}
p, err = m.R.Peek(6)
if err != nil {
return
}
err = e.UnmarshalBinary(p[2:])
if err == nil {
_, err = m.R.Skip(6)
}
extType = int8(p[1])
length = 4
return

case mfixext8:
if int8(p[1]) != e.ExtensionType() {
err = errExt(int8(p[1]), e.ExtensionType())
return
}
p, err = m.R.Peek(10)
if err != nil {
return
}
err = e.UnmarshalBinary(p[2:])
if err == nil {
_, err = m.R.Skip(10)
}
extType = int8(p[1])
length = 8
return

case mfixext16:
if int8(p[1]) != e.ExtensionType() {
err = errExt(int8(p[1]), e.ExtensionType())
return
}
p, err = m.R.Peek(18)
if err != nil {
return
}
err = e.UnmarshalBinary(p[2:])
if err == nil {
_, err = m.R.Skip(18)
}
extType = int8(p[1])
length = 16
return

case mext8:
p, err = m.R.Peek(3)
if err != nil {
return
}
if int8(p[2]) != e.ExtensionType() {
err = errExt(int8(p[2]), e.ExtensionType())
return
}
read = int(uint8(p[1]))
off = 3
offset = 3
extType = int8(p[2])
length = int(uint8(p[1]))

case mext16:
p, err = m.R.Peek(4)
if err != nil {
return
}
if int8(p[3]) != e.ExtensionType() {
err = errExt(int8(p[3]), e.ExtensionType())
return
}
read = int(big.Uint16(p[1:]))
off = 4
offset = 4
extType = int8(p[3])
length = int(big.Uint16(p[1:]))

case mext32:
p, err = m.R.Peek(6)
if err != nil {
return
}
if int8(p[5]) != e.ExtensionType() {
err = errExt(int8(p[5]), e.ExtensionType())
return
}
read = int(big.Uint32(p[1:]))
off = 6
offset = 6
extType = int8(p[5])
length = int(big.Uint32(p[1:]))

default:
err = badPrefix(ExtensionType, lead)
return
}

p, err = m.R.Peek(read + off)
return
}

// ReadExtension reads the next object from the reader
// as an extension. ReadExtension will fail if the next
// object in the stream is not an extension, or if
// e.Type() is not the same as the wire type.
func (m *Reader) ReadExtension(e Extension) error {
offset, length, extType, err := m.peekExtensionHeader()
if err != nil {
return
return err
}

if expectedType := e.ExtensionType(); extType != expectedType {
return errExt(extType, expectedType)
}
err = e.UnmarshalBinary(p[off:])

p, err := m.R.Peek(offset + length)
if err != nil {
return err
}
err = e.UnmarshalBinary(p[offset:])
if err == nil {
_, err = m.R.Skip(read + off)
// consume the peeked bytes
_, err = m.R.Skip(offset + length)
}
return
return err
}

// ReadExtensionRaw reads the next object from the reader
// as an extension. The returned slice is only
// valid until the next *Reader method call.
func (m *Reader) ReadExtensionRaw() (int8, []byte, error) {
offset, length, extType, err := m.peekExtensionHeader()
if err != nil {
return 0, nil, err
}

payload, err := m.R.Next(offset + length)
if err != nil {
return 0, nil, err
}

return extType, payload[offset:], nil
}

// AppendExtension appends a MessagePack extension to the provided slice
Expand Down
Loading