Skip to content

Commit e4e23d2

Browse files
authored
fix: nested preload with join panic when find (#6877)
1 parent c4c9aa4 commit e4e23d2

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

callbacks/preload.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
121121
}
122122
} else if rel := relationships.Relations[name]; rel != nil {
123123
if joined, nestedJoins := isJoined(name); joined {
124-
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
125-
tx := preloadDB(db, reflectValue, reflectValue.Interface())
126-
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
127-
return err
124+
switch rv := db.Statement.ReflectValue; rv.Kind() {
125+
case reflect.Slice, reflect.Array:
126+
for i := 0; i < rv.Len(); i++ {
127+
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
128+
tx := preloadDB(db, reflectValue, reflectValue.Interface())
129+
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
130+
return err
131+
}
132+
}
133+
case reflect.Struct:
134+
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
135+
tx := preloadDB(db, reflectValue, reflectValue.Interface())
136+
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
137+
return err
138+
}
139+
default:
140+
return gorm.ErrInvalidData
128141
}
129142
} else {
130143
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})

tests/preload_test.go

+10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"sync"
99
"testing"
1010

11+
"github.com/stretchr/testify/require"
12+
1113
"gorm.io/gorm"
1214
"gorm.io/gorm/clause"
1315
. "gorm.io/gorm/utils/tests"
@@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
362364
t.Errorf("failed to find value, got err: %v", err)
363365
}
364366
AssertEqual(t, find2, value)
367+
368+
var finds []Value
369+
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
370+
if err != nil {
371+
t.Errorf("failed to find value, got err: %v", err)
372+
}
373+
require.Len(t, finds, 1)
374+
AssertEqual(t, finds[0], value)
365375
}
366376

367377
func TestEmbedPreload(t *testing.T) {

0 commit comments

Comments
 (0)