Skip to content

Commit

Permalink
fix: embedding of scanonly fields
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jan 29, 2024
1 parent 9052fc4 commit ed6ed74
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
20 changes: 10 additions & 10 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,18 @@ func (t *Table) processFields(
continue
}

typ := sf.Type
if sf.Type.Kind() == reflect.Ptr {
typ = sf.Type.Elem()
sfType := sf.Type
if sfType.Kind() == reflect.Ptr {
sfType = sfType.Elem()
}

if typ.Kind() != reflect.Struct { // ignore unexported non-struct types
if sfType.Kind() != reflect.Struct { // ignore unexported non-struct types
continue
}

subtable := newTable(t.dialect, typ, seen, canAddr)
subtable := newTable(t.dialect, sfType, seen, canAddr)

for _, subfield := range subtable.Fields {
for _, subfield := range subtable.allFields {
embedded = append(embedded, embeddedField{
index: sf.Index,
unexported: unexported,
Expand Down Expand Up @@ -253,7 +253,7 @@ func (t *Table) processFields(
}

for _, embfield := range embedded {
subfield := *embfield.subfield
subfield := embfield.subfield.Clone()

if ambiguousNames[subfield.Name] > 1 &&
!(!subfield.Tag.IsZero() && ambiguousTags[subfield.Name] == 1) {
Expand All @@ -265,7 +265,7 @@ func (t *Table) processFields(
subfield.Name = embfield.prefix + subfield.Name
subfield.SQLName = t.quoteIdent(subfield.Name)
}
t.addField(&subfield)
t.addField(subfield)
}
}

Expand Down Expand Up @@ -306,8 +306,8 @@ func (t *Table) addField(field *Field) {
}

t.FieldMap[field.Name] = field
if v, ok := field.Tag.Option("alt"); ok {
t.FieldMap[v] = field
if altName, ok := field.Tag.Option("alt"); ok {
t.FieldMap[altName] = field
}

if field.Tag.HasOption("scanonly") {
Expand Down
32 changes: 31 additions & 1 deletion schema/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ func TestTable(t *testing.T) {
require.Equal(t, []int{1, 0}, barView.Index)
})

t.Run("embed scanonly", func(t *testing.T) {
type Model1 struct {
Foo string
Bar string `bun:",scanonly"`
}

type Model2 struct {
Model1
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))
require.Len(t, table.FieldMap, 2)

foo, ok := table.FieldMap["foo"]
require.True(t, ok)
require.Equal(t, []int{0, 0}, foo.Index)

bar, ok := table.FieldMap["bar"]
require.True(t, ok)
require.Equal(t, []int{0, 1}, bar.Index)
})

t.Run("scanonly", func(t *testing.T) {
type Model1 struct {
Foo string
Expand All @@ -120,17 +142,25 @@ func TestTable(t *testing.T) {

type Model2 struct {
XXX Model1 `bun:",scanonly"`
Baz string `bun:",scanonly"`
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))

require.Len(t, table.StructMap, 1)
require.NotNil(t, table.StructMap["xxx"])

require.Len(t, table.FieldMap, 2)
baz := table.FieldMap["baz"]
require.NotNil(t, baz)
require.Equal(t, []int{1}, baz.Index)

foo := table.LookupField("xxx__foo")
require.NotNil(t, foo)
require.Equal(t, []int{0, 0}, foo.Index)

bar := table.LookupField("xxx__bar")
require.NotNil(t, foo)
require.NotNil(t, bar)
require.Equal(t, []int{0, 1}, bar.Index)
})

Expand Down

0 comments on commit ed6ed74

Please sign in to comment.