diff --git a/codec_generic_internal_test.go b/codec_generic_internal_test.go index 98a5855..a55f26b 100644 --- a/codec_generic_internal_test.go +++ b/codec_generic_internal_test.go @@ -173,6 +173,14 @@ func TestGenericDecode(t *testing.T) { want: map[string]any{"string": "foo"}, wantErr: require.NoError, }, + { + name: "Union Zero Index", + // 0x80 represents 128. So the bytes below will result in 0 + // as a result of zig-zag encoding. + data: []byte{0x80, 0x80, 0x80, 0x80, 0x30}, + schema: `["null"]`, + wantErr: require.NoError, + }, { name: "Union Nil", data: []byte{0x00}, diff --git a/codec_union.go b/codec_union.go index 9e3649a..b6fd192 100644 --- a/codec_union.go +++ b/codec_union.go @@ -152,7 +152,7 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { return } - w.WriteLong(int64(pos)) + w.WriteInt(int32(pos)) if schema.Type() == Null && val == nil { return @@ -259,8 +259,8 @@ func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) schema: union, encoder: encoder, isPtr: isPtr, - nullIdx: int64(nullIdx), - typeIdx: int64(typeIdx), + nullIdx: int32(nullIdx), + typeIdx: int32(typeIdx), } } @@ -268,17 +268,17 @@ type unionNullableEncoder struct { schema *UnionSchema encoder ValEncoder isPtr bool - nullIdx int64 - typeIdx int64 + nullIdx int32 + typeIdx int32 } func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) { if *((*unsafe.Pointer)(ptr)) == nil { - w.WriteLong(e.nullIdx) + w.WriteInt(e.nullIdx) return } - w.WriteLong(e.typeIdx) + w.WriteInt(e.typeIdx) newPtr := ptr if e.isPtr { newPtr = *((*unsafe.Pointer)(ptr)) @@ -445,7 +445,7 @@ type unionResolverEncoder struct { } func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) { - w.WriteLong(int64(e.pos)) + w.WriteInt(int32(e.pos)) e.encoder.Encode(ptr, w) } @@ -453,7 +453,7 @@ func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) { func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) { types := schema.Types() - idx := int(r.ReadLong()) + idx := int(r.ReadInt()) if idx < 0 || idx > len(types)-1 { r.ReportError("decode union type", "unknown union type") return 0, nil