diff --git a/db.go b/db.go index d78062e5..0b8067bb 100644 --- a/db.go +++ b/db.go @@ -93,6 +93,13 @@ type DbMap struct { tablesDynamic map[string]*TableMap // tables that use same go-struct and different db table names logger GorpLogger logPrefix string + + Cache Cache +} + +type Cache interface { + Load(key interface{}) (value interface{}, ok bool) + Store(key, value interface{}) } func (m *DbMap) dynamicTableAdd(tableName string, tbl *TableMap) { diff --git a/gorp.go b/gorp.go index fc654567..8cdcf724 100644 --- a/gorp.go +++ b/gorp.go @@ -256,7 +256,32 @@ func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect }), args } +type fieldCacheKey struct { + t reflect.Type + name string + cols string +} + +type fieldCacheEntry struct { + mapping [][]int + err error +} + func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) { + var ck fieldCacheKey + var err error + if m.Cache != nil { + ck.t = t + ck.name = name + ck.cols = strings.Join(cols, ",") + + rv, ok := m.Cache.Load(ck) + if ok { + entry := rv.(*fieldCacheEntry) + return entry.mapping, entry.err + } + } + colToFieldIndex := make([][]int, len(cols)) // check if type t is a mapped table - if so we'll @@ -298,13 +323,22 @@ func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([ missingColNames = append(missingColNames, colName) } } + if len(missingColNames) > 0 { - return colToFieldIndex, &NoFieldInTypeError{ + err = &NoFieldInTypeError{ TypeName: t.Name(), MissingColNames: missingColNames, } } - return colToFieldIndex, nil + + if m.Cache != nil { + entry := &fieldCacheEntry{ + mapping: colToFieldIndex, + err: err, + } + m.Cache.Store(ck, entry) + } + return colToFieldIndex, err } func fieldByName(val reflect.Value, fieldName string) *reflect.Value { diff --git a/mapping_test.go b/mapping_test.go new file mode 100644 index 00000000..030c14db --- /dev/null +++ b/mapping_test.go @@ -0,0 +1,140 @@ +package gorp + +import ( + "reflect" + "sync" + "testing" + "time" +) + +type testUser struct { + ID uint64 `db:"id"` + Username string `db:"user_name"` + HashedPassword []byte `db:"hashed_password"` + EMail string `db:"email"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +type testCoolUser struct { + testUser + IsCool bool `db:"is_cool"` + BestFriends []string `db:"best_friends"` +} + +func BenchmarkColumnToFieldIndex(b *testing.B) { + structType := reflect.TypeOf(testUser{}) + dbmap := &DbMap{Cache: &sync.Map{}} + b.ResetTimer() + for n := 0; n < b.N; n++ { + _, err := columnToFieldIndex(dbmap, + structType, + "some_table", + []string{ + "user_name", + "email", + "created_at", + "updated_at", + "id", + }) + if err != nil { + panic(err) + } + } +} + +func TestColumnToFieldIndexBasic(t *testing.T) { + structType := reflect.TypeOf(testUser{}) + dbmap := &DbMap{} + cols, err := columnToFieldIndex(dbmap, + structType, + "some_table", + []string{ + "email", + }) + if err != nil { + t.Fatal(err) + } + if len(cols) != 1 { + t.Fatal("cols should have 1 result", cols) + } + if cols[0][0] != 3 { + t.Fatal("cols[0][0] should map to email field in testUser", cols) + } +} + +func TestColumnToFieldIndexSome(t *testing.T) { + structType := reflect.TypeOf(testUser{}) + dbmap := &DbMap{} + cols, err := columnToFieldIndex(dbmap, + structType, + "some_table", + []string{ + "id", + "email", + "created_at", + }) + if err != nil { + t.Fatal(err) + } + if len(cols) != 3 { + t.Fatal("cols should have 3 results", cols) + } + if cols[0][0] != 0 { + t.Fatal("cols[0][0] should map to id field in testUser", cols) + } + if cols[1][0] != 3 { + t.Fatal("cols[1][0] should map to email field in testUser", cols) + } + if cols[2][0] != 4 { + t.Fatal("cols[2][0] should map to created_at field in testUser", cols) + } +} + +func TestColumnToFieldIndexEmbedded(t *testing.T) { + structType := reflect.TypeOf(testCoolUser{}) + dbmap := &DbMap{} + cols, err := columnToFieldIndex(dbmap, + structType, + "some_table", + []string{ + "id", + "email", + "is_cool", + }) + if err != nil { + t.Fatal(err) + } + if len(cols) != 3 { + t.Fatal("cols should have 3 results", cols) + } + if cols[0][0] != 0 && cols[0][1] != 0 { + t.Fatal("cols[0][0] should map to id field in testCoolUser", cols) + } + if cols[1][0] != 0 && cols[1][1] != 3 { + t.Fatal("cols[1][0] should map to email field in testCoolUser", cols) + } + if cols[2][0] != 1 { + t.Fatal("cols[2][0] should map to is_cool field in testCoolUser", cols) + } +} + +func TestColumnToFieldIndexEmbeddedFriends(t *testing.T) { + structType := reflect.TypeOf(testCoolUser{}) + dbmap := &DbMap{} + cols, err := columnToFieldIndex(dbmap, + structType, + "some_table", + []string{ + "best_friends", + }) + if err != nil { + t.Fatal(err) + } + if len(cols) != 1 { + t.Fatal("cols should have 1 results", cols) + } + if cols[0][0] != 2 { + t.Fatal("cols[0][0] should map to BestFriends field in testCoolUser", cols) + } +}