Skip to content

Commit

Permalink
chore: various embedding fixes
Browse files Browse the repository at this point in the history
This PR does several things:
- Nil pointer embedded structs are now handled correctly.
- Embedded structs no longer need to have `protobuf:<n>` tags since they didn't use them to begin with.
- Embedding primitive types now reports an error instead of silently breaking.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
  • Loading branch information
DmitriyMV committed Jul 29, 2022
1 parent ab9b1ff commit 549761b
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 30 deletions.
4 changes: 2 additions & 2 deletions field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestEncodeNested(t *testing.T) {

//nolint:govet
type StructWithEmbed struct {
A int32 `protobuf:"1"`
*EmbedStruct `protobuf:"2"`
A int32 `protobuf:"1"`
*EmbedStruct
}

type EmbedStruct struct {
Expand Down
18 changes: 15 additions & 3 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,23 @@ func (m *marshaller) encodeFields(val reflect.Value, fieldsData []FieldData) {
// fieldByIndex returns the field of the struct by its index if the field is exported.
// Otherwise, it returns empty reflect.Value.
func fieldByIndex(structVal reflect.Value, data FieldData) reflect.Value {
if data.Field.IsExported() && structVal.IsValid() {
return structVal.FieldByIndex(data.FieldIndex)
if !structVal.IsValid() || !data.Field.IsExported() || len(data.FieldIndex) == 0 {
return reflect.Value{}
}

return reflect.Value{}
var result reflect.Value

for i := 0; i < len(data.FieldIndex); i++ {
index := data.FieldIndex[:i+1]

result = structVal.FieldByIndex(index)
if len(data.FieldIndex) > 1 && result.Kind() == reflect.Ptr && result.IsNil() {
// Embedded field is nil, return empty reflect.Value. Avo
return reflect.Value{}
}
}

return result
}

//nolint:cyclop
Expand Down
125 changes: 116 additions & 9 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestStringKey(t *testing.T) {

func TestInternalStructMarshal(t *testing.T) {
encoded := hasInternalCanMarshal[string]{
Field: canMarshal[string]{private: "test for tests"},
Field: Sequence[string]{field: "test for tests"},
Field2: 150,
}

Expand All @@ -284,22 +284,22 @@ func TestInternalStructMarshal(t *testing.T) {
}

type hasInternalCanMarshal[T string | []byte] struct {
Field canMarshal[T] `protobuf:"1"`
Field2 int `protobuf:"2"`
Field Sequence[T] `protobuf:"1"`
Field2 int `protobuf:"2"`
}

type canMarshal[T string | []byte] struct {
private T
type Sequence[T string | []byte] struct {
field T
}

//nolint:revive
func (cm *canMarshal[T]) MarshalBinary() ([]byte, error) {
return []byte(cm.private), nil
func (cm *Sequence[T]) MarshalBinary() ([]byte, error) {
return []byte(cm.field), nil
}

//nolint:revive
func (cm *canMarshal[T]) UnmarshalBinary(data []byte) error {
cm.private = T(data)
func (cm *Sequence[T]) UnmarshalBinary(data []byte) error {
cm.field = T(data)

return nil
}
Expand Down Expand Up @@ -342,3 +342,110 @@ func TestMarshal(t *testing.T) {
assert.Equal(t, a, testA)
assert.Equal(t, b, testB)
}

type EmbeddedStruct struct {
Value int `protobuf:"1"`
Value2 uint32 `protobuf:"2"`
}

type AnotherEmbeddedStruct struct {
Value1 int `protobuf:"3"`
Value2 uint32 `protobuf:"4"`
}

func TestEmbedding(t *testing.T) {
structs := map[string]struct {
fn func(t *testing.T)
}{
"should embed struct": {
fn: makeEmbedTest(struct {
EmbeddedStruct
}{
EmbeddedStruct: EmbeddedStruct{
Value: 0x11,
Value2: 0x12,
},
}),
},
"should embed struct pointer": {
fn: makeEmbedTest(struct {
*EmbeddedStruct
}{
EmbeddedStruct: &EmbeddedStruct{
Value: 0x15,
Value2: 0x16,
},
}),
},
"should embed nil pointer struct and not nil pointer struct": {
fn: makeEmbedTest(struct {
*EmbeddedStruct
*AnotherEmbeddedStruct
}{
EmbeddedStruct: nil,
AnotherEmbeddedStruct: &AnotherEmbeddedStruct{
Value1: 0x21,
Value2: 0x22,
},
}),
},
"should embed struct with marshaller": {
fn: makeEmbedTest(struct {
Sequence[string]
}{
Sequence: Sequence[string]{
"test",
},
}),
},
"should not embed nil struct pointer": {
fn: makeIncorrectEmbedTest(struct {
*EmbeddedStruct
}{
EmbeddedStruct: nil,
}),
},
"should not embed simple type": {
fn: makeIncorrectEmbedTest(struct {
int
}{
0x11,
}),
},
"should not embed pointer to simple type": {
fn: makeIncorrectEmbedTest(struct {
*int
}{
int: new(int),
}),
},
}

for name, test := range structs {
t.Run(name, test.fn)
}
}

func makeEmbedTest[V any](v V) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
encoded := must(protoenc.Marshal(&v))(t)

t.Logf("\n%s", hex.Dump(encoded))

var result V

require.NoError(t, protoenc.Unmarshal(encoded, &result))
require.Equal(t, v, result)
}
}

func makeIncorrectEmbedTest[V any](v V) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()

_, err := protoenc.Marshal(&v)

require.Error(t, err)
}
}
42 changes: 26 additions & 16 deletions type_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ func structFields(typ reflect.Type) ([]FieldData, error) {
continue
}

if typField.Anonymous {
if deref(typField.Type).Kind() != reflect.Struct {
return nil, fmt.Errorf("%s.%s.%s is not a struct type", typ.PkgPath(), typ.Name(), typField.Name)
}

fields, err := structFields(typField.Type)
if err != nil {
return nil, err
}

for _, innerField := range fields {
innerField.FieldIndex = append([]int{i}, innerField.FieldIndex...)
result = append(result, innerField)
}

continue
}

num := ParseTag(typField)
switch num {
case 0:
Expand All @@ -74,23 +92,15 @@ func structFields(typ reflect.Type) ([]FieldData, error) {
return nil, fmt.Errorf("%s.%s.%s has invalid protobuf tag", typ.PkgPath(), typ.Name(), typField.Name)
}

if typField.Anonymous {
fields, err := structFields(typField.Type)
if err != nil {
return nil, err
}
result = append(result, FieldData{
Num: protowire.Number(num),
FieldIndex: []int{i},
Field: typField,
})
}

for _, innerField := range fields {
innerField.FieldIndex = append([]int{i}, innerField.FieldIndex...)
result = append(result, innerField)
}
} else {
result = append(result, FieldData{
Num: protowire.Number(num),
FieldIndex: []int{i},
Field: typField,
})
}
if len(result) == 0 {
return nil, fmt.Errorf("%s.%s has no exported fields", typ.PkgPath(), typ.Name())
}

return result, nil
Expand Down

0 comments on commit 549761b

Please sign in to comment.