diff --git a/codec_default_internal_test.go b/codec_default_internal_test.go index 5902f76..96eaaf9 100644 --- a/codec_default_internal_test.go +++ b/codec_default_internal_test.go @@ -741,15 +741,15 @@ func TestDecoder_DefaultFixed(t *testing.T) { schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { - A string `avro:"a"` - B big.Rat `avro:"b"` + A string `avro:"a"` + B *big.Rat `avro:"b"` } var got TestRecord err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) - assert.Equal(t, big.NewRat(1734, 5), &got.B) + assert.Equal(t, big.NewRat(1734, 5), got.B) assert.Equal(t, "foo", got.A) }) } diff --git a/codec_fixed.go b/codec_fixed.go index d8b11a5..887defd 100644 --- a/codec_fixed.go +++ b/codec_fixed.go @@ -18,27 +18,34 @@ func createDecoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValDecoder { break } return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)} - case reflect.Uint64: if fixed.Size() != 8 { break } return &fixedUint64Codec{} + case reflect.Ptr: + ptrType := typ.(*reflect2.UnsafePtrType) + elemType := ptrType.Elem() + ls := fixed.Logical() + tpy1 := elemType.Type1() + if elemType.Kind() != reflect.Struct || !tpy1.ConvertibleTo(ratType) || ls == nil || + ls.Type() != Decimal { + break + } + dec := ls.(*DecimalLogicalSchema) + return &fixedDecimalCodec{prec: dec.Precision(), scale: dec.Scale(), size: fixed.Size()} case reflect.Struct: ls := fixed.Logical() if ls == nil { break } typ1 := typ.Type1() - switch { - case typ1.ConvertibleTo(durType) && ls.Type() == Duration: - return &fixedDurationCodec{} - case typ1.ConvertibleTo(ratType) && ls.Type() == Decimal: - dec := ls.(*DecimalLogicalSchema) - return &fixedDecimalCodec{prec: dec.Precision(), scale: dec.Scale(), size: fixed.Size()} + if !typ1.ConvertibleTo(durType) || ls.Type() != Duration { + break } + return &fixedDurationCodec{} } return &errorDecoder{ @@ -54,14 +61,12 @@ func createEncoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValEncoder { break } return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)} - case reflect.Uint64: if fixed.Size() != 8 { break } return &fixedUint64Codec{} - case reflect.Ptr: ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() @@ -131,7 +136,7 @@ type fixedDecimalCodec struct { func (c *fixedDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { b := make([]byte, c.size) r.Read(b) - *((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale) + *((**big.Rat)(ptr)) = ratFromBytes(b, c.scale) } func (c *fixedDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { diff --git a/codec_generic.go b/codec_generic.go index 36bf9ee..1384e9e 100644 --- a/codec_generic.go +++ b/codec_generic.go @@ -19,13 +19,6 @@ func genericDecode(typ reflect2.Type, dec ValDecoder, r *Reader) any { if reflect2.IsNil(obj) { return nil } - - // Generic reader returns a different result from the - // codec in the case of a big.Rat. Handle this. - if typ.Type1() == ratType { - dec := obj.(big.Rat) - return &dec - } return obj } @@ -125,7 +118,7 @@ func genericReceiver(schema Schema) (reflect2.Type, error) { var v LogicalDuration return reflect2.TypeOf(v), nil case Decimal: - var v big.Rat + var v *big.Rat return reflect2.TypeOf(v), nil } } diff --git a/codec_native.go b/codec_native.go index 3678de6..e4c5a41 100644 --- a/codec_native.go +++ b/codec_native.go @@ -592,7 +592,7 @@ func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } - *((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale) + *((**big.Rat)(ptr)) = ratFromBytes(b, c.scale) } func ratFromBytes(b []byte, scale int) *big.Rat { diff --git a/decoder_fixed_test.go b/decoder_fixed_test.go index 11f978e..ee5f194 100644 --- a/decoder_fixed_test.go +++ b/decoder_fixed_test.go @@ -49,7 +49,7 @@ func TestDecoder_FixedRat_Positive(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(1734, 5), got) @@ -64,7 +64,7 @@ func TestDecoder_FixedRat_Negative(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(-1734, 5), got) @@ -79,7 +79,7 @@ func TestDecoder_FixedRat_Zero(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(0, 1), got) @@ -94,7 +94,7 @@ func TestDecoder_FixedRatInvalidLogicalSchema(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) assert.Error(t, err) } diff --git a/decoder_native_test.go b/decoder_native_test.go index a8aa773..423b794 100644 --- a/decoder_native_test.go +++ b/decoder_native_test.go @@ -768,7 +768,7 @@ func TestDecoder_BytesRat_Positive(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(1734, 5), got) @@ -783,7 +783,7 @@ func TestDecoder_BytesRat_Negative(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(-1734, 5), got) @@ -798,7 +798,7 @@ func TestDecoder_BytesRat_Zero(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(0, 1), got) @@ -813,7 +813,7 @@ func TestDecoder_BytesRatInvalidSchema(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) assert.Error(t, err) } @@ -827,7 +827,7 @@ func TestDecoder_BytesRatInvalidLogicalSchema(t *testing.T) { require.NoError(t, err) got := &big.Rat{} - err = dec.Decode(got) + err = dec.Decode(&got) assert.Error(t, err) }