diff --git a/marshal.go b/marshal.go index 67ee110..43396e3 100644 --- a/marshal.go +++ b/marshal.go @@ -151,12 +151,12 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) { putBool(m, val.Bool()) case reflect.Int8, reflect.Int16: - putTag(m, num, protowire.Fixed32Type) - putInt32(m, int32(val.Int())) + putTag(m, num, protowire.VarintType) + putUVarint(m, val.Int()) case reflect.Uint8, reflect.Uint16: - putTag(m, num, protowire.Fixed32Type) - putInt32(m, int32(val.Uint())) + putTag(m, num, protowire.VarintType) + putUVarint(m, val.Uint()) case reflect.Int, reflect.Int32, reflect.Int64: putTag(m, num, protowire.VarintType) @@ -399,12 +399,12 @@ func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) { switch elem.Kind() { //nolint:exhaustive case reflect.Int8, reflect.Int16: for i := 0; i < sliceLen; i++ { - putInt32(&result, int32(val.Index(i).Int())) + putUVarint(&result, val.Index(i).Int()) } case reflect.Uint8, reflect.Uint16: for i := 0; i < sliceLen; i++ { - putInt32(&result, uint32(val.Index(i).Uint())) + putUVarint(&result, val.Index(i).Uint()) } case reflect.Bool: diff --git a/messages/helpers_test.go b/messages/helpers_test.go index 17984a2..3ea24f2 100644 --- a/messages/helpers_test.go +++ b/messages/helpers_test.go @@ -47,12 +47,22 @@ type msg[T any] interface { proto.Message } +func runTestPipe[R any, RP msg[R], T any](t *testing.T, original T) { + encoded1 := must(protoenc.Marshal(&original))(t) + decoded := protoUnmarshal[R, RP](t, encoded1) + encoded2 := must(proto.Marshal(decoded))(t) + result := ourUnmarshal[T](t, encoded2) + + shouldBeEqual(t, original, result) +} + func protoUnmarshal[T any, V msg[T]](t *testing.T, data []byte) V { t.Helper() var msg T - err := proto.Unmarshal(data, V(&msg)) + err := proto.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(data, V(&msg)) + require.NoError(t, err) return &msg diff --git a/messages/messages_test.go b/messages/messages_test.go index 7825012..3183751 100644 --- a/messages/messages_test.go +++ b/messages/messages_test.go @@ -5,11 +5,9 @@ package messages_test import ( - "encoding/hex" "testing" "github.com/stretchr/testify/require" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/emptypb" "github.com/siderolabs/protoenc" @@ -18,15 +16,6 @@ import ( // TODO: ensure that binary output is also the same -func runTestPipe[R any, RP msg[R], T any](t *testing.T, original T) { - encoded1 := must(protoenc.Marshal(&original))(t) - decoded := protoUnmarshal[R, RP](t, encoded1) - encoded2 := must(proto.Marshal(decoded))(t) - result := ourUnmarshal[T](t, encoded2) - - shouldBeEqual(t, original, result) -} - //nolint:govet type BasicMessage struct { Int64 int64 `protobuf:"1"` @@ -320,29 +309,36 @@ func TestEmptyMessage(t *testing.T) { }) } -func TestEnumMessage(t *testing.T) { - // This test ensures that we can decode a message with an enum field. - // Even tho we use fixed 32-bit values for encoding enums (unlike protobuf) decoding into int8-16s should still work. +func TestEnumMessage_CompatibleOldScheme(t *testing.T) { + // This test ensures that we can decode a message with an enum field encoded by previus version of our encoder. t.Parallel() + encoded := []byte{0x0d, 0x01, 0x00, 0x00, 0x00} + type Enum int8 type EnumMessage struct { EnumField Enum `protobuf:"1"` } - original := messages.EnumMessage{ - EnumField: messages.Enum_ENUM2, - } + dest := EnumMessage{} - encoded, err := proto.Marshal(&original) + err := protoenc.Unmarshal(encoded, &dest) require.NoError(t, err) - t.Log("\n", hex.Dump(encoded)) + require.EqualValues(t, dest.EnumField, 1) +} - decoded := EnumMessage{} - err = protoenc.Unmarshal(encoded, &decoded) - require.NoError(t, err) +func TestEnumMessage(t *testing.T) { + t.Parallel() + + type Enum int8 + + type EnumMessage struct { + EnumField Enum `protobuf:"1"` + } - require.EqualValues(t, original.EnumField, decoded.EnumField) + runTestPipe[messages.EnumMessage](t, EnumMessage{ + EnumField: 1, + }) } diff --git a/scanner.go b/scanner.go index dff5ee0..46a47a2 100644 --- a/scanner.go +++ b/scanner.go @@ -300,7 +300,7 @@ func (s *dataScanner) Wiretype() protowire.Type { func getDataScannerFor(eltype reflect.Type, buf []byte) (dataScanner, bool, error) { switch eltype.Kind() { //nolint:exhaustive case reflect.Uint8, reflect.Uint16, reflect.Int8, reflect.Int16: - return makeDataScanner(protowire.Fixed32Type, buf), true, nil + return makeDataScanner(protowire.VarintType, buf), true, nil case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint32, reflect.Uint64, reflect.Uint: diff --git a/slice_test.go b/slice_test.go index 7d946df..cb1851b 100644 --- a/slice_test.go +++ b/slice_test.go @@ -198,10 +198,10 @@ func testSliceEncodingResult[T any](slc []T, expected []byte) func(t *testing.T) func TestSmallIntegers(t *testing.T) { t.Parallel() - encodedBytes := hexToBytes(t, "0a 03 01 FF 03") - encodedFixed := hexToBytes(t, "0a 0c [01 00 00 00] [ff 00 00 00] [03 00 00 00]") - encodedFixedNegative := hexToBytes(t, "0a 0c [01 00 00 00] [ff ff ff ff] [03 00 00 00]") - encodedUint16s := hexToBytes(t, "0a 0c [01 00 00 00] [ff ff 00 00] [03 00 00 00]") + encodedBytes := hexToBytes(t, "0A 03 01 FF 03") + encodedFixed := hexToBytes(t, "0A 04 01 FF 01 03") + encodedFixedNegative := hexToBytes(t, "0A 0C [01] [FF FF FF FF FF FF FF FF FF 01] [03]") + encodedUint16s := hexToBytes(t, "0A 05 [01] [FF FF 03] [03]") type customByte byte @@ -215,7 +215,7 @@ func TestSmallIntegers(t *testing.T) { CustomByte customByte `protobuf:"5"` } - encodedCustomType := hexToBytes(t, "0a 19 [0d [ff ff ff ff]] [1d [ff ff 00 00]] [15 [ff ff ff ff]] [25 [ff 00 00 00]] [2d [ff 00 00 00]]") + encodedCustomType := hexToBytes(t, "0a 20 [08 [FF FF FF FF FF FF FF FF FF 01] 18 [FF FF 03] 10 [FF FF FF FF FF FF FF FF FF 01] 20 [FF 01] 28 [FF 01]]") tests := []struct { //nolint:govet name string @@ -226,23 +226,23 @@ func TestSmallIntegers(t *testing.T) { testEncodeDecodeWrapped([...]byte{1, 0xFF, 3}, encodedBytes), }, { - "array of custom byte types should be encoded in 'fixed32' form", + "array of custom byte type should be encoded in 'varint' form", testEncodeDecodeWrapped([...]customByte{1, 0xFF, 3}, encodedFixed), }, { - "slice of custom byte type should be encoded in 'fixed32' form", + "slice of custom byte type should be encoded in 'varint' form", testEncodeDecodeWrapped([]customByte{1, 0xFF, 3}, encodedFixed), }, { - "slice of int8 should be encoded in 'fixed32' form", + "slice of int8 should be encoded in 'varint' form", testEncodeDecodeWrapped([]int8{1, -1, 3}, encodedFixedNegative), }, { - "slice of int16 type should be encoded in 'fixed32' form", + "slice of int16 type should be encoded in 'varint' form", testEncodeDecodeWrapped([]int16{1, -1, 3}, encodedFixedNegative), }, { - "slice of uint16 type should be encoded in 'fixed32' form", + "slice of uint16 type should be encoded in 'varint' form", testEncodeDecodeWrapped([]uint16{1, 0xFFFF, 3}, encodedUint16s), }, { @@ -250,7 +250,7 @@ func TestSmallIntegers(t *testing.T) { testEncodeDecodeWrapped(customSlice{1, 0xFF, 3}, encodedBytes), }, { - "customType should be encoded in 'fixed32' form", + "customType should be encoded in 'varint' form", testEncodeDecodeWrapped(customType{ Int16: -1, Uint16: 0xFFFF, @@ -269,6 +269,7 @@ func TestSmallIntegers(t *testing.T) { func testEncodeDecodeWrapped[T any](slc T, expected []byte) func(t *testing.T) { return func(t *testing.T) { + t.Helper() t.Parallel() original := Value[T]{V: slc} diff --git a/unmarshal.go b/unmarshal.go index 78ad4dd..a2b0ddf 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -439,6 +439,7 @@ func unmarshalByteSeqeunce(dst reflect.Value, val complexValue) error { } func slice(dst reflect.Value, val complexValue) error { + // TODO: this code doesn't support the case when slice is encoded in several chunks across the message elemType := dst.Type().Elem() // we only decode bytes as []byte or [n]byte field