Skip to content

Commit

Permalink
feat: support TextMarshaller for map key (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Nov 11, 2023
1 parent 210b6b9 commit 8c60780
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 4 deletions.
139 changes: 135 additions & 4 deletions codec_map.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package avro

import (
"encoding"
"errors"
"fmt"
"io"
Expand All @@ -11,16 +12,28 @@ import (
)

func createDecoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
if typ.Kind() == reflect.Map && typ.(reflect2.MapType).Key().Kind() == reflect.String {
return decoderOfMap(cfg, schema, typ)
if typ.Kind() == reflect.Map {
keyType := typ.(reflect2.MapType).Key()
switch {
case keyType.Kind() == reflect.String:
return decoderOfMap(cfg, schema, typ)
case keyType.Implements(textUnmarshalerType):
return decoderOfMapUnmarshaler(cfg, schema, typ)
}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
if typ.Kind() == reflect.Map && typ.(reflect2.MapType).Key().Kind() == reflect.String {
return encoderOfMap(cfg, schema, typ)
if typ.Kind() == reflect.Map {
keyType := typ.(reflect2.MapType).Key()
switch {
case keyType.Kind() == reflect.String:
return encoderOfMap(cfg, schema, typ)
case keyType.Implements(textMarshalerType):
return encoderOfMapMarshaler(cfg, schema, typ)
}
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand Down Expand Up @@ -69,6 +82,65 @@ func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
}
}

func decoderOfMapUnmarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
m := schema.(*MapSchema)
mapType := typ.(*reflect2.UnsafeMapType)
decoder := decoderOfType(cfg, m.Values(), mapType.Elem())

return &mapDecoderUnmarshaler{
mapType: mapType,
keyType: mapType.Key(),
elemType: mapType.Elem(),
decoder: decoder,
}
}

type mapDecoderUnmarshaler struct {
mapType *reflect2.UnsafeMapType
keyType reflect2.Type
elemType reflect2.Type
decoder ValDecoder
}

func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) {
if d.mapType.UnsafeIsNil(ptr) {
d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
}

for {
l, _ := r.ReadBlockHeader()
if l == 0 {
break
}

for i := int64(0); i < l; i++ {
keyPtr := d.keyType.UnsafeNew()
keyObj := d.keyType.UnsafeIndirect(keyPtr)
if reflect2.IsNil(keyObj) {
ptrType := d.keyType.(*reflect2.UnsafePtrType)
newPtr := ptrType.Elem().UnsafeNew()
*((*unsafe.Pointer)(keyPtr)) = newPtr
keyObj = d.keyType.UnsafeIndirect(keyPtr)
}
unmarshaler := keyObj.(encoding.TextUnmarshaler)
err := unmarshaler.UnmarshalText([]byte(r.ReadString()))
if err != nil {
r.ReportError("mapDecoderUnmarshaler", err.Error())
return
}

elemPtr := d.elemType.UnsafeNew()
d.decoder.Decode(elemPtr, r)

d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
}
}

if r.Error != nil && !errors.Is(r.Error, io.EOF) {
r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
}
}

func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
m := schema.(*MapSchema)
mapType := typ.(*reflect2.UnsafeMapType)
Expand Down Expand Up @@ -113,3 +185,62 @@ func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
}
}

func encoderOfMapMarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
m := schema.(*MapSchema)
mapType := typ.(*reflect2.UnsafeMapType)
encoder := encoderOfType(cfg, m.Values(), mapType.Elem())

return &mapEncoderMarshaller{
blockLength: cfg.getBlockLength(),
mapType: mapType,
keyType: mapType.Key(),
encoder: encoder,
}
}

type mapEncoderMarshaller struct {
blockLength int
mapType *reflect2.UnsafeMapType
keyType reflect2.Type
encoder ValEncoder
}

func (e *mapEncoderMarshaller) Encode(ptr unsafe.Pointer, w *Writer) {
blockLength := e.blockLength

iter := e.mapType.UnsafeIterate(ptr)

for {
wrote := w.WriteBlockCB(func(w *Writer) int64 {
var i int
for i = 0; iter.HasNext() && i < blockLength; i++ {
keyPtr, elemPtr := iter.UnsafeNext()

obj := e.keyType.UnsafeIndirect(keyPtr)
if e.keyType.IsNullable() && reflect2.IsNil(obj) {
w.Error = errors.New("avro: mapEncoderMarshaller: encoding nil TextMarshaller")
return int64(0)
}
marshaler := (obj).(encoding.TextMarshaler)
b, err := marshaler.MarshalText()
if err != nil {
w.Error = err
return int64(0)
}
w.WriteString(string(b))

e.encoder.Encode(elemPtr, w)
}
return int64(i)
})

if wrote == 0 {
break
}
}

if w.Error != nil && !errors.Is(w.Error, io.EOF) {
w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
}
}
77 changes: 77 additions & 0 deletions decoder_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package avro_test

import (
"bytes"
"errors"
"strconv"
"testing"

"github.com/hamba/avro/v2"
Expand Down Expand Up @@ -65,6 +67,81 @@ func TestDecoder_MapMapError(t *testing.T) {
assert.Error(t, err)
}

type textUnmarshallerInt int

func (t *textUnmarshallerInt) UnmarshalText(text []byte) error {
i, err := strconv.Atoi(string(text))
if err != nil {
return err
}
*t = textUnmarshallerInt(i)
return nil
}

func TestDecoder_MapUnmarshallerMap(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x1, 0xe, 0x2, 0x31, 0x8, 0x74, 0x65, 0x73, 0x74, 0x0}
schema := `{"type":"map", "values": "string"}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[*textUnmarshallerInt]string
err := dec.Decode(&got)

require.NoError(t, err)
want := map[textUnmarshallerInt]string{1: "test"}
for k, v := range got {
wantVal, ok := want[*k]
assert.True(t, ok)
assert.Equal(t, wantVal, v)
}
}

type textUnmarshallerNope int

func (t textUnmarshallerNope) UnmarshalText(text []byte) error {
i, err := strconv.Atoi(string(text))
if err != nil {
return err
}
t = textUnmarshallerNope(i)
return nil
}

func TestDecoder_MapUnmarshallerMapImpossible(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x1, 0xe, 0x2, 0x31, 0x8, 0x74, 0x65, 0x73, 0x74, 0x0}
schema := `{"type":"map", "values": "string"}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[textUnmarshallerNope]string
err := dec.Decode(&got)

require.NoError(t, err)
want := map[textUnmarshallerNope]string{0: "test"}
assert.Equal(t, want, got)
}

type textUnmarshallerError int

func (t *textUnmarshallerError) UnmarshalText(text []byte) error {
return errors.New("test")
}

func TestDecoder_MapUnmarshallerKeyError(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x1, 0xe, 0x2, 0x31, 0x8, 0x74, 0x65, 0x73, 0x74, 0x0}
schema := `{"type":"map", "values": "string"}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[*textUnmarshallerError]string
err := dec.Decode(&got)

require.Error(t, err)
}

func TestDecoder_MapInvalidKeyType(t *testing.T) {
defer ConfigTeardown()

Expand Down
76 changes: 76 additions & 0 deletions encoder_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package avro_test

import (
"bytes"
"errors"
"strconv"
"testing"

"github.com/hamba/avro/v2"
Expand Down Expand Up @@ -113,3 +115,77 @@ func TestEncoder_MapWithMoreThanBlockLengthKeys(t *testing.T) {
return (foobar || barfoo)
})
}

type textMarshallerInt int

func (t textMarshallerInt) MarshalText() (text []byte, err error) {
return []byte(strconv.Itoa(int(t))), nil
}

type textMarshallerError int

func (t textMarshallerError) MarshalText() (text []byte, err error) {
return nil, errors.New("test")
}

func TestEncoder_MapMarshaller(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"map", "values": "string"}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

err = enc.Encode(map[textMarshallerInt]string{
1: "test",
})

require.NoError(t, err)
want := []byte{0x1, 0xe, 0x2, 0x31, 0x8, 0x74, 0x65, 0x73, 0x74, 0x0}
assert.Equal(t, want, buf.Bytes())
}

func TestEncoder_MapMarshallerNil(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"map", "values": "string"}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

err = enc.Encode(map[*textMarshallerError]int{
nil: 1,
})

require.Error(t, err)
}

func TestEncoder_MapMarshallerKeyError(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"map", "values": "string"}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

err = enc.Encode(map[textMarshallerError]int{
1: 1,
})

require.Error(t, err)
}

func TestEncoder_MapMarshallerError(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"map", "values": "string"}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

err = enc.Encode(map[textMarshallerInt]int{
1: 1,
})

require.Error(t, err)
}

0 comments on commit 8c60780

Please sign in to comment.