From 939e15a40831c97618d1942f6722effc7d8baa17 Mon Sep 17 00:00:00 2001 From: adrianiacobghiula <2491756+adrianiacobghiula@users.noreply.github.com> Date: Mon, 24 Jun 2024 10:23:06 +0200 Subject: [PATCH] add support for recursive schemas & structs --- .gitignore | 1 + README.md | 4 -- codec.go | 116 ++++++++++++++++++++++----------- codec_array.go | 20 +++--- codec_default.go | 7 +- codec_dynamic.go | 4 +- codec_enum.go | 16 ++--- codec_fixed.go | 13 ++-- codec_generic_internal_test.go | 2 +- codec_map.go | 36 +++++----- codec_marshaler.go | 4 +- codec_ptr.go | 8 +-- codec_record.go | 48 +++++++------- codec_union.go | 71 ++++++++++---------- config.go | 5 +- decoder_array_test.go | 35 ++++++++++ decoder_map_test.go | 33 ++++++++++ decoder_union_test.go | 32 +++++++++ encoder_array_test.go | 36 ++++++++++ encoder_map_test.go | 34 ++++++++++ encoder_union_test.go | 32 +++++++++ example_test.go | 47 +++++++++++++ 22 files changed, 437 insertions(+), 167 deletions(-) diff --git a/.gitignore b/.gitignore index e69de29b..723ef36f 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/README.md b/README.md index f90c51cf..7aa0d21f 100644 --- a/README.md +++ b/README.md @@ -142,10 +142,6 @@ For security reasons, the configuration `Config.MaxByteSliceSize` restricts the by the `Reader`. The default maximum size is `1MiB` and is configurable. This is required to stop untrusted input from consuming all memory and crashing the application. Should this not be need, setting a negative number will disable the behaviour. -### Recursive Structs - -At this moment recursive structs are not supported. It is planned for the future. - ## Benchmark Benchmark source code can be found at: [https://github.com/nrwiersma/avro-benchmarks](https://github.com/nrwiersma/avro-benchmarks) diff --git a/codec.go b/codec.go index 7657f86f..370a0eae 100644 --- a/codec.go +++ b/codec.go @@ -71,48 +71,88 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder { } ptrType := typ.(*reflect2.UnsafePtrType) - decoder = decoderOfType(c, schema, ptrType.Elem()) + decoder = decoderOfType(newDecoderContext(c), schema, ptrType.Elem()) c.addDecoderToCache(schema.CacheFingerprint(), rtype, decoder) return decoder } -func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { - if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil { +type deferDecoder struct { + decoder ValDecoder +} + +func (d *deferDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + d.decoder.Decode(ptr, r) +} + +type deferEncoder struct { + encoder ValEncoder +} + +func (d *deferEncoder) Encode(ptr unsafe.Pointer, w *Writer) { + d.encoder.Encode(ptr, w) +} + +type decoderContext struct { + cfg *frozenConfig + decoders map[cacheKey]ValDecoder +} + +func newDecoderContext(cfg *frozenConfig) *decoderContext { + return &decoderContext{ + cfg: cfg, + decoders: make(map[cacheKey]ValDecoder), + } +} + +type encoderContext struct { + cfg *frozenConfig + encoders map[cacheKey]ValEncoder +} + +func newEncoderContext(cfg *frozenConfig) *encoderContext { + return &encoderContext{ + cfg: cfg, + encoders: make(map[cacheKey]ValEncoder), + } +} + +func decoderOfType(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { + if dec := createDecoderOfMarshaler(schema, typ); dec != nil { return dec } - // Handle eface case when it isnt a union + // Handle eface (empty interface) case when it isn't a union if typ.Kind() == reflect.Interface && schema.Type() != Union { if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { - return newEfaceDecoder(cfg, schema) + return newEfaceDecoder(d, schema) } } switch schema.Type() { case String, Bytes, Int, Long, Float, Double, Boolean: return createDecoderOfNative(schema.(*PrimitiveSchema), typ) - case Record: - return createDecoderOfRecord(cfg, schema, typ) - + key := cacheKey{fingerprint: schema.Fingerprint(), rtype: typ.RType()} + defDec := &deferDecoder{} + d.decoders[key] = defDec + defDec.decoder = createDecoderOfRecord(d, schema.(*RecordSchema), typ) + return defDec.decoder case Ref: - return decoderOfType(cfg, schema.(*RefSchema).Schema(), typ) - + key := cacheKey{fingerprint: schema.(*RefSchema).Schema().Fingerprint(), rtype: typ.RType()} + if dec, f := d.decoders[key]; f { + return dec + } + return decoderOfType(d, schema.(*RefSchema).Schema(), typ) case Enum: - return createDecoderOfEnum(schema, typ) - + return createDecoderOfEnum(schema.(*EnumSchema), typ) case Array: - return createDecoderOfArray(cfg, schema, typ) - + return createDecoderOfArray(d, schema.(*ArraySchema), typ) case Map: - return createDecoderOfMap(cfg, schema, typ) - + return createDecoderOfMap(d, schema.(*MapSchema), typ) case Union: - return createDecoderOfUnion(cfg, schema, typ) - + return createDecoderOfUnion(d, schema.(*UnionSchema), typ) case Fixed: - return createDecoderOfFixed(schema, typ) - + return createDecoderOfFixed(schema.(*FixedSchema), typ) default: // It is impossible to get here with a valid schema return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())} @@ -130,7 +170,7 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder { return encoder } - encoder = encoderOfType(c, schema, typ) + encoder = encoderOfType(newEncoderContext(c), schema, typ) if typ.LikePtr() { encoder = &onePtrEncoder{encoder} } @@ -146,8 +186,8 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { e.enc.Encode(noescape(unsafe.Pointer(&ptr)), w) } -func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { - if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil { +func encoderOfType(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder { + if enc := createEncoderOfMarshaler(schema, typ); enc != nil { return enc } @@ -158,28 +198,28 @@ func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncod switch schema.Type() { case String, Bytes, Int, Long, Float, Double, Boolean, Null: return createEncoderOfNative(schema, typ) - case Record: - return createEncoderOfRecord(cfg, schema, typ) - + key := cacheKey{fingerprint: schema.Fingerprint(), rtype: typ.RType()} + defEnc := &deferEncoder{} + e.encoders[key] = defEnc + defEnc.encoder = createEncoderOfRecord(e, schema.(*RecordSchema), typ) + return defEnc.encoder case Ref: - return encoderOfType(cfg, schema.(*RefSchema).Schema(), typ) - + key := cacheKey{fingerprint: schema.(*RefSchema).Schema().Fingerprint(), rtype: typ.RType()} + if enc, f := e.encoders[key]; f { + return enc + } + return encoderOfType(e, schema.(*RefSchema).Schema(), typ) case Enum: - return createEncoderOfEnum(schema, typ) - + return createEncoderOfEnum(schema.(*EnumSchema), typ) case Array: - return createEncoderOfArray(cfg, schema, typ) - + return createEncoderOfArray(e, schema.(*ArraySchema), typ) case Map: - return createEncoderOfMap(cfg, schema, typ) - + return createEncoderOfMap(e, schema.(*MapSchema), typ) case Union: - return createEncoderOfUnion(cfg, schema, typ) - + return createEncoderOfUnion(e, schema.(*UnionSchema), typ) case Fixed: - return createEncoderOfFixed(schema, typ) - + return createEncoderOfFixed(schema.(*FixedSchema), typ) default: // It is impossible to get here with a valid schema return &errorEncoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())} diff --git a/codec_array.go b/codec_array.go index 349b4a41..0b412d93 100644 --- a/codec_array.go +++ b/codec_array.go @@ -10,26 +10,25 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfArray(d *decoderContext, schema *ArraySchema, typ reflect2.Type) ValDecoder { if typ.Kind() == reflect.Slice { - return decoderOfArray(cfg, schema, typ) + return decoderOfArray(d, schema, typ) } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfArray(e *encoderContext, schema *ArraySchema, typ reflect2.Type) ValEncoder { if typ.Kind() == reflect.Slice { - return encoderOfArray(cfg, schema, typ) + return encoderOfArray(e, schema, typ) } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { - arr := schema.(*ArraySchema) +func decoderOfArray(d *decoderContext, arr *ArraySchema, typ reflect2.Type) ValDecoder { sliceType := typ.(*reflect2.UnsafeSliceType) - decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem()) + decoder := decoderOfType(d, arr.Items(), sliceType.Elem()) return &arrayDecoder{typ: sliceType, decoder: decoder} } @@ -74,13 +73,12 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { - arr := schema.(*ArraySchema) +func encoderOfArray(e *encoderContext, arr *ArraySchema, typ reflect2.Type) ValEncoder { sliceType := typ.(*reflect2.UnsafeSliceType) - encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem()) + encoder := encoderOfType(e, arr.Items(), sliceType.Elem()) return &arrayEncoder{ - blockLength: cfg.getBlockLength(), + blockLength: e.cfg.getBlockLength(), typ: sliceType, encoder: encoder, } diff --git a/codec_default.go b/codec_default.go index 5225c616..a6980eda 100644 --- a/codec_default.go +++ b/codec_default.go @@ -7,13 +7,14 @@ import ( "github.com/modern-go/reflect2" ) -func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) ValDecoder { +func createDefaultDecoder(d *decoderContext, field *Field, typ reflect2.Type) ValDecoder { + var cfg = d.cfg fn := func(def any) ([]byte, error) { defaultType := reflect2.TypeOf(def) if defaultType == nil { defaultType = reflect2.TypeOf((*null)(nil)) } - defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) + defaultEncoder := encoderOfType(newEncoderContext(cfg), field.Type(), defaultType) if defaultType.LikePtr() { defaultEncoder = &onePtrEncoder{defaultEncoder} } @@ -37,7 +38,7 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va } return &defaultDecoder{ data: b, - decoder: decoderOfType(cfg, field.Type(), typ), + decoder: decoderOfType(d, field.Type(), typ), } } diff --git a/codec_dynamic.go b/codec_dynamic.go index 229079fc..f14a04ee 100644 --- a/codec_dynamic.go +++ b/codec_dynamic.go @@ -13,9 +13,9 @@ type efaceDecoder struct { dec ValDecoder } -func newEfaceDecoder(cfg *frozenConfig, schema Schema) *efaceDecoder { +func newEfaceDecoder(d *decoderContext, schema Schema) *efaceDecoder { typ, _ := genericReceiver(schema) - dec := decoderOfType(cfg, schema, typ) + dec := decoderOfType(d, schema, typ) return &efaceDecoder{ schema: schema, diff --git a/codec_enum.go b/codec_enum.go index 8f23eb6e..65ab4535 100644 --- a/codec_enum.go +++ b/codec_enum.go @@ -10,27 +10,27 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfEnum(schema *EnumSchema, typ reflect2.Type) ValDecoder { switch { case typ.Kind() == reflect.String: - return &enumCodec{enum: schema.(*EnumSchema)} + return &enumCodec{enum: schema} case typ.Implements(textUnmarshalerType): - return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)} + return &enumTextMarshalerCodec{typ: typ, enum: schema} case reflect2.PtrTo(typ).Implements(textUnmarshalerType): - return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true} + return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true} } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfEnum(schema *EnumSchema, typ reflect2.Type) ValEncoder { switch { case typ.Kind() == reflect.String: - return &enumCodec{enum: schema.(*EnumSchema)} + return &enumCodec{enum: schema} case typ.Implements(textMarshalerType): - return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)} + return &enumTextMarshalerCodec{typ: typ, enum: schema} case reflect2.PtrTo(typ).Implements(textMarshalerType): - return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true} + return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true} } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} diff --git a/codec_fixed.go b/codec_fixed.go index 1467f576..d8b11a51 100644 --- a/codec_fixed.go +++ b/codec_fixed.go @@ -10,8 +10,7 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder { - fixed := schema.(*FixedSchema) +func createDecoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValDecoder { switch typ.Kind() { case reflect.Array: arrayType := typ.(reflect2.ArrayType) @@ -21,7 +20,6 @@ func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder { return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)} case reflect.Uint64: - fixed := schema.(*FixedSchema) if fixed.Size() != 8 { break } @@ -44,23 +42,20 @@ func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder { } return &errorDecoder{ - err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), schema.Type(), fixed.Size()), + err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), fixed.Type(), fixed.Size()), } } -func createEncoderOfFixed(schema Schema, typ reflect2.Type) ValEncoder { - fixed := schema.(*FixedSchema) +func createEncoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValEncoder { switch typ.Kind() { case reflect.Array: arrayType := typ.(reflect2.ArrayType) - fixed := schema.(*FixedSchema) if arrayType.Elem().Kind() != reflect.Uint8 || arrayType.Len() != fixed.Size() { break } return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)} case reflect.Uint64: - fixed := schema.(*FixedSchema) if fixed.Size() != 8 { break } @@ -92,7 +87,7 @@ func createEncoderOfFixed(schema Schema, typ reflect2.Type) ValEncoder { } return &errorEncoder{ - err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), schema.Type(), fixed.Size()), + err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), fixed.Type(), fixed.Size()), } } diff --git a/codec_generic_internal_test.go b/codec_generic_internal_test.go index a55f26bd..1d96c7e8 100644 --- a/codec_generic_internal_test.go +++ b/codec_generic_internal_test.go @@ -228,7 +228,7 @@ func TestGenericDecode(t *testing.T) { typ, err := genericReceiver(schema) require.NoError(t, err) - dec := decoderOfType(DefaultConfig.(*frozenConfig), schema, typ) + dec := decoderOfType(newDecoderContext(DefaultConfig.(*frozenConfig)), schema, typ) got := genericDecode(typ, dec, r) diff --git a/codec_map.go b/codec_map.go index 6c888c88..ceefa008 100644 --- a/codec_map.go +++ b/codec_map.go @@ -11,38 +11,37 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfMap(d *decoderContext, schema *MapSchema, typ reflect2.Type) ValDecoder { if typ.Kind() == reflect.Map { keyType := typ.(reflect2.MapType).Key() switch { case keyType.Kind() == reflect.String: - return decoderOfMap(cfg, schema, typ) + return decoderOfMap(d, schema, typ) case keyType.Implements(textUnmarshalerType): - return decoderOfMapUnmarshaler(cfg, schema, typ) + return decoderOfMapUnmarshaler(d, schema, typ) } } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfMap(e *encoderContext, schema *MapSchema, typ reflect2.Type) ValEncoder { if typ.Kind() == reflect.Map { keyType := typ.(reflect2.MapType).Key() switch { case keyType.Kind() == reflect.String: - return encoderOfMap(cfg, schema, typ) + return encoderOfMap(e, schema, typ) case keyType.Implements(textMarshalerType): - return encoderOfMapMarshaler(cfg, schema, typ) + return encoderOfMapMarshaler(e, schema, typ) } } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { - m := schema.(*MapSchema) +func decoderOfMap(d *decoderContext, m *MapSchema, typ reflect2.Type) ValDecoder { mapType := typ.(*reflect2.UnsafeMapType) - decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) + decoder := decoderOfType(d, m.Values(), mapType.Elem()) return &mapDecoder{ mapType: mapType, @@ -86,10 +85,9 @@ func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func decoderOfMapUnmarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { - m := schema.(*MapSchema) +func decoderOfMapUnmarshaler(d *decoderContext, m *MapSchema, typ reflect2.Type) ValDecoder { mapType := typ.(*reflect2.UnsafeMapType) - decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) + decoder := decoderOfType(d, m.Values(), mapType.Elem()) return &mapDecoderUnmarshaler{ mapType: mapType, @@ -145,13 +143,12 @@ func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { - m := schema.(*MapSchema) +func encoderOfMap(e *encoderContext, m *MapSchema, typ reflect2.Type) ValEncoder { mapType := typ.(*reflect2.UnsafeMapType) - encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) + encoder := encoderOfType(e, m.Values(), mapType.Elem()) return &mapEncoder{ - blockLength: cfg.getBlockLength(), + blockLength: e.cfg.getBlockLength(), mapType: mapType, encoder: encoder, } @@ -190,13 +187,12 @@ func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { } } -func encoderOfMapMarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { - m := schema.(*MapSchema) +func encoderOfMapMarshaler(e *encoderContext, m *MapSchema, typ reflect2.Type) ValEncoder { mapType := typ.(*reflect2.UnsafeMapType) - encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) + encoder := encoderOfType(e, m.Values(), mapType.Elem()) return &mapEncoderMarshaller{ - blockLength: cfg.getBlockLength(), + blockLength: e.cfg.getBlockLength(), mapType: mapType, keyType: mapType.Key(), encoder: encoder, diff --git a/codec_marshaler.go b/codec_marshaler.go index fa705119..d783d177 100644 --- a/codec_marshaler.go +++ b/codec_marshaler.go @@ -12,7 +12,7 @@ var ( textUnmarshalerType = reflect2.TypeOfPtr((*encoding.TextUnmarshaler)(nil)).Elem() ) -func createDecoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfMarshaler(schema Schema, typ reflect2.Type) ValDecoder { if typ.Implements(textUnmarshalerType) && schema.Type() == String { return &textMarshalerCodec{typ} } @@ -25,7 +25,7 @@ func createDecoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) return nil } -func createEncoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfMarshaler(schema Schema, typ reflect2.Type) ValEncoder { if typ.Implements(textMarshalerType) && schema.Type() == String { return &textMarshalerCodec{ typ: typ, diff --git a/codec_ptr.go b/codec_ptr.go index fc94a68c..07b099ee 100644 --- a/codec_ptr.go +++ b/codec_ptr.go @@ -7,11 +7,11 @@ import ( "github.com/modern-go/reflect2" ) -func decoderOfPtr(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfPtr(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - decoder := decoderOfType(cfg, schema, elemType) + decoder := decoderOfType(d, schema, elemType) return &dereferenceDecoder{typ: elemType, decoder: decoder} } @@ -34,11 +34,11 @@ func (d *dereferenceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) } -func encoderOfPtr(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfPtr(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder { ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - enc := encoderOfType(cfg, schema, elemType) + enc := encoderOfType(e, schema, elemType) return &dereferenceEncoder{typ: elemType, encoder: enc} } diff --git a/codec_record.go b/codec_record.go index 45ee02f8..7cfdbef3 100644 --- a/codec_record.go +++ b/codec_record.go @@ -10,20 +10,20 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfRecord(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { switch typ.Kind() { case reflect.Struct: - return decoderOfStruct(cfg, schema, typ) + return decoderOfStruct(d, schema, typ) case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { break } - return decoderOfRecord(cfg, schema, typ) + return decoderOfRecord(d, schema, typ) case reflect.Ptr: - return decoderOfPtr(cfg, schema, typ) + return decoderOfPtr(d, schema, typ) case reflect.Interface: if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok { @@ -34,28 +34,28 @@ func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} } -func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfRecord(e *encoderContext, schema *RecordSchema, typ reflect2.Type) ValEncoder { switch typ.Kind() { case reflect.Struct: - return encoderOfStruct(cfg, schema, typ) + return encoderOfStruct(e, schema, typ) case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { break } - return encoderOfRecord(cfg, schema, typ) + return encoderOfRecord(e, schema, typ) case reflect.Ptr: - return encoderOfPtr(cfg, schema, typ) + return encoderOfPtr(e, schema, typ) } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} } -func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfStruct(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { rec := schema.(*RecordSchema) - structDesc := describeStruct(cfg.getTagKey(), typ) + structDesc := describeStruct(d.cfg.getTagKey(), typ) fields := make([]*structFieldDecoder, 0, len(rec.Fields())) @@ -89,14 +89,14 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec if field.hasDef { fields = append(fields, &structFieldDecoder{ field: sf.Field, - decoder: createDefaultDecoder(cfg, field, sf.Field[len(sf.Field)-1].Type()), + decoder: createDefaultDecoder(d, field, sf.Field[len(sf.Field)-1].Type()), }) continue } } - dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()) + dec := decoderOfType(d, field.Type(), sf.Field[len(sf.Field)-1].Type()) fields = append(fields, &structFieldDecoder{ field: sf.Field, decoder: dec, @@ -152,9 +152,8 @@ func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { - rec := schema.(*RecordSchema) - structDesc := describeStruct(cfg.getTagKey(), typ) +func encoderOfStruct(e *encoderContext, rec *RecordSchema, typ reflect2.Type) ValEncoder { + structDesc := describeStruct(e.cfg.getTagKey(), typ) fields := make([]*structFieldEncoder, 0, len(rec.Fields())) for _, field := range rec.Fields() { @@ -162,7 +161,7 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc if sf != nil { fields = append(fields, &structFieldEncoder{ field: sf.Field, - encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()), + encoder: encoderOfType(e, field.Type(), sf.Field[len(sf.Field)-1].Type()), }) continue } @@ -184,14 +183,14 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc defaultType := reflect2.TypeOf(&def) fields = append(fields, &structFieldEncoder{ defaultPtr: reflect2.PtrOf(&def), - encoder: encoderOfNullableUnion(cfg, field.Type(), defaultType), + encoder: encoderOfNullableUnion(e, field.Type(), defaultType), }) continue } } defaultType := reflect2.TypeOf(def) - defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) + defaultEncoder := encoderOfType(e, field.Type(), defaultType) if defaultType.LikePtr() { defaultEncoder = &onePtrEncoder{defaultEncoder} } @@ -250,7 +249,7 @@ func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) { } } -func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfRecord(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { rec := schema.(*RecordSchema) mapType := typ.(*reflect2.UnsafeMapType) @@ -268,7 +267,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec if field.hasDef { fields[i] = recordMapDecoderField{ name: field.Name(), - decoder: createDefaultDecoder(cfg, field, mapType.Elem()), + decoder: createDefaultDecoder(d, field, mapType.Elem()), } continue } @@ -276,7 +275,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields[i] = recordMapDecoderField{ name: field.Name(), - decoder: newEfaceDecoder(cfg, field.Type()), + decoder: newEfaceDecoder(d, field.Type()), } } @@ -319,8 +318,7 @@ func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { - rec := schema.(*RecordSchema) +func encoderOfRecord(e *encoderContext, rec *RecordSchema, typ reflect2.Type) ValEncoder { mapType := typ.(*reflect2.UnsafeMapType) fields := make([]mapEncoderField, len(rec.Fields())) @@ -329,7 +327,7 @@ func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc name: field.Name(), hasDef: field.HasDefault(), def: field.Default(), - encoder: encoderOfType(cfg, field.Type(), mapType.Elem()), + encoder: encoderOfType(e, field.Type(), mapType.Elem()), } if field.HasDefault() { @@ -344,7 +342,7 @@ func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc } defaultType := reflect2.TypeOf(fields[i].def) - fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType) + fields[i].defEncoder = encoderOfType(e, field.Type(), defaultType) if defaultType.LikePtr() { fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder} } diff --git a/codec_union.go b/codec_union.go index b6fd1924..7d80b539 100644 --- a/codec_union.go +++ b/codec_union.go @@ -10,27 +10,27 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfUnion(d *decoderContext, schema *UnionSchema, typ reflect2.Type) ValDecoder { switch typ.Kind() { case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { break } - return decoderOfMapUnion(cfg, schema, typ) + return decoderOfMapUnion(d, schema, typ) case reflect.Slice: - if !schema.(*UnionSchema).Nullable() { + if !schema.Nullable() { break } - return decoderOfNullableUnion(cfg, schema, typ) + return decoderOfNullableUnion(d, schema, typ) case reflect.Ptr: - if !schema.(*UnionSchema).Nullable() { + if !schema.Nullable() { break } - return decoderOfNullableUnion(cfg, schema, typ) + return decoderOfNullableUnion(d, schema, typ) case reflect.Interface: if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { - dec, err := decoderOfResolvedUnion(cfg, schema) + dec, err := decoderOfResolvedUnion(d, schema) if err != nil { return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)} } @@ -41,30 +41,29 @@ func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfUnion(e *encoderContext, schema *UnionSchema, typ reflect2.Type) ValEncoder { switch typ.Kind() { case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { break } - return encoderOfMapUnion(cfg, schema, typ) + return encoderOfMapUnion(e, schema, typ) case reflect.Slice: - if !schema.(*UnionSchema).Nullable() { + if !schema.Nullable() { break } - return encoderOfNullableUnion(cfg, schema, typ) + return encoderOfNullableUnion(e, schema, typ) case reflect.Ptr: - if !schema.(*UnionSchema).Nullable() { + if !schema.Nullable() { break } - return encoderOfNullableUnion(cfg, schema, typ) + return encoderOfNullableUnion(e, schema, typ) } - return encoderOfResolverUnion(cfg, schema, typ) + return encoderOfResolverUnion(e, schema, typ) } -func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { - union := schema.(*UnionSchema) +func decoderOfMapUnion(d *decoderContext, union *UnionSchema, typ reflect2.Type) ValDecoder { mapType := typ.(*reflect2.UnsafeMapType) typeDecs := make([]ValDecoder, len(union.Types())) @@ -72,11 +71,11 @@ func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValD if s.Type() == Null { continue } - typeDecs[i] = newEfaceDecoder(cfg, s) + typeDecs[i] = newEfaceDecoder(d, s) } return &mapUnionDecoder{ - cfg: cfg, + cfg: d.cfg, schema: union, mapType: mapType, elemType: mapType.Elem(), @@ -116,11 +115,9 @@ func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) } -func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder { - union := schema.(*UnionSchema) - +func encoderOfMapUnion(e *encoderContext, union *UnionSchema, _ reflect2.Type) ValEncoder { return &mapUnionEncoder{ - cfg: cfg, + cfg: e.cfg, schema: union, } } @@ -161,14 +158,14 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { elemType := reflect2.TypeOf(val) elemPtr := reflect2.PtrOf(val) - encoder := encoderOfType(e.cfg, schema, elemType) + encoder := encoderOfType(newEncoderContext(e.cfg), schema, elemType) if elemType.LikePtr() { encoder = &onePtrEncoder{encoder} } encoder.Encode(elemPtr, w) } -func decoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfNullableUnion(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { union := schema.(*UnionSchema) _, typeIdx := union.Indices() @@ -183,7 +180,7 @@ func decoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) case *reflect2.UnsafeSliceType: baseTyp = v } - decoder := decoderOfType(cfg, union.Types()[typeIdx], baseTyp) + decoder := decoderOfType(d, union.Types()[typeIdx], baseTyp) return &unionNullableDecoder{ schema: union, @@ -238,7 +235,7 @@ func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) { d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) } -func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfNullableUnion(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder { union := schema.(*UnionSchema) nullIdx, typeIdx := union.Indices() @@ -253,7 +250,7 @@ func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) case *reflect2.UnsafeSliceType: baseTyp = v } - encoder := encoderOfType(cfg, union.Types()[typeIdx], baseTyp) + encoder := encoderOfType(e, union.Types()[typeIdx], baseTyp) return &unionNullableEncoder{ schema: union, @@ -286,7 +283,7 @@ func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) { e.encoder.Encode(newPtr, w) } -func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) { +func decoderOfResolvedUnion(d *decoderContext, schema Schema) (ValDecoder, error) { union := schema.(*UnionSchema) types := make([]reflect2.Type, len(union.Types())) @@ -294,13 +291,13 @@ func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error for i, schema := range union.Types() { name := unionResolutionName(schema) - typ, err := cfg.resolver.Type(name) + typ, err := d.cfg.resolver.Type(name) if err != nil { - if cfg.config.UnionResolutionError { + if d.cfg.config.UnionResolutionError { return nil, err } - if cfg.config.PartialUnionTypeResolution { + if d.cfg.config.PartialUnionTypeResolution { decoders[i] = nil types[i] = nil continue @@ -311,13 +308,13 @@ func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error break } - decoder := decoderOfType(cfg, schema, typ) + decoder := decoderOfType(d, schema, typ) decoders[i] = decoder types[i] = typ } return &unionResolvedDecoder{ - cfg: cfg, + cfg: d.cfg, schema: union, types: types, decoders: decoders, @@ -358,7 +355,7 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { r.ReportError("Union", err.Error()) return } - obj[name] = genericDecode(vTyp, decoderOfType(d.cfg, schema, vTyp), r) + obj[name] = genericDecode(vTyp, decoderOfType(newDecoderContext(d.cfg), schema, vTyp), r) *pObj = obj return @@ -408,10 +405,10 @@ func unionResolutionName(schema Schema) string { return name } -func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfResolverUnion(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder { union := schema.(*UnionSchema) - names, err := cfg.resolver.Name(typ) + names, err := e.cfg.resolver.Name(typ) if err != nil { return &errorEncoder{err: err} } @@ -431,7 +428,7 @@ func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])} } - encoder := encoderOfType(cfg, schema, typ) + encoder := encoderOfType(e, schema, typ) return &unionResolverEncoder{ pos: pos, diff --git a/config.go b/config.go index 3281b1b4..57b2769a 100644 --- a/config.go +++ b/config.go @@ -129,10 +129,10 @@ type frozenConfig struct { func (c *frozenConfig) Marshal(schema Schema, v any) ([]byte, error) { writer := c.borrowWriter() + defer c.returnWriter(writer) writer.WriteVal(schema, v) if err := writer.Error; err != nil { - c.returnWriter(writer) return nil, err } @@ -140,7 +140,6 @@ func (c *frozenConfig) Marshal(schema Schema, v any) ([]byte, error) { copied := make([]byte, len(result)) copy(copied, result) - c.returnWriter(writer) return copied, nil } @@ -159,10 +158,10 @@ func (c *frozenConfig) returnWriter(writer *Writer) { func (c *frozenConfig) Unmarshal(schema Schema, data []byte, v any) error { reader := c.borrowReader(data) + defer c.returnReader(reader) reader.ReadVal(schema, v) err := reader.Error - c.returnReader(reader) if errors.Is(err, io.EOF) { return nil diff --git a/decoder_array_test.go b/decoder_array_test.go index ca35443d..2cb81a10 100644 --- a/decoder_array_test.go +++ b/decoder_array_test.go @@ -64,6 +64,41 @@ func TestDecoder_ArraySliceOfStruct(t *testing.T) { assert.Equal(t, []TestRecord{{A: 27, B: "foo"}, {A: 27, B: "foo"}}, got) } +func TestDecoder_ArrayRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B []record `avro:"b"` + } + + data := []byte{0x2, 0x3, 0x8, 0x4, 0x0, 0x6, 0x0, 0x0} + schema := `{ + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", + "type": "int" + }, + { + "name": "b", + "type": { + "type": "array", + "items": "test" + } + } + ] + }` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + assert.NoError(t, err) + assert.Equal(t, record{A: 1, B: []record{{A: 2}, {A: 3}}}, got) +} + func TestDecoder_ArraySliceError(t *testing.T) { defer ConfigTeardown() diff --git a/decoder_map_test.go b/decoder_map_test.go index c9944c23..8563a825 100644 --- a/decoder_map_test.go +++ b/decoder_map_test.go @@ -66,6 +66,39 @@ func TestDecoder_MapMapOfStruct(t *testing.T) { assert.Equal(t, map[string]TestRecord{"foo": {A: 27, B: "foo"}}, got) } +func TestDecoder_MapOfRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B map[string]record `avro:"b"` + } + + data := []byte{0x02, 0x01, 0x0c, 0x06, 0x66, 0x6f, 0x6f, 0x04, 0x0, 0x0} + schema := `{ + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", "type": "int" + }, + { + "name": "b", + "type": { + "type": "map", "values": "test" + } + } + ] + }` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + assert.NoError(t, err) + assert.Equal(t, record{A: 1, B: map[string]record{"foo": {A: 2, B: map[string]record{}}}}, got) +} + func TestDecoder_MapMapError(t *testing.T) { defer ConfigTeardown() diff --git a/decoder_union_test.go b/decoder_union_test.go index c09e162a..4af546be 100644 --- a/decoder_union_test.go +++ b/decoder_union_test.go @@ -295,6 +295,38 @@ func TestDecoder_UnionPtrNotNullable(t *testing.T) { assert.Error(t, err) } +func TestDecoder_UnionPtrRecursiveType(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B *record `avro:"b"` + } + + data := []byte{0x02, 0x02, 0x04, 0x0} + schema := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "int"}, + {"name": "b", "type": [null, "test"]} + ] + }` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + require.NoError(t, err) + want := record{ + A: 1, + B: &record{ + A: 2, + }, + } + assert.Equal(t, want, got) +} + func TestDecoder_UnionInterface(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_array_test.go b/encoder_array_test.go index cf07c40e..f877e6d8 100644 --- a/encoder_array_test.go +++ b/encoder_array_test.go @@ -64,6 +64,42 @@ func TestEncoder_ArrayOfStruct(t *testing.T) { assert.Equal(t, []byte{0x03, 0x14, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x0}, buf.Bytes()) } +func TestEncoder_ArrayRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B []record `avro:"b"` + } + + schema := `{ + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", + "type": "int" + }, + { + "name": "b", + "type": { + "type": "array", + "items": "test" + } + } + ] + }` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{A: 1, B: []record{{A: 2}, {A: 3}}} + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x2, 0x3, 0x8, 0x4, 0x0, 0x6, 0x0, 0x0}, buf.Bytes()) +} + func TestEncoder_ArrayError(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_map_test.go b/encoder_map_test.go index d591898f..fe241305 100644 --- a/encoder_map_test.go +++ b/encoder_map_test.go @@ -66,6 +66,40 @@ func TestEncoder_MapOfStruct(t *testing.T) { assert.Equal(t, []byte{0x01, 0x12, 0x06, 0x66, 0x6F, 0x6F, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x0}, buf.Bytes()) } +func TestEncoder_MapOfRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B map[string]record `avro:"b"` + } + + schema := `{ + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", "type": "int" + }, + { + "name": "b", + "type": { + "type": "map", "values": "test" + } + } + ] + }` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{A: 1, B: map[string]record{"foo": {A: 2}}} + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x02, 0x01, 0x0c, 0x06, 0x66, 0x6f, 0x6f, 0x04, 0x0, 0x0}, buf.Bytes()) +} + func TestEncoder_MapInvalidKeyType(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_union_test.go b/encoder_union_test.go index f7687ad8..ca9327ad 100644 --- a/encoder_union_test.go +++ b/encoder_union_test.go @@ -414,6 +414,38 @@ func TestEncoder_UnionInterfaceNamed(t *testing.T) { assert.Equal(t, []byte{0x02, 0x02}, buf.Bytes()) } +func TestEncoder_UnionInterfaceRecursiveType(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B interface{} `avro:"b"` + } + + schema := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "int"}, + {"name": "b", "type": [null, "test"]} + ] +}` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{ + A: 1, + B: &record{ + A: 2, + }, + } + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x02, 0x02, 0x04, 0x0}, buf.Bytes()) +} + func TestEncoder_UnionInterfaceWithTime(t *testing.T) { defer ConfigTeardown() diff --git a/example_test.go b/example_test.go index e06790ec..3408ed79 100644 --- a/example_test.go +++ b/example_test.go @@ -3,7 +3,12 @@ package avro_test import ( "bytes" "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "log" + "sync" + "sync/atomic" + "testing" "github.com/hamba/avro/v2" ) @@ -199,3 +204,45 @@ func ExampleMarshal() { // Output: [54 6 102 111 111] } + +func TestEncoderDecoder_Concurrency(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "simple", + "namespace": "org.hamba.avro", + "fields" : [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"} + ] + }`) + + var ops atomic.Uint32 + + type SimpleRecord struct { + A int64 `avro:"a"` + B string `avro:"b"` + } + + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go func(schema avro.Schema, wg *sync.WaitGroup, idx int64) { + defer wg.Done() + in := SimpleRecord{A: idx, B: fmt.Sprintf("foo-%d", idx)} + + data, err := avro.Marshal(schema, in) + require.NoError(t, err) + + out := SimpleRecord{} + err = avro.Unmarshal(schema, data, &out) + + require.NoError(t, err) + assert.Equal(t, idx, out.A) + assert.Equal(t, fmt.Sprintf("foo-%d", idx), out.B) + ops.Add(1) + }(schema, wg, int64(i)) + } + wg.Wait() + + assert.Equal(t, uint32(1000), ops.Load()) +}