From 549761b029e126ee8ba6ee6c967d67c1d7d119a4 Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Fri, 29 Jul 2022 19:24:24 +0300 Subject: [PATCH] chore: various embedding fixes This PR does several things: - Nil pointer embedded structs are now handled correctly. - Embedded structs no longer need to have `protobuf:` tags since they didn't use them to begin with. - Embedding primitive types now reports an error instead of silently breaking. Signed-off-by: Dmitriy Matrenichev --- field_test.go | 4 +- marshal.go | 18 +++++-- marshal_test.go | 125 ++++++++++++++++++++++++++++++++++++++++++++---- type_cache.go | 42 +++++++++------- 4 files changed, 159 insertions(+), 30 deletions(-) diff --git a/field_test.go b/field_test.go index 0a3158a..eb5401c 100644 --- a/field_test.go +++ b/field_test.go @@ -55,8 +55,8 @@ func TestEncodeNested(t *testing.T) { //nolint:govet type StructWithEmbed struct { - A int32 `protobuf:"1"` - *EmbedStruct `protobuf:"2"` + A int32 `protobuf:"1"` + *EmbedStruct } type EmbedStruct struct { diff --git a/marshal.go b/marshal.go index 1f9ad87..216cd25 100644 --- a/marshal.go +++ b/marshal.go @@ -115,11 +115,23 @@ func (m *marshaller) encodeFields(val reflect.Value, fieldsData []FieldData) { // fieldByIndex returns the field of the struct by its index if the field is exported. // Otherwise, it returns empty reflect.Value. func fieldByIndex(structVal reflect.Value, data FieldData) reflect.Value { - if data.Field.IsExported() && structVal.IsValid() { - return structVal.FieldByIndex(data.FieldIndex) + if !structVal.IsValid() || !data.Field.IsExported() || len(data.FieldIndex) == 0 { + return reflect.Value{} } - return reflect.Value{} + var result reflect.Value + + for i := 0; i < len(data.FieldIndex); i++ { + index := data.FieldIndex[:i+1] + + result = structVal.FieldByIndex(index) + if len(data.FieldIndex) > 1 && result.Kind() == reflect.Ptr && result.IsNil() { + // Embedded field is nil, return empty reflect.Value. Avo + return reflect.Value{} + } + } + + return result } //nolint:cyclop diff --git a/marshal_test.go b/marshal_test.go index 315e983..afa6d31 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -269,7 +269,7 @@ func TestStringKey(t *testing.T) { func TestInternalStructMarshal(t *testing.T) { encoded := hasInternalCanMarshal[string]{ - Field: canMarshal[string]{private: "test for tests"}, + Field: Sequence[string]{field: "test for tests"}, Field2: 150, } @@ -284,22 +284,22 @@ func TestInternalStructMarshal(t *testing.T) { } type hasInternalCanMarshal[T string | []byte] struct { - Field canMarshal[T] `protobuf:"1"` - Field2 int `protobuf:"2"` + Field Sequence[T] `protobuf:"1"` + Field2 int `protobuf:"2"` } -type canMarshal[T string | []byte] struct { - private T +type Sequence[T string | []byte] struct { + field T } //nolint:revive -func (cm *canMarshal[T]) MarshalBinary() ([]byte, error) { - return []byte(cm.private), nil +func (cm *Sequence[T]) MarshalBinary() ([]byte, error) { + return []byte(cm.field), nil } //nolint:revive -func (cm *canMarshal[T]) UnmarshalBinary(data []byte) error { - cm.private = T(data) +func (cm *Sequence[T]) UnmarshalBinary(data []byte) error { + cm.field = T(data) return nil } @@ -342,3 +342,110 @@ func TestMarshal(t *testing.T) { assert.Equal(t, a, testA) assert.Equal(t, b, testB) } + +type EmbeddedStruct struct { + Value int `protobuf:"1"` + Value2 uint32 `protobuf:"2"` +} + +type AnotherEmbeddedStruct struct { + Value1 int `protobuf:"3"` + Value2 uint32 `protobuf:"4"` +} + +func TestEmbedding(t *testing.T) { + structs := map[string]struct { + fn func(t *testing.T) + }{ + "should embed struct": { + fn: makeEmbedTest(struct { + EmbeddedStruct + }{ + EmbeddedStruct: EmbeddedStruct{ + Value: 0x11, + Value2: 0x12, + }, + }), + }, + "should embed struct pointer": { + fn: makeEmbedTest(struct { + *EmbeddedStruct + }{ + EmbeddedStruct: &EmbeddedStruct{ + Value: 0x15, + Value2: 0x16, + }, + }), + }, + "should embed nil pointer struct and not nil pointer struct": { + fn: makeEmbedTest(struct { + *EmbeddedStruct + *AnotherEmbeddedStruct + }{ + EmbeddedStruct: nil, + AnotherEmbeddedStruct: &AnotherEmbeddedStruct{ + Value1: 0x21, + Value2: 0x22, + }, + }), + }, + "should embed struct with marshaller": { + fn: makeEmbedTest(struct { + Sequence[string] + }{ + Sequence: Sequence[string]{ + "test", + }, + }), + }, + "should not embed nil struct pointer": { + fn: makeIncorrectEmbedTest(struct { + *EmbeddedStruct + }{ + EmbeddedStruct: nil, + }), + }, + "should not embed simple type": { + fn: makeIncorrectEmbedTest(struct { + int + }{ + 0x11, + }), + }, + "should not embed pointer to simple type": { + fn: makeIncorrectEmbedTest(struct { + *int + }{ + int: new(int), + }), + }, + } + + for name, test := range structs { + t.Run(name, test.fn) + } +} + +func makeEmbedTest[V any](v V) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + encoded := must(protoenc.Marshal(&v))(t) + + t.Logf("\n%s", hex.Dump(encoded)) + + var result V + + require.NoError(t, protoenc.Unmarshal(encoded, &result)) + require.Equal(t, v, result) + } +} + +func makeIncorrectEmbedTest[V any](v V) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + + _, err := protoenc.Marshal(&v) + + require.Error(t, err) + } +} diff --git a/type_cache.go b/type_cache.go index 9de58ce..9370d60 100644 --- a/type_cache.go +++ b/type_cache.go @@ -63,6 +63,24 @@ func structFields(typ reflect.Type) ([]FieldData, error) { continue } + if typField.Anonymous { + if deref(typField.Type).Kind() != reflect.Struct { + return nil, fmt.Errorf("%s.%s.%s is not a struct type", typ.PkgPath(), typ.Name(), typField.Name) + } + + fields, err := structFields(typField.Type) + if err != nil { + return nil, err + } + + for _, innerField := range fields { + innerField.FieldIndex = append([]int{i}, innerField.FieldIndex...) + result = append(result, innerField) + } + + continue + } + num := ParseTag(typField) switch num { case 0: @@ -74,23 +92,15 @@ func structFields(typ reflect.Type) ([]FieldData, error) { return nil, fmt.Errorf("%s.%s.%s has invalid protobuf tag", typ.PkgPath(), typ.Name(), typField.Name) } - if typField.Anonymous { - fields, err := structFields(typField.Type) - if err != nil { - return nil, err - } + result = append(result, FieldData{ + Num: protowire.Number(num), + FieldIndex: []int{i}, + Field: typField, + }) + } - for _, innerField := range fields { - innerField.FieldIndex = append([]int{i}, innerField.FieldIndex...) - result = append(result, innerField) - } - } else { - result = append(result, FieldData{ - Num: protowire.Number(num), - FieldIndex: []int{i}, - Field: typField, - }) - } + if len(result) == 0 { + return nil, fmt.Errorf("%s.%s has no exported fields", typ.PkgPath(), typ.Name()) } return result, nil