Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to prefer text unmarshaler when data type is String #371

Open
wants to merge 1 commit into
base: v5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
disallowUnknownFieldsFlag
usePreallocateValues
disableAllocLimitFlag
preferTextUnmarshalerForString
)

type bufReader interface {
Expand Down Expand Up @@ -184,6 +185,20 @@ func (d *Decoder) DisableAllocLimit(on bool) {
}
}

// PreferTextUnmarshalerForString makes the decoder prefer [encoding.TextUnmarshaler]
// over [encoding.BinaryUnmarshaler] when both are implemented, and source
// MessagePack data is a String (as opposed to Binary).
//
// If this option is not enabled, [encoding.BinaryUnmarshaler] will be preferred
// instead, regardless of MessagePack data type.
func (d *Decoder) PreferTextUnmarshalerForString(on bool) {
if on {
d.flags |= preferTextUnmarshalerForString
} else {
d.flags &= ^preferTextUnmarshalerForString
}
}

// Buffered returns a reader of the data remaining in the Decoder's buffer.
// The reader is valid until the next call to Decode.
func (d *Decoder) Buffered() io.Reader {
Expand Down
39 changes: 35 additions & 4 deletions decode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"reflect"

"github.com/vmihailenco/msgpack/v5/msgpcode"
)

var (
Expand Down Expand Up @@ -70,10 +72,16 @@ func _getDecoder(typ reflect.Type) decoderFunc {
if typ.Implements(unmarshalerType) {
return nilAwareDecoder(typ, unmarshalValue)
}
if typ.Implements(binaryUnmarshalerType) {

implementsBinaryUnmarshaler := typ.Implements(binaryUnmarshalerType)
implementsTextUnmarshaler := typ.Implements(textUnmarshalerType)
if implementsBinaryUnmarshaler && implementsTextUnmarshaler {
return nilAwareDecoder(typ, unmarshalBinaryOrTextValue)
}
if implementsBinaryUnmarshaler {
return nilAwareDecoder(typ, unmarshalBinaryValue)
}
if typ.Implements(textUnmarshalerType) {
if implementsTextUnmarshaler {
return nilAwareDecoder(typ, unmarshalTextValue)
}

Expand All @@ -86,10 +94,15 @@ func _getDecoder(typ reflect.Type) decoderFunc {
if ptr.Implements(unmarshalerType) {
return addrDecoder(nilAwareDecoder(typ, unmarshalValue))
}
if ptr.Implements(binaryUnmarshalerType) {
implementsBinaryUnmarshaler := ptr.Implements(binaryUnmarshalerType)
implementsTextUnmarshaler := ptr.Implements(textUnmarshalerType)
if implementsBinaryUnmarshaler && implementsTextUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryOrTextValue))
}
if implementsBinaryUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue))
}
if ptr.Implements(textUnmarshalerType) {
if implementsTextUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue))
}
}
Expand Down Expand Up @@ -249,3 +262,21 @@ func unmarshalTextValue(d *Decoder, v reflect.Value) error {
unmarshaler := v.Interface().(encoding.TextUnmarshaler)
return unmarshaler.UnmarshalText(data)
}

func unmarshalBinaryOrTextValue(d *Decoder, v reflect.Value) error {
useText := false
if d.flags&preferTextUnmarshalerForString != 0 {
code, err := d.PeekCode()
if err != nil {
return err
}
if msgpcode.IsString(code) {
useText = true
}
}
if useText {
return unmarshalTextValue(d, v)
} else {
return unmarshalBinaryValue(d, v)
}
}
67 changes: 67 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package msgpack_test

import (
"bytes"
"encoding"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
Expand Down Expand Up @@ -427,6 +429,8 @@ type typeTest struct {
wantnil bool
wantzero bool
wanted interface{}

preferTextUnmarshalerForString bool
}

func (t typeTest) String() string {
Expand All @@ -442,6 +446,36 @@ func (t *typeTest) requireErr(err error, s string) {
}
}

type binaryTextType uint32

// UnmarshalText implements encoding.TextUnmarshaler
func (v *binaryTextType) UnmarshalText(text []byte) error {
var b [4]byte
n, err := hex.Decode(b[:], text)
if err != nil {
return err
}
if n != 4 {
return fmt.Errorf("invalid length %d", n)
}
*v = binaryTextType(binary.BigEndian.Uint32(b[:]))
return nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (v *binaryTextType) UnmarshalBinary(data []byte) error {
if n := len(data); n != 4 {
return fmt.Errorf("invalid length %d", n)
}
*v = binaryTextType(binary.BigEndian.Uint32(data))
return nil
}

var (
_ encoding.TextUnmarshaler = new(binaryTextType)
_ encoding.BinaryUnmarshaler = new(binaryTextType)
)

var (
intSlice = make([]int, 0, 3)
repoURL, _ = url.Parse("https://github.com/vmihailenco/msgpack")
Expand Down Expand Up @@ -622,6 +656,36 @@ var (
},

{in: big.NewInt(123), out: new(big.Int)},

{
in: "deadbeef",
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),
decErr: "invalid length 8",

preferTextUnmarshalerForString: false,
},
{
in: "deadbeef",
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: true,
},
{
in: []byte{0xde, 0xad, 0xbe, 0xef},
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: false,
},
{
in: []byte{0xde, 0xad, 0xbe, 0xef},
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: true,
},
}
)

Expand Down Expand Up @@ -655,6 +719,9 @@ func TestTypes(t *testing.T) {
}

dec := msgpack.NewDecoder(&buf)
if test.preferTextUnmarshalerForString {
dec.PreferTextUnmarshalerForString(true)
}
err = dec.Decode(test.out)
if test.decErr != "" {
test.requireErr(err, test.decErr)
Expand Down