diff --git a/sqlhooks.go b/sqlhooks.go index 74818ab..2964249 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -50,10 +50,21 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { } wrapped := &Conn{conn, drv.hooks} - if isExecer(conn) { + if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { + return &ExecerQueryerContextWithSessionResetter{wrapped, + &ExecerContext{wrapped}, &QueryerContext{wrapped}, + &SessionResetter{wrapped}}, nil + } else if isExecer(conn) && isQueryer(conn) { + return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, + &QueryerContext{wrapped}}, nil + } else if isExecer(conn) { // If conn implements an Execer interface, return a driver.Conn which // also implements Execer return &ExecerContext{wrapped}, nil + } else if isQueryer(conn) { + // If conn implements an Queryer interface, return a driver.Conn which + // also implements Queryer + return &QueryerContext{wrapped}, nil } return wrapped, nil } @@ -149,6 +160,81 @@ func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Resul return nil, errors.New("Exec was called when ExecContext was implemented") } +// QueryerContext implements a database/sql.driver.QueryerContext +type QueryerContext struct { + *Conn +} + +func isQueryer(conn driver.Conn) bool { + switch conn.(type) { + case driver.QueryerContext: + return true + case driver.Queryer: + return true + default: + return false + } +} + +func (conn *QueryerContext) queryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + switch c := conn.Conn.Conn.(type) { + case driver.QueryerContext: + return c.QueryContext(ctx, query, args) + case driver.Queryer: + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return c.Query(query, dargs) + default: + // This should not happen + return nil, errors.New("QueryerContext created for a non Queryer driver.Conn") + } +} + +func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + var err error + + list := namedToInterface(args) + + // Query `Before` Hooks + if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { + return nil, err + } + + results, err := conn.queryContext(ctx, query, args) + if err != nil { + return results, handlerErr(ctx, conn.hooks, err, query, list...) + } + + if ctx, err = conn.hooks.After(ctx, query, list...); err != nil { + return nil, err + } + + return results, err +} + +// ExecerQueryerContext implements database/sql.driver.ExecerContext and +// database/sql.driver.QueryerContext +type ExecerQueryerContext struct { + *Conn + *ExecerContext + *QueryerContext +} + +// ExecerQueryerContext implements database/sql.driver.ExecerContext and +// database/sql.driver.QueryerContext +type ExecerQueryerContextWithSessionResetter struct { + *Conn + *ExecerContext + *QueryerContext + *SessionResetter +} + +type SessionResetter struct { + *Conn +} + // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt diff --git a/sqlhooks_1_10.go b/sqlhooks_1_10.go new file mode 100644 index 0000000..49095a0 --- /dev/null +++ b/sqlhooks_1_10.go @@ -0,0 +1,18 @@ +// +build go1.10 + +package sqlhooks + +import ( + "context" + "database/sql/driver" +) + +func isSessionResetter(conn driver.Conn) bool { + _, ok := conn.(driver.SessionResetter) + return ok +} + +func (s *SessionResetter) ResetSession(ctx context.Context) error { + c := s.Conn.Conn.(driver.SessionResetter) + return c.ResetSession(ctx) +} diff --git a/sqlhooks_1_10_interface_test.go b/sqlhooks_1_10_interface_test.go new file mode 100644 index 0000000..4db927b --- /dev/null +++ b/sqlhooks_1_10_interface_test.go @@ -0,0 +1,17 @@ +// +build go1.10 + +package sqlhooks + +import "database/sql/driver" + +func init() { + interfaceTestCases = append(interfaceTestCases, + struct { + name string + expectedInterfaces []interface{} + }{ + "ExecerQueryerContextSessionResetter", []interface{}{ + (*driver.ExecerContext)(nil), + (*driver.QueryerContext)(nil), + (*driver.SessionResetter)(nil)}}) +} diff --git a/sqlhooks_interface_test.go b/sqlhooks_interface_test.go new file mode 100644 index 0000000..462b12b --- /dev/null +++ b/sqlhooks_interface_test.go @@ -0,0 +1,125 @@ +package sqlhooks + +import ( + "context" + "database/sql/driver" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var interfaceTestCases = []struct { + name string + expectedInterfaces []interface{} +}{ + {"Basic", []interface{}{(*driver.Conn)(nil)}}, + {"Execer", []interface{}{(*driver.Execer)(nil)}}, + {"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}}, + {"Queryer", []interface{}{(*driver.QueryerContext)(nil)}}, + {"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}}, + {"ExecerQueryerContext", []interface{}{ + (*driver.ExecerContext)(nil), + (*driver.QueryerContext)(nil)}}, +} + +type fakeDriver struct{} + +func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { + switch dsn { + case "Basic": + return &struct{ *FakeConnBasic }{}, nil + case "Execer": + return &struct { + *FakeConnBasic + *FakeConnExecer + }{}, nil + case "ExecerContext": + return &struct { + *FakeConnBasic + *FakeConnExecerContext + }{}, nil + case "Queryer": + return &struct { + *FakeConnBasic + *FakeConnQueryer + }{}, nil + case "QueryerContext": + return &struct { + *FakeConnBasic + *FakeConnQueryerContext + }{}, nil + case "ExecerQueryerContext": + return &struct { + *FakeConnBasic + *FakeConnExecerContext + *FakeConnQueryerContext + }{}, nil + case "ExecerQueryerContextSessionResetter": + return &struct { + *FakeConnBasic + *FakeConnExecer + *FakeConnQueryer + *FakeConnSessionResetter + }{}, nil + } + + return nil, errors.New("Fake driver not implemented") +} + +// Conn implements a database/sql.driver.Conn +type FakeConnBasic struct{} + +func (*FakeConnBasic) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("Not implemented") +} +func (*FakeConnBasic) Close() error { + return errors.New("Not implemented") +} +func (*FakeConnBasic) Begin() (driver.Tx, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnExecer struct{} + +func (*FakeConnExecer) Exec(query string, args []driver.Value) (driver.Result, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnExecerContext struct{} + +func (*FakeConnExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnQueryer struct{} + +func (*FakeConnQueryer) Query(query string, args []driver.Value) (driver.Rows, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnQueryerContext struct{} + +func (*FakeConnQueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnSessionResetter struct{} + +func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { + return errors.New("Not implemented") +} + +func TestInterfaces(t *testing.T) { + drv := Wrap(&fakeDriver{}, &testHooks{}) + + for _, c := range interfaceTestCases { + conn, err := drv.Open(c.name) + require.NoErrorf(t, err, "Driver name %s", c.name) + + for _, i := range c.expectedInterfaces { + assert.Implements(t, i, conn) + } + } +} diff --git a/sqlhooks_pre_1_10.go b/sqlhooks_pre_1_10.go new file mode 100644 index 0000000..70d8ee3 --- /dev/null +++ b/sqlhooks_pre_1_10.go @@ -0,0 +1,17 @@ +// +build !go1.10 + +package sqlhooks + +import ( + "context" + "database/sql/driver" + "errors" +) + +func isSessionResetter(conn driver.Conn) bool { + return false +} + +func (s *SessionResetter) ResetSession(ctx context.Context) error { + return errors.New("SessionResetter not implemented") +}