Skip to content

Commit

Permalink
feat: support slices for nullable unions (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Apr 18, 2024
1 parent 15d2425 commit 7a2eb5f
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 29 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ When a non-`nil` union value is encountered, a single key is en/decoded. The key
type name, or scheam full name in the case of a named schema (enum, fixed or record).
* ***T:** This is allowed in a "nullable" union. A nullable union is defined as a two schema union,
with one of the types being `null` (ie. `["null", "string"]` or `["string", "null"]`), in this case
a `*T` is allowed, with `T` matching the conversion table above.
a `*T` is allowed, with `T` matching the conversion table above. In the case of a slice, the slice can be used
directly.
* **any:** An `interface` can be provided and the type or name resolved. Primitive types
are pre-registered, but named types, maps and slices will need to be registered with the `Register` function.
In the case of arrays and maps the enclosed schema type or name is postfix to the type with a `:` separator,
Expand Down
2 changes: 1 addition & 1 deletion codec_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc
defaultType := reflect2.TypeOf(&def)
fields = append(fields, &structFieldEncoder{
defaultPtr: reflect2.PtrOf(&def),
encoder: encoderOfPtrUnion(cfg, field.Type(), defaultType),
encoder: encoderOfNullableUnion(cfg, field.Type(), defaultType),
})
continue
}
Expand Down
97 changes: 73 additions & 24 deletions codec_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@ func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V
break
}
return decoderOfMapUnion(cfg, schema, typ)

case reflect.Slice:
if !schema.(*UnionSchema).Nullable() {
break
}
return decoderOfNullableUnion(cfg, schema, typ)
case reflect.Ptr:
if !schema.(*UnionSchema).Nullable() {
break
}
return decoderOfPtrUnion(cfg, schema, typ)

return decoderOfNullableUnion(cfg, schema, typ)
case reflect.Interface:
if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
dec, err := decoderOfResolvedUnion(cfg, schema)
if err != nil {
return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)}
}

return dec
}
}
Expand All @@ -47,14 +49,17 @@ func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V
break
}
return encoderOfMapUnion(cfg, schema, typ)

case reflect.Slice:
if !schema.(*UnionSchema).Nullable() {
break
}
return encoderOfNullableUnion(cfg, schema, typ)
case reflect.Ptr:
if !schema.(*UnionSchema).Nullable() {
break
}
return encoderOfPtrUnion(cfg, schema, typ)
return encoderOfNullableUnion(cfg, schema, typ)
}

return encoderOfResolverUnion(cfg, schema, typ)
}

Expand Down Expand Up @@ -163,27 +168,39 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
encoder.Encode(elemPtr, w)
}

func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
func decoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
union := schema.(*UnionSchema)
_, typeIdx := union.Indices()
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()
decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType)

return &unionPtrDecoder{
var (
baseTyp reflect2.Type
isPtr bool
)
switch v := typ.(type) {
case *reflect2.UnsafePtrType:
baseTyp = v.Elem()
isPtr = true
case *reflect2.UnsafeSliceType:
baseTyp = v
}
decoder := decoderOfType(cfg, union.Types()[typeIdx], baseTyp)

return &unionNullableDecoder{
schema: union,
typ: elemType,
typ: baseTyp,
isPtr: isPtr,
decoder: decoder,
}
}

type unionPtrDecoder struct {
type unionNullableDecoder struct {
schema *UnionSchema
typ reflect2.Type
isPtr bool
decoder ValDecoder
}

func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
_, schema := getUnionSchema(d.schema, r)
if schema == nil {
return
Expand All @@ -194,47 +211,79 @@ func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
return
}

// Handle the non-ptr case separately.
if !d.isPtr {
if d.typ.UnsafeIsNil(ptr) {
// Create a new instance.
newPtr := d.typ.UnsafeNew()
d.decoder.Decode(newPtr, r)
d.typ.UnsafeSet(ptr, newPtr)
return
}

// Reuse the existing instance.
d.decoder.Decode(ptr, r)
return
}

if *((*unsafe.Pointer)(ptr)) == nil {
// Create new instance
// Create new instance.
newPtr := d.typ.UnsafeNew()
d.decoder.Decode(newPtr, r)
*((*unsafe.Pointer)(ptr)) = newPtr
return
}

// Reuse existing instance
// Reuse existing instance.
d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r)
}

func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
union := schema.(*UnionSchema)
nullIdx, typeIdx := union.Indices()
ptrType := typ.(*reflect2.UnsafePtrType)
encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem())

return &unionPtrEncoder{
var (
baseTyp reflect2.Type
isPtr bool
)
switch v := typ.(type) {
case *reflect2.UnsafePtrType:
baseTyp = v.Elem()
isPtr = true
case *reflect2.UnsafeSliceType:
baseTyp = v
}
encoder := encoderOfType(cfg, union.Types()[typeIdx], baseTyp)

return &unionNullableEncoder{
schema: union,
encoder: encoder,
isPtr: isPtr,
nullIdx: int64(nullIdx),
typeIdx: int64(typeIdx),
}
}

type unionPtrEncoder struct {
type unionNullableEncoder struct {
schema *UnionSchema
encoder ValEncoder
isPtr bool
nullIdx int64
typeIdx int64
}

func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
if *((*unsafe.Pointer)(ptr)) == nil {
w.WriteLong(e.nullIdx)
return
}

w.WriteLong(e.typeIdx)
e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w)
newPtr := ptr
if e.isPtr {
newPtr = *((*unsafe.Pointer)(ptr))
}
e.encoder.Encode(newPtr, w)
}

func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) {
Expand Down
50 changes: 47 additions & 3 deletions decoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,17 @@ func TestDecoder_UnionPtrReversed(t *testing.T) {
func TestDecoder_UnionPtrReuseInstance(t *testing.T) {
defer ConfigTeardown()

avro.Register("test", &TestRecord{})

data := []byte{0x02, 0x36, 0x06, 0x66, 0x6F, 0x6F}
schema := `["null", {"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

got := &TestRecord{}
var original TestRecord
got := &original
err := dec.Decode(&got)

require.NoError(t, err)
assert.IsType(t, &TestRecord{}, got)
assert.Same(t, &original, got)
assert.Equal(t, int64(27), got.A)
assert.Equal(t, "foo", got.B)
}
Expand Down Expand Up @@ -225,6 +225,50 @@ func TestDecoder_UnionPtrReversedNull(t *testing.T) {
assert.Nil(t, got)
}

func TestDecoder_UnionNullableSlice(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}
schema := `["null", "bytes"]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got []byte
err := dec.Decode(&got)

want := []byte("foo")
require.NoError(t, err)
assert.Equal(t, want, got)
}

func TestDecoder_UnionNullableSliceNull(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x00}
schema := `["null", "bytes"]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got []byte
err := dec.Decode(&got)

require.NoError(t, err)
assert.Nil(t, got)
}

func TestDecoder_UnionNullableSliceNotNullButEmpty(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x00}
schema := `["null", "bytes"]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got []byte
err := dec.Decode(&got)

require.NoError(t, err)
assert.NotNil(t, got)
assert.Empty(t, got)
}

func TestDecoder_UnionPtrInvalidSchema(t *testing.T) {
defer ConfigTeardown()

Expand Down
30 changes: 30 additions & 0 deletions encoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,36 @@ func TestEncoder_UnionPtrNotNullable(t *testing.T) {
assert.Error(t, err)
}

func TestEncoder_UnionNullableSlice(t *testing.T) {
defer ConfigTeardown()

schema := `["null", "bytes"]`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

b := []byte("foo")
err = enc.Encode(b)

require.NoError(t, err)
assert.Equal(t, []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}, buf.Bytes())
}

func TestEncoder_UnionNullableSliceNull(t *testing.T) {
defer ConfigTeardown()

schema := `["null", "bytes"]`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
require.NoError(t, err)

var b []byte
err = enc.Encode(b)

require.NoError(t, err)
assert.Equal(t, []byte{0x00}, buf.Bytes())
}

func TestEncoder_UnionInterface(t *testing.T) {
defer ConfigTeardown()

Expand Down

0 comments on commit 7a2eb5f

Please sign in to comment.