From 98e9bd8ddee254ae4ae9ad8058c37e993ec097d8 Mon Sep 17 00:00:00 2001 From: Suresh Kumar Date: Wed, 2 Jan 2019 16:10:48 +0530 Subject: [PATCH 1/4] Add QueryerContext interface If we don't support QueryerContext, the db.Query() call will always do "prepare" statement --- sqlhooks.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/sqlhooks.go b/sqlhooks.go index 74818ab..4a5ceb6 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -50,10 +50,16 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { } wrapped := &Conn{conn, drv.hooks} - if isExecer(conn) { + if isExecer(conn) && isQueryer(conn) { + return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}}, nil + } else if isExecer(conn) { // If conn implements an Execer interface, return a driver.Conn which // also implements Execer return &ExecerContext{wrapped}, nil + } else if isQueryer(conn) { + // If conn implements an Queryer interface, return a driver.Conn which + // also implements Queryer + return &QueryerContext{wrapped}, nil } return wrapped, nil } @@ -149,6 +155,68 @@ func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Resul return nil, errors.New("Exec was called when ExecContext was implemented") } +// QueryerContext implements a database/sql.driver.QueryerContext +type QueryerContext struct { + *Conn +} + +func isQueryer(conn driver.Conn) bool { + switch conn.(type) { + case driver.QueryerContext: + return true + case driver.Queryer: + return true + default: + return false + } +} + +func (conn *QueryerContext) queryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + switch c := conn.Conn.Conn.(type) { + case driver.QueryerContext: + return c.QueryContext(ctx, query, args) + case driver.Queryer: + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return c.Query(query, dargs) + default: + // This should not happen + return nil, errors.New("QueryerContext created for a non Queryer driver.Conn") + } +} + +func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + var err error + + list := namedToInterface(args) + + // Query `Before` Hooks + if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { + return nil, err + } + + results, err := conn.queryContext(ctx, query, args) + if err != nil { + return results, handlerErr(ctx, conn.hooks, err, query, list...) + } + + if ctx, err = conn.hooks.After(ctx, query, list...); err != nil { + return nil, err + } + + return results, err +} + +// ExecerQueryerContext implements database/sql.driver.ExecerContext and +// database/sql.driver.QueryerContext +type ExecerQueryerContext struct { + *Conn + *ExecerContext + *QueryerContext +} + // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt From bacd1d2500e94ed838f54b462af0ef92d71d39eb Mon Sep 17 00:00:00 2001 From: Suresh Kumar Date: Tue, 19 Feb 2019 19:33:52 +0530 Subject: [PATCH 2/4] Implement SessionResetter interface and add unit test cases --- sqlhooks.go | 40 +++++++++++- sqlhooks_interface_test.go | 129 +++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 sqlhooks_interface_test.go diff --git a/sqlhooks.go b/sqlhooks.go index 4a5ceb6..e0558f4 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -50,8 +50,13 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { } wrapped := &Conn{conn, drv.hooks} - if isExecer(conn) && isQueryer(conn) { - return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}}, nil + if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { + return &ExecerQueryerContextWithSessionResetter{wrapped, + &ExecerContext{wrapped}, &QueryerContext{wrapped}, + &SessionResetter{wrapped}}, nil + } else if isExecer(conn) && isQueryer(conn) { + return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, + &QueryerContext{wrapped}}, nil } else if isExecer(conn) { // If conn implements an Execer interface, return a driver.Conn which // also implements Execer @@ -217,6 +222,37 @@ type ExecerQueryerContext struct { *QueryerContext } +// ExecerQueryerContext implements database/sql.driver.ExecerContext and +// database/sql.driver.QueryerContext +type ExecerQueryerContextWithSessionResetter struct { + *Conn + *ExecerContext + *QueryerContext + *SessionResetter +} + +type SessionResetter struct { + *Conn +} + +func isSessionResetter(conn driver.Conn) bool { + switch conn.(type) { + case driver.SessionResetter: + return true + default: + return false + } +} + +func (s *SessionResetter) ResetSession(ctx context.Context) error { + switch c := s.Conn.Conn.(type) { + case driver.SessionResetter: + return c.ResetSession(ctx) + } + + return nil +} + // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt diff --git a/sqlhooks_interface_test.go b/sqlhooks_interface_test.go new file mode 100644 index 0000000..dfb5653 --- /dev/null +++ b/sqlhooks_interface_test.go @@ -0,0 +1,129 @@ +package sqlhooks + +import ( + "context" + "database/sql/driver" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeDriver struct{} + +func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { + switch dsn { + case "Basic": + return &struct{ *FakeConnBasic }{}, nil + case "Execer": + return &struct { + *FakeConnBasic + *FakeConnExecer + }{}, nil + case "ExecerContext": + return &struct { + *FakeConnBasic + *FakeConnExecerContext + }{}, nil + case "Queryer": + return &struct { + *FakeConnBasic + *FakeConnQueryer + }{}, nil + case "QueryerContext": + return &struct { + *FakeConnBasic + *FakeConnQueryerContext + }{}, nil + case "ExecerQueryerContext": + return &struct { + *FakeConnBasic + *FakeConnExecerContext + *FakeConnQueryerContext + }{}, nil + case "ExecerQueryerContextSessionResetter": + return &struct { + *FakeConnBasic + *FakeConnExecer + *FakeConnQueryer + *FakeConnSessionResetter + }{}, nil + } + + return nil, errors.New("Fake driver not implemented") +} + +// Conn implements a database/sql.driver.Conn +type FakeConnBasic struct{} + +func (*FakeConnBasic) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("Not implemented") +} +func (*FakeConnBasic) Close() error { + return errors.New("Not implemented") +} +func (*FakeConnBasic) Begin() (driver.Tx, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnExecer struct{} + +func (*FakeConnExecer) Exec(query string, args []driver.Value) (driver.Result, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnExecerContext struct{} + +func (*FakeConnExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnQueryer struct{} + +func (*FakeConnQueryer) Query(query string, args []driver.Value) (driver.Rows, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnQueryerContext struct{} + +func (*FakeConnQueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return nil, errors.New("Not implemented") +} + +type FakeConnSessionResetter struct{} + +func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { + return errors.New("Not implemented") +} + +func TestInterfaces(t *testing.T) { + drv := Wrap(&fakeDriver{}, &testHooks{}) + + cases := []struct { + name string + expectedInterfaces []interface{} + }{ + {"Basic", []interface{}{(*driver.Conn)(nil)}}, + {"Execer", []interface{}{(*driver.Execer)(nil)}}, + {"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}}, + {"Queryer", []interface{}{(*driver.QueryerContext)(nil)}}, + {"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}}, + {"ExecerQueryerContext", []interface{}{ + (*driver.ExecerContext)(nil), + (*driver.QueryerContext)(nil)}}, + {"ExecerQueryerContextSessionResetter", []interface{}{ + (*driver.ExecerContext)(nil), + (*driver.QueryerContext)(nil), + (*driver.SessionResetter)(nil)}}, + } + + for _, c := range cases { + conn, err := drv.Open(c.name) + require.NoErrorf(t, err, "Driver name %s", c.name) + + for _, i := range c.expectedInterfaces { + assert.Implements(t, i, conn) + } + } +} From 11efd5ea7e74dddbf23763819da637f7dc15d929 Mon Sep 17 00:00:00 2001 From: Suresh Kumar Date: Tue, 26 Feb 2019 14:36:16 +0530 Subject: [PATCH 3/4] Assert SessionResetter type --- sqlhooks.go | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/sqlhooks.go b/sqlhooks.go index e0558f4..3343205 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -236,21 +236,13 @@ type SessionResetter struct { } func isSessionResetter(conn driver.Conn) bool { - switch conn.(type) { - case driver.SessionResetter: - return true - default: - return false - } + _, ok := conn.(driver.SessionResetter) + return ok } func (s *SessionResetter) ResetSession(ctx context.Context) error { - switch c := s.Conn.Conn.(type) { - case driver.SessionResetter: - return c.ResetSession(ctx) - } - - return nil + c := s.Conn.Conn.(driver.SessionResetter) + return c.ResetSession(ctx) } // Stmt implements a database/sql/driver.Stmt From 5b518ef36717c446784eb7c2691a4edb4c4b3a14 Mon Sep 17 00:00:00 2001 From: Suresh Kumar Date: Tue, 26 Feb 2019 17:23:46 +0530 Subject: [PATCH 4/4] Make SessionResetter changes conditional compiled in go1.10+ --- sqlhooks.go | 10 ---------- sqlhooks_1_10.go | 18 +++++++++++++++++ sqlhooks_1_10_interface_test.go | 17 +++++++++++++++++ sqlhooks_interface_test.go | 34 +++++++++++++++------------------ sqlhooks_pre_1_10.go | 17 +++++++++++++++++ 5 files changed, 67 insertions(+), 29 deletions(-) create mode 100644 sqlhooks_1_10.go create mode 100644 sqlhooks_1_10_interface_test.go create mode 100644 sqlhooks_pre_1_10.go diff --git a/sqlhooks.go b/sqlhooks.go index 3343205..2964249 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -235,16 +235,6 @@ type SessionResetter struct { *Conn } -func isSessionResetter(conn driver.Conn) bool { - _, ok := conn.(driver.SessionResetter) - return ok -} - -func (s *SessionResetter) ResetSession(ctx context.Context) error { - c := s.Conn.Conn.(driver.SessionResetter) - return c.ResetSession(ctx) -} - // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt diff --git a/sqlhooks_1_10.go b/sqlhooks_1_10.go new file mode 100644 index 0000000..49095a0 --- /dev/null +++ b/sqlhooks_1_10.go @@ -0,0 +1,18 @@ +// +build go1.10 + +package sqlhooks + +import ( + "context" + "database/sql/driver" +) + +func isSessionResetter(conn driver.Conn) bool { + _, ok := conn.(driver.SessionResetter) + return ok +} + +func (s *SessionResetter) ResetSession(ctx context.Context) error { + c := s.Conn.Conn.(driver.SessionResetter) + return c.ResetSession(ctx) +} diff --git a/sqlhooks_1_10_interface_test.go b/sqlhooks_1_10_interface_test.go new file mode 100644 index 0000000..4db927b --- /dev/null +++ b/sqlhooks_1_10_interface_test.go @@ -0,0 +1,17 @@ +// +build go1.10 + +package sqlhooks + +import "database/sql/driver" + +func init() { + interfaceTestCases = append(interfaceTestCases, + struct { + name string + expectedInterfaces []interface{} + }{ + "ExecerQueryerContextSessionResetter", []interface{}{ + (*driver.ExecerContext)(nil), + (*driver.QueryerContext)(nil), + (*driver.SessionResetter)(nil)}}) +} diff --git a/sqlhooks_interface_test.go b/sqlhooks_interface_test.go index dfb5653..462b12b 100644 --- a/sqlhooks_interface_test.go +++ b/sqlhooks_interface_test.go @@ -10,6 +10,20 @@ import ( "github.com/stretchr/testify/require" ) +var interfaceTestCases = []struct { + name string + expectedInterfaces []interface{} +}{ + {"Basic", []interface{}{(*driver.Conn)(nil)}}, + {"Execer", []interface{}{(*driver.Execer)(nil)}}, + {"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}}, + {"Queryer", []interface{}{(*driver.QueryerContext)(nil)}}, + {"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}}, + {"ExecerQueryerContext", []interface{}{ + (*driver.ExecerContext)(nil), + (*driver.QueryerContext)(nil)}}, +} + type fakeDriver struct{} func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { @@ -100,25 +114,7 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { func TestInterfaces(t *testing.T) { drv := Wrap(&fakeDriver{}, &testHooks{}) - cases := []struct { - name string - expectedInterfaces []interface{} - }{ - {"Basic", []interface{}{(*driver.Conn)(nil)}}, - {"Execer", []interface{}{(*driver.Execer)(nil)}}, - {"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}}, - {"Queryer", []interface{}{(*driver.QueryerContext)(nil)}}, - {"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}}, - {"ExecerQueryerContext", []interface{}{ - (*driver.ExecerContext)(nil), - (*driver.QueryerContext)(nil)}}, - {"ExecerQueryerContextSessionResetter", []interface{}{ - (*driver.ExecerContext)(nil), - (*driver.QueryerContext)(nil), - (*driver.SessionResetter)(nil)}}, - } - - for _, c := range cases { + for _, c := range interfaceTestCases { conn, err := drv.Open(c.name) require.NoErrorf(t, err, "Driver name %s", c.name) diff --git a/sqlhooks_pre_1_10.go b/sqlhooks_pre_1_10.go new file mode 100644 index 0000000..70d8ee3 --- /dev/null +++ b/sqlhooks_pre_1_10.go @@ -0,0 +1,17 @@ +// +build !go1.10 + +package sqlhooks + +import ( + "context" + "database/sql/driver" + "errors" +) + +func isSessionResetter(conn driver.Conn) bool { + return false +} + +func (s *SessionResetter) ResetSession(ctx context.Context) error { + return errors.New("SessionResetter not implemented") +}