Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move generic decoding to codec level #336

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions codec_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
pObj := (*any)(ptr)
obj := *pObj
if obj == nil {
*pObj = r.ReadNext(d.schema)
*pObj = genericDecode(d.schema, r)
return
}

typ := reflect2.TypeOf(obj)
if typ.Kind() != reflect.Ptr {
*pObj = r.ReadNext(d.schema)
*pObj = genericDecode(d.schema, r)
return
}

Expand Down
135 changes: 135 additions & 0 deletions codec_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package avro

import (
"fmt"
"math/big"
"time"
"unsafe"

"github.com/modern-go/reflect2"
)

func genericDecode(schema Schema, r *Reader) any {
rPtr, rTyp, err := genericReceiver(schema)
if err != nil {
r.ReportError("Read", err.Error())
return nil
}
decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r)
if r.Error != nil {
return nil
}
obj := rTyp.UnsafeIndirect(rPtr)
if reflect2.IsNil(obj) {
return nil
}

// seems generic reader is not compatible with codec
if rTyp.Type1() == ratType {
dec := obj.(big.Rat)
return &dec
}

return obj
}

func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) {
var ls LogicalSchema
lts, ok := schema.(LogicalTypeSchema)
if ok {
ls = lts.Logical()
}

name := string(schema.Type())
if ls != nil {
name += "." + string(ls.Type())
}

switch schema.Type() {
case Boolean:
var v bool
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Int:
if ls != nil {
switch ls.Type() {
case Date:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil

case TimeMillis:
var v time.Duration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
}
var v int
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Long:
if ls != nil {
switch ls.Type() {
case TimeMicros:
var v time.Duration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil

case TimestampMillis:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil

case TimestampMicros:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
}
var v int64
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Float:
var v float32
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Double:
var v float64
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case String:
var v string
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Bytes:
if ls != nil && ls.Type() == Decimal {
var v *big.Rat
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
var v []byte
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Record:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Ref:
return genericReceiver(schema.(*RefSchema).Schema())
case Enum:
var v string
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Array:
v := make([]any, 0)
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Map:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Union:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Fixed:
fixed := schema.(*FixedSchema)
ls := fixed.Logical()
if ls != nil {
switch ls.Type() {
case Duration:
var v LogicalDuration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Decimal:
var v big.Rat
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
}
v := byteSliceToArray(make([]byte, fixed.Size()), fixed.Size())
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
default:
return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name)
}
}
221 changes: 221 additions & 0 deletions codec_generic_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package avro

import (
"bytes"
"math/big"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGenericDecode(t *testing.T) {
tests := []struct {
name string
data []byte
schema string
want any
wantErr require.ErrorAssertionFunc
}{

{
name: "Bool",
data: []byte{0x01},
schema: "boolean",
want: true,
wantErr: require.NoError,
},
{
name: "Int",
data: []byte{0x36},
schema: "int",
want: 27,
wantErr: require.NoError,
},
{
name: "Int Date",
data: []byte{0xAE, 0x9D, 0x02},
schema: `{"type":"int","logicalType":"date"}`,
want: time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC),
wantErr: require.NoError,
},
{
name: "Int Time-Millis",
data: []byte{0xAA, 0xB4, 0xDE, 0x75},
schema: `{"type":"int","logicalType":"time-millis"}`,
want: 123456789 * time.Millisecond,
wantErr: require.NoError,
},
{
name: "Long",
data: []byte{0x36},
schema: "long",
want: int64(27),
wantErr: require.NoError,
},
{
name: "Long Time-Micros",
data: []byte{0x86, 0xEA, 0xC8, 0xE9, 0x97, 0x07},
schema: `{"type":"long","logicalType":"time-micros"}`,
want: 123456789123 * time.Microsecond,
wantErr: require.NoError,
},
{
name: "Long Timestamp-Millis",
data: []byte{0x90, 0xB2, 0xAE, 0xC3, 0xEC, 0x5B},
schema: `{"type":"long","logicalType":"timestamp-millis"}`,
want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC),
wantErr: require.NoError,
},
{
name: "Long Timestamp-Micros",
data: []byte{0x80, 0xCD, 0xB7, 0xA2, 0xEE, 0xC7, 0xCD, 0x05},
schema: `{"type":"long","logicalType":"timestamp-micros"}`,
want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC),
wantErr: require.NoError,
},
{
name: "Float",
data: []byte{0x33, 0x33, 0x93, 0x3F},
schema: "float",
want: float32(1.15),
wantErr: require.NoError,
},
{
name: "Double",
data: []byte{0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xF2, 0x3F},
schema: "double",
want: float64(1.15),
wantErr: require.NoError,
},
{
name: "String",
data: []byte{0x06, 0x66, 0x6F, 0x6F},
schema: "string",
want: "foo",
wantErr: require.NoError,
},
{
name: "Bytes",
data: []byte{0x08, 0xEC, 0xAB, 0x44, 0x00},
schema: "bytes",
want: []byte{0xEC, 0xAB, 0x44, 0x00},
wantErr: require.NoError,
},
{
name: "Bytes Decimal",
data: []byte{0x6, 0x00, 0x87, 0x78},
schema: `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`,
want: big.NewRat(1734, 5),
wantErr: require.NoError,
},
{
name: "Record",
data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f},
schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`,
want: map[string]any{"a": int64(27), "b": "foo"},
wantErr: require.NoError,
},
{
name: "Ref",
data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f},
schema: `{"type":"record","name":"parent","fields":[{"name":"a","type":{"type":"record","name":"test","fields":[{"name":"a","type":"long"},{"name":"b","type":"string"}]}},{"name":"b","type":"test"}]}`,
want: map[string]any{"a": map[string]any{"a": int64(27), "b": "foo"}, "b": map[string]any{"a": int64(27), "b": "foo"}},
wantErr: require.NoError,
},
{
name: "Array",
data: []byte{0x04, 0x36, 0x38, 0x0},
schema: `{"type":"array", "items": "int"}`,
want: []any{27, 28},
wantErr: require.NoError,
},
{
name: "Map",
data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F, 0x06, 0x66, 0x6F, 0x6F, 0x00},
schema: `{"type":"map", "values": "string"}`,
want: map[string]any{"foo": "foo"},
wantErr: require.NoError,
},
{
name: "Enum",
data: []byte{0x02},
schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`,
want: "bar",
wantErr: require.NoError,
},
{
name: "Enum Invalid Symbol",
data: []byte{0x04},
schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`,
want: nil,
wantErr: require.Error,
},
{
name: "Union",
data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F},
schema: `["null", "string"]`,
want: map[string]any{"string": "foo"},
wantErr: require.NoError,
},
{
name: "Union Nil",
data: []byte{0x00},
schema: `["null", "string"]`,
want: nil,
wantErr: require.NoError,
},
{
name: "Union Named",
data: []byte{0x02, 0x02},
schema: `["null", {"type":"enum", "name": "test", "symbols": ["foo", "bar"]}]`,
want: map[string]any{"test": "bar"},
wantErr: require.NoError,
},
{
name: "Union Invalid Schema",
data: []byte{0x04},
schema: `["null", "string"]`,
want: nil,
wantErr: require.Error,
},
{
name: "Fixed",
data: []byte{0x66, 0x6F, 0x6F, 0x66, 0x6F, 0x6F},
schema: `{"type":"fixed", "name": "test", "size": 6}`,
want: [6]byte{'f', 'o', 'o', 'f', 'o', 'o'},
wantErr: require.NoError,
},
{
name: "Fixed Decimal",
data: []byte{0x00, 0x00, 0x00, 0x00, 0x87, 0x78},
schema: `{"type":"fixed", "name": "test", "size": 6,"logicalType":"decimal","precision":4,"scale":2}`,
want: big.NewRat(1734, 5),
wantErr: require.NoError,
},
}

for i, test := range tests {
test := test
t.Run(strconv.Itoa(i), func(t *testing.T) {
schema := MustParse(test.schema)
r := NewReader(bytes.NewReader(test.data), 10)

got := genericDecode(schema, r)

test.wantErr(t, r.Error)
assert.Equal(t, test.want, got)
})
}
}

func TestGenericDecode_UnsupportedType(t *testing.T) {
schema := NewPrimitiveSchema(Type("test"), nil)
r := NewReader(bytes.NewReader([]byte{0x01}), 10)

_ = genericDecode(schema, r)

assert.Error(t, r.Error)
}
2 changes: 1 addition & 1 deletion codec_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
// We cannot resolve this, set it to the map type
name := schemaTypeName(schema)
obj := map[string]any{}
obj[name] = r.ReadNext(schema)
obj[name] = genericDecode(schema, r)

*pObj = obj
return
Expand Down
Loading