Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for recursive schemas & structs #413

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ For security reasons, the configuration `Config.MaxByteSliceSize` restricts the
by the `Reader`. The default maximum size is `1MiB` and is configurable. This is required to stop untrusted input from consuming all memory and
crashing the application. Should this not be need, setting a negative number will disable the behaviour.

### Recursive Structs

At this moment recursive structs are not supported. It is planned for the future.

## Benchmark

Benchmark source code can be found at: [https://github.com/nrwiersma/avro-benchmarks](https://github.com/nrwiersma/avro-benchmarks)
Expand Down
116 changes: 78 additions & 38 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,88 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder {
}

ptrType := typ.(*reflect2.UnsafePtrType)
decoder = decoderOfType(c, schema, ptrType.Elem())
decoder = decoderOfType(newDecoderContext(c), schema, ptrType.Elem())
c.addDecoderToCache(schema.CacheFingerprint(), rtype, decoder)
return decoder
}

func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil {
type deferDecoder struct {
decoder ValDecoder
}

func (d *deferDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
d.decoder.Decode(ptr, r)
}

type deferEncoder struct {
encoder ValEncoder
}

func (d *deferEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
d.encoder.Encode(ptr, w)
}

type decoderContext struct {
cfg *frozenConfig
decoders map[cacheKey]ValDecoder
}

func newDecoderContext(cfg *frozenConfig) *decoderContext {
return &decoderContext{
cfg: cfg,
decoders: make(map[cacheKey]ValDecoder),
}
}

type encoderContext struct {
cfg *frozenConfig
encoders map[cacheKey]ValEncoder
}

func newEncoderContext(cfg *frozenConfig) *encoderContext {
return &encoderContext{
cfg: cfg,
encoders: make(map[cacheKey]ValEncoder),
}
}

func decoderOfType(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder {
if dec := createDecoderOfMarshaler(schema, typ); dec != nil {
return dec
}

// Handle eface case when it isnt a union
// Handle eface (empty interface) case when it isn't a union
if typ.Kind() == reflect.Interface && schema.Type() != Union {
if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
return newEfaceDecoder(cfg, schema)
return newEfaceDecoder(d, schema)
}
}

switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean:
return createDecoderOfNative(schema.(*PrimitiveSchema), typ)

case Record:
return createDecoderOfRecord(cfg, schema, typ)

key := cacheKey{fingerprint: schema.CacheFingerprint(), rtype: typ.RType()}
defDec := &deferDecoder{}
d.decoders[key] = defDec
defDec.decoder = createDecoderOfRecord(d, schema.(*RecordSchema), typ)
return defDec.decoder
case Ref:
return decoderOfType(cfg, schema.(*RefSchema).Schema(), typ)

key := cacheKey{fingerprint: schema.(*RefSchema).Schema().CacheFingerprint(), rtype: typ.RType()}
if dec, f := d.decoders[key]; f {
return dec
}
return decoderOfType(d, schema.(*RefSchema).Schema(), typ)
case Enum:
return createDecoderOfEnum(schema, typ)

return createDecoderOfEnum(schema.(*EnumSchema), typ)
case Array:
return createDecoderOfArray(cfg, schema, typ)

return createDecoderOfArray(d, schema.(*ArraySchema), typ)
case Map:
return createDecoderOfMap(cfg, schema, typ)

return createDecoderOfMap(d, schema.(*MapSchema), typ)
case Union:
return createDecoderOfUnion(cfg, schema, typ)

return createDecoderOfUnion(d, schema.(*UnionSchema), typ)
case Fixed:
return createDecoderOfFixed(schema, typ)

return createDecoderOfFixed(schema.(*FixedSchema), typ)
default:
// It is impossible to get here with a valid schema
return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())}
Expand All @@ -130,7 +170,7 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder {
return encoder
}

encoder = encoderOfType(c, schema, typ)
encoder = encoderOfType(newEncoderContext(c), schema, typ)
if typ.LikePtr() {
encoder = &onePtrEncoder{encoder}
}
Expand All @@ -146,8 +186,8 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
e.enc.Encode(noescape(unsafe.Pointer(&ptr)), w)
}

func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil {
func encoderOfType(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder {
if enc := createEncoderOfMarshaler(schema, typ); enc != nil {
return enc
}

Expand All @@ -158,28 +198,28 @@ func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncod
switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean, Null:
return createEncoderOfNative(schema, typ)

case Record:
return createEncoderOfRecord(cfg, schema, typ)

key := cacheKey{fingerprint: schema.Fingerprint(), rtype: typ.RType()}
defEnc := &deferEncoder{}
e.encoders[key] = defEnc
defEnc.encoder = createEncoderOfRecord(e, schema.(*RecordSchema), typ)
return defEnc.encoder
case Ref:
return encoderOfType(cfg, schema.(*RefSchema).Schema(), typ)

key := cacheKey{fingerprint: schema.(*RefSchema).Schema().Fingerprint(), rtype: typ.RType()}
if enc, f := e.encoders[key]; f {
return enc
}
return encoderOfType(e, schema.(*RefSchema).Schema(), typ)
case Enum:
return createEncoderOfEnum(schema, typ)

return createEncoderOfEnum(schema.(*EnumSchema), typ)
case Array:
return createEncoderOfArray(cfg, schema, typ)

return createEncoderOfArray(e, schema.(*ArraySchema), typ)
case Map:
return createEncoderOfMap(cfg, schema, typ)

return createEncoderOfMap(e, schema.(*MapSchema), typ)
case Union:
return createEncoderOfUnion(cfg, schema, typ)

return createEncoderOfUnion(e, schema.(*UnionSchema), typ)
case Fixed:
return createEncoderOfFixed(schema, typ)

return createEncoderOfFixed(schema.(*FixedSchema), typ)
default:
// It is impossible to get here with a valid schema
return &errorEncoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())}
Expand Down
20 changes: 9 additions & 11 deletions codec_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
func createDecoderOfArray(d *decoderContext, schema *ArraySchema, typ reflect2.Type) ValDecoder {
if typ.Kind() == reflect.Slice {
return decoderOfArray(cfg, schema, typ)
return decoderOfArray(d, schema, typ)
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
func createEncoderOfArray(e *encoderContext, schema *ArraySchema, typ reflect2.Type) ValEncoder {
if typ.Kind() == reflect.Slice {
return encoderOfArray(cfg, schema, typ)
return encoderOfArray(e, schema, typ)
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
arr := schema.(*ArraySchema)
func decoderOfArray(d *decoderContext, arr *ArraySchema, typ reflect2.Type) ValDecoder {
sliceType := typ.(*reflect2.UnsafeSliceType)
decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem())
decoder := decoderOfType(d, arr.Items(), sliceType.Elem())

return &arrayDecoder{typ: sliceType, decoder: decoder}
}
Expand Down Expand Up @@ -74,13 +73,12 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
}
}

func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
arr := schema.(*ArraySchema)
func encoderOfArray(e *encoderContext, arr *ArraySchema, typ reflect2.Type) ValEncoder {
sliceType := typ.(*reflect2.UnsafeSliceType)
encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem())
encoder := encoderOfType(e, arr.Items(), sliceType.Elem())

return &arrayEncoder{
blockLength: cfg.getBlockLength(),
blockLength: e.cfg.getBlockLength(),
typ: sliceType,
encoder: encoder,
}
Expand Down
7 changes: 4 additions & 3 deletions codec_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import (
"github.com/modern-go/reflect2"
)

func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) ValDecoder {
func createDefaultDecoder(d *decoderContext, field *Field, typ reflect2.Type) ValDecoder {
cfg := d.cfg
fn := func(def any) ([]byte, error) {
defaultType := reflect2.TypeOf(def)
if defaultType == nil {
defaultType = reflect2.TypeOf((*null)(nil))
}
defaultEncoder := encoderOfType(cfg, field.Type(), defaultType)
defaultEncoder := encoderOfType(newEncoderContext(cfg), field.Type(), defaultType)
if defaultType.LikePtr() {
defaultEncoder = &onePtrEncoder{defaultEncoder}
}
Expand All @@ -37,7 +38,7 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va
}
return &defaultDecoder{
data: b,
decoder: decoderOfType(cfg, field.Type(), typ),
decoder: decoderOfType(d, field.Type(), typ),
}
}

Expand Down
1 change: 0 additions & 1 deletion codec_default_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ func TestDecoder_DefaultEnum(t *testing.T) {

require.NoError(t, err)
assert.Equal(t, TestRecord{B: "bar", A: "foo"}, got)

})

t.Run("TextUnmarshaler", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions codec_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ type efaceDecoder struct {
dec ValDecoder
}

func newEfaceDecoder(cfg *frozenConfig, schema Schema) *efaceDecoder {
func newEfaceDecoder(d *decoderContext, schema Schema) *efaceDecoder {
typ, _ := genericReceiver(schema)
dec := decoderOfType(cfg, schema, typ)
dec := decoderOfType(d, schema, typ)

return &efaceDecoder{
schema: schema,
Expand Down
16 changes: 8 additions & 8 deletions codec_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
func createDecoderOfEnum(schema *EnumSchema, typ reflect2.Type) ValDecoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{enum: schema.(*EnumSchema)}
return &enumCodec{enum: schema}
case typ.Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
return &enumTextMarshalerCodec{typ: typ, enum: schema}
case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
func createEncoderOfEnum(schema *EnumSchema, typ reflect2.Type) ValEncoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{enum: schema.(*EnumSchema)}
return &enumCodec{enum: schema}
case typ.Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
return &enumTextMarshalerCodec{typ: typ, enum: schema}
case reflect2.PtrTo(typ).Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true}
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand Down
13 changes: 4 additions & 9 deletions codec_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder {
fixed := schema.(*FixedSchema)
func createDecoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValDecoder {
switch typ.Kind() {
case reflect.Array:
arrayType := typ.(reflect2.ArrayType)
Expand All @@ -21,7 +20,6 @@ func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder {
return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)}

case reflect.Uint64:
fixed := schema.(*FixedSchema)
if fixed.Size() != 8 {
break
}
Expand All @@ -44,23 +42,20 @@ func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder {
}

return &errorDecoder{
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), schema.Type(), fixed.Size()),
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), fixed.Type(), fixed.Size()),
}
}

func createEncoderOfFixed(schema Schema, typ reflect2.Type) ValEncoder {
fixed := schema.(*FixedSchema)
func createEncoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValEncoder {
switch typ.Kind() {
case reflect.Array:
arrayType := typ.(reflect2.ArrayType)
fixed := schema.(*FixedSchema)
if arrayType.Elem().Kind() != reflect.Uint8 || arrayType.Len() != fixed.Size() {
break
}
return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)}

case reflect.Uint64:
fixed := schema.(*FixedSchema)
if fixed.Size() != 8 {
break
}
Expand Down Expand Up @@ -92,7 +87,7 @@ func createEncoderOfFixed(schema Schema, typ reflect2.Type) ValEncoder {
}

return &errorEncoder{
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), schema.Type(), fixed.Size()),
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), fixed.Type(), fixed.Size()),
}
}

Expand Down
3 changes: 1 addition & 2 deletions codec_generic_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func TestGenericDecode(t *testing.T) {
want any
wantErr require.ErrorAssertionFunc
}{

{
name: "Bool",
data: []byte{0x01},
Expand Down Expand Up @@ -228,7 +227,7 @@ func TestGenericDecode(t *testing.T) {

typ, err := genericReceiver(schema)
require.NoError(t, err)
dec := decoderOfType(DefaultConfig.(*frozenConfig), schema, typ)
dec := decoderOfType(newDecoderContext(DefaultConfig.(*frozenConfig)), schema, typ)

got := genericDecode(typ, dec, r)

Expand Down
Loading
Loading