Skip to content

Commit 549761b

Browse files
committed
chore: various embedding fixes
This PR does several things: - Nil pointer embedded structs are now handled correctly. - Embedded structs no longer need to have `protobuf:<n>` tags since they didn't use them to begin with. - Embedding primitive types now reports an error instead of silently breaking. Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
1 parent ab9b1ff commit 549761b

File tree

4 files changed

+159
-30
lines changed

4 files changed

+159
-30
lines changed

field_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ func TestEncodeNested(t *testing.T) {
5555

5656
//nolint:govet
5757
type StructWithEmbed struct {
58-
A int32 `protobuf:"1"`
59-
*EmbedStruct `protobuf:"2"`
58+
A int32 `protobuf:"1"`
59+
*EmbedStruct
6060
}
6161

6262
type EmbedStruct struct {

marshal.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,23 @@ func (m *marshaller) encodeFields(val reflect.Value, fieldsData []FieldData) {
115115
// fieldByIndex returns the field of the struct by its index if the field is exported.
116116
// Otherwise, it returns empty reflect.Value.
117117
func fieldByIndex(structVal reflect.Value, data FieldData) reflect.Value {
118-
if data.Field.IsExported() && structVal.IsValid() {
119-
return structVal.FieldByIndex(data.FieldIndex)
118+
if !structVal.IsValid() || !data.Field.IsExported() || len(data.FieldIndex) == 0 {
119+
return reflect.Value{}
120120
}
121121

122-
return reflect.Value{}
122+
var result reflect.Value
123+
124+
for i := 0; i < len(data.FieldIndex); i++ {
125+
index := data.FieldIndex[:i+1]
126+
127+
result = structVal.FieldByIndex(index)
128+
if len(data.FieldIndex) > 1 && result.Kind() == reflect.Ptr && result.IsNil() {
129+
// Embedded field is nil, return empty reflect.Value. Avo
130+
return reflect.Value{}
131+
}
132+
}
133+
134+
return result
123135
}
124136

125137
//nolint:cyclop

marshal_test.go

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ func TestStringKey(t *testing.T) {
269269

270270
func TestInternalStructMarshal(t *testing.T) {
271271
encoded := hasInternalCanMarshal[string]{
272-
Field: canMarshal[string]{private: "test for tests"},
272+
Field: Sequence[string]{field: "test for tests"},
273273
Field2: 150,
274274
}
275275

@@ -284,22 +284,22 @@ func TestInternalStructMarshal(t *testing.T) {
284284
}
285285

286286
type hasInternalCanMarshal[T string | []byte] struct {
287-
Field canMarshal[T] `protobuf:"1"`
288-
Field2 int `protobuf:"2"`
287+
Field Sequence[T] `protobuf:"1"`
288+
Field2 int `protobuf:"2"`
289289
}
290290

291-
type canMarshal[T string | []byte] struct {
292-
private T
291+
type Sequence[T string | []byte] struct {
292+
field T
293293
}
294294

295295
//nolint:revive
296-
func (cm *canMarshal[T]) MarshalBinary() ([]byte, error) {
297-
return []byte(cm.private), nil
296+
func (cm *Sequence[T]) MarshalBinary() ([]byte, error) {
297+
return []byte(cm.field), nil
298298
}
299299

300300
//nolint:revive
301-
func (cm *canMarshal[T]) UnmarshalBinary(data []byte) error {
302-
cm.private = T(data)
301+
func (cm *Sequence[T]) UnmarshalBinary(data []byte) error {
302+
cm.field = T(data)
303303

304304
return nil
305305
}
@@ -342,3 +342,110 @@ func TestMarshal(t *testing.T) {
342342
assert.Equal(t, a, testA)
343343
assert.Equal(t, b, testB)
344344
}
345+
346+
type EmbeddedStruct struct {
347+
Value int `protobuf:"1"`
348+
Value2 uint32 `protobuf:"2"`
349+
}
350+
351+
type AnotherEmbeddedStruct struct {
352+
Value1 int `protobuf:"3"`
353+
Value2 uint32 `protobuf:"4"`
354+
}
355+
356+
func TestEmbedding(t *testing.T) {
357+
structs := map[string]struct {
358+
fn func(t *testing.T)
359+
}{
360+
"should embed struct": {
361+
fn: makeEmbedTest(struct {
362+
EmbeddedStruct
363+
}{
364+
EmbeddedStruct: EmbeddedStruct{
365+
Value: 0x11,
366+
Value2: 0x12,
367+
},
368+
}),
369+
},
370+
"should embed struct pointer": {
371+
fn: makeEmbedTest(struct {
372+
*EmbeddedStruct
373+
}{
374+
EmbeddedStruct: &EmbeddedStruct{
375+
Value: 0x15,
376+
Value2: 0x16,
377+
},
378+
}),
379+
},
380+
"should embed nil pointer struct and not nil pointer struct": {
381+
fn: makeEmbedTest(struct {
382+
*EmbeddedStruct
383+
*AnotherEmbeddedStruct
384+
}{
385+
EmbeddedStruct: nil,
386+
AnotherEmbeddedStruct: &AnotherEmbeddedStruct{
387+
Value1: 0x21,
388+
Value2: 0x22,
389+
},
390+
}),
391+
},
392+
"should embed struct with marshaller": {
393+
fn: makeEmbedTest(struct {
394+
Sequence[string]
395+
}{
396+
Sequence: Sequence[string]{
397+
"test",
398+
},
399+
}),
400+
},
401+
"should not embed nil struct pointer": {
402+
fn: makeIncorrectEmbedTest(struct {
403+
*EmbeddedStruct
404+
}{
405+
EmbeddedStruct: nil,
406+
}),
407+
},
408+
"should not embed simple type": {
409+
fn: makeIncorrectEmbedTest(struct {
410+
int
411+
}{
412+
0x11,
413+
}),
414+
},
415+
"should not embed pointer to simple type": {
416+
fn: makeIncorrectEmbedTest(struct {
417+
*int
418+
}{
419+
int: new(int),
420+
}),
421+
},
422+
}
423+
424+
for name, test := range structs {
425+
t.Run(name, test.fn)
426+
}
427+
}
428+
429+
func makeEmbedTest[V any](v V) func(t *testing.T) {
430+
return func(t *testing.T) {
431+
t.Helper()
432+
encoded := must(protoenc.Marshal(&v))(t)
433+
434+
t.Logf("\n%s", hex.Dump(encoded))
435+
436+
var result V
437+
438+
require.NoError(t, protoenc.Unmarshal(encoded, &result))
439+
require.Equal(t, v, result)
440+
}
441+
}
442+
443+
func makeIncorrectEmbedTest[V any](v V) func(t *testing.T) {
444+
return func(t *testing.T) {
445+
t.Helper()
446+
447+
_, err := protoenc.Marshal(&v)
448+
449+
require.Error(t, err)
450+
}
451+
}

type_cache.go

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ func structFields(typ reflect.Type) ([]FieldData, error) {
6363
continue
6464
}
6565

66+
if typField.Anonymous {
67+
if deref(typField.Type).Kind() != reflect.Struct {
68+
return nil, fmt.Errorf("%s.%s.%s is not a struct type", typ.PkgPath(), typ.Name(), typField.Name)
69+
}
70+
71+
fields, err := structFields(typField.Type)
72+
if err != nil {
73+
return nil, err
74+
}
75+
76+
for _, innerField := range fields {
77+
innerField.FieldIndex = append([]int{i}, innerField.FieldIndex...)
78+
result = append(result, innerField)
79+
}
80+
81+
continue
82+
}
83+
6684
num := ParseTag(typField)
6785
switch num {
6886
case 0:
@@ -74,23 +92,15 @@ func structFields(typ reflect.Type) ([]FieldData, error) {
7492
return nil, fmt.Errorf("%s.%s.%s has invalid protobuf tag", typ.PkgPath(), typ.Name(), typField.Name)
7593
}
7694

77-
if typField.Anonymous {
78-
fields, err := structFields(typField.Type)
79-
if err != nil {
80-
return nil, err
81-
}
95+
result = append(result, FieldData{
96+
Num: protowire.Number(num),
97+
FieldIndex: []int{i},
98+
Field: typField,
99+
})
100+
}
82101

83-
for _, innerField := range fields {
84-
innerField.FieldIndex = append([]int{i}, innerField.FieldIndex...)
85-
result = append(result, innerField)
86-
}
87-
} else {
88-
result = append(result, FieldData{
89-
Num: protowire.Number(num),
90-
FieldIndex: []int{i},
91-
Field: typField,
92-
})
93-
}
102+
if len(result) == 0 {
103+
return nil, fmt.Errorf("%s.%s has no exported fields", typ.PkgPath(), typ.Name())
94104
}
95105

96106
return result, nil

0 commit comments

Comments
 (0)