Skip to content

Commit

Permalink
Modified Pre/Post functions into using interface
Browse files Browse the repository at this point in the history
  • Loading branch information
umisama authored and James Cooper committed May 14, 2014
1 parent b5ce3b9 commit f1c93ef
Showing 1 changed file with 76 additions and 49 deletions.
125 changes: 76 additions & 49 deletions gorp.go
Original file line number Diff line number Diff line change
Expand Up @@ -1370,17 +1370,21 @@ func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
// Determine where the results are: written to i, or returned in list
if t, _ := toSliceType(i); t == nil {
for _, v := range list {
err = runHook("PostGet", reflect.ValueOf(v), hookArg(exec))
if err != nil {
return nil, err
if v, ok := v.(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
}
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(i))
for i := 0; i < resultsValue.Len(); i++ {
err = runHook("PostGet", resultsValue.Index(i), hookArg(exec))
if err != nil {
return nil, err
if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
}
}
Expand Down Expand Up @@ -1708,27 +1712,30 @@ func get(m *DbMap, exec SqlExecutor, i interface{},
}
}

err = runHook("PostGet", v, hookArg(exec))
if err != nil {
return nil, err
if v, ok := v.Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}

return v.Interface(), nil
}

func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
hookarg := hookArg(exec)
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}

eptr := elem.Addr()
err = runHook("PreDelete", eptr, hookarg)
if err != nil {
return -1, err
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreDelete); ok {
err = v.PreDelete(exec)
if err != nil {
return -1, err
}
}

bi, err := table.bindDelete(elem)
Expand All @@ -1752,28 +1759,31 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {

count += rows

err = runHook("PostDelete", eptr, hookarg)
if err != nil {
return -1, err
if v, ok := eval.(HasPostDelete); ok {
err := v.PostDelete(exec)
if err != nil {
return -1, err
}
}
}

return count, nil
}

func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
hookarg := hookArg(exec)
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}

eptr := elem.Addr()
err = runHook("PreUpdate", eptr, hookarg)
if err != nil {
return -1, err
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreUpdate); ok {
err = v.PreUpdate(exec)
if err != nil {
return -1, err
}
}

bi, err := table.bindUpdate(elem)
Expand Down Expand Up @@ -1802,26 +1812,29 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {

count += rows

err = runHook("PostUpdate", eptr, hookarg)
if err != nil {
return -1, err
if v, ok := eval.(HasPostUpdate); ok {
err = v.PostUpdate(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}

func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
hookarg := hookArg(exec)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, false)
if err != nil {
return err
}

eptr := elem.Addr()
err = runHook("PreInsert", eptr, hookarg)
if err != nil {
return err
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreInsert); ok {
err := v.PreInsert(exec)
if err != nil {
return err
}
}

bi, err := table.bindInsert(elem)
Expand Down Expand Up @@ -1850,25 +1863,11 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
}
}

err = runHook("PostInsert", eptr, hookarg)
if err != nil {
return err
}
}
return nil
}

func hookArg(exec SqlExecutor) []reflect.Value {
execval := reflect.ValueOf(exec)
return []reflect.Value{execval}
}

func runHook(name string, eptr reflect.Value, arg []reflect.Value) error {
hook := eptr.MethodByName(name)
if hook != zeroVal {
ret := hook.Call(arg)
if len(ret) > 0 && !ret[0].IsNil() {
return ret[0].Interface().(error)
if v, ok := eval.(HasPostInsert); ok {
err := v.PostInsert(exec)
if err != nil {
return err
}
}
}
return nil
Expand All @@ -1889,3 +1888,31 @@ func lockError(m *DbMap, exec SqlExecutor, tableName string,
}
return -1, ole
}

type HasPostGet interface {
PostGet(SqlExecutor) error
}

type HasPostDelete interface {
PostDelete(SqlExecutor) error
}

type HasPostUpdate interface {
PostUpdate(SqlExecutor) error
}

type HasPostInsert interface {
PostInsert(SqlExecutor) error
}

type HasPreDelete interface {
PreDelete(SqlExecutor) error
}

type HasPreUpdate interface {
PreUpdate(SqlExecutor) error
}

type HasPreInsert interface {
PreInsert(SqlExecutor) error
}

0 comments on commit f1c93ef

Please sign in to comment.