Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,21 @@ func (drv *Driver) Open(name string) (driver.Conn, error) {
}

wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) {
if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) {
return &ExecerQueryerContextWithSessionResetter{wrapped,
&ExecerContext{wrapped}, &QueryerContext{wrapped},
&SessionResetter{wrapped}}, nil
} else if isExecer(conn) && isQueryer(conn) {
return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped},
&QueryerContext{wrapped}}, nil
} else if isExecer(conn) {
// If conn implements an Execer interface, return a driver.Conn which
// also implements Execer
return &ExecerContext{wrapped}, nil
} else if isQueryer(conn) {
// If conn implements an Queryer interface, return a driver.Conn which
// also implements Queryer
return &QueryerContext{wrapped}, nil
}
return wrapped, nil
}
Expand Down Expand Up @@ -149,6 +160,81 @@ func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Resul
return nil, errors.New("Exec was called when ExecContext was implemented")
}

// QueryerContext implements a database/sql.driver.QueryerContext
type QueryerContext struct {
*Conn
}

func isQueryer(conn driver.Conn) bool {
switch conn.(type) {
case driver.QueryerContext:
return true
case driver.Queryer:
return true
default:
return false
}
}

func (conn *QueryerContext) queryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
switch c := conn.Conn.Conn.(type) {
case driver.QueryerContext:
return c.QueryContext(ctx, query, args)
case driver.Queryer:
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.Query(query, dargs)
default:
// This should not happen
return nil, errors.New("QueryerContext created for a non Queryer driver.Conn")
}
}

func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var err error

list := namedToInterface(args)

// Query `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
return nil, err
}

results, err := conn.queryContext(ctx, query, args)
if err != nil {
return results, handlerErr(ctx, conn.hooks, err, query, list...)
}

if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
return nil, err
}

return results, err
}

// ExecerQueryerContext implements database/sql.driver.ExecerContext and
// database/sql.driver.QueryerContext
type ExecerQueryerContext struct {
*Conn
*ExecerContext
*QueryerContext
}

// ExecerQueryerContext implements database/sql.driver.ExecerContext and
// database/sql.driver.QueryerContext
type ExecerQueryerContextWithSessionResetter struct {
*Conn
*ExecerContext
*QueryerContext
*SessionResetter
}

type SessionResetter struct {
*Conn
}

// Stmt implements a database/sql/driver.Stmt
type Stmt struct {
Stmt driver.Stmt
Expand Down
18 changes: 18 additions & 0 deletions sqlhooks_1_10.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// +build go1.10

package sqlhooks

import (
"context"
"database/sql/driver"
)

func isSessionResetter(conn driver.Conn) bool {
_, ok := conn.(driver.SessionResetter)
return ok
}

func (s *SessionResetter) ResetSession(ctx context.Context) error {
c := s.Conn.Conn.(driver.SessionResetter)
return c.ResetSession(ctx)
}
17 changes: 17 additions & 0 deletions sqlhooks_1_10_interface_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// +build go1.10

package sqlhooks

import "database/sql/driver"

func init() {
interfaceTestCases = append(interfaceTestCases,
struct {
name string
expectedInterfaces []interface{}
}{
"ExecerQueryerContextSessionResetter", []interface{}{
(*driver.ExecerContext)(nil),
(*driver.QueryerContext)(nil),
(*driver.SessionResetter)(nil)}})
}
125 changes: 125 additions & 0 deletions sqlhooks_interface_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package sqlhooks

import (
"context"
"database/sql/driver"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var interfaceTestCases = []struct {
name string
expectedInterfaces []interface{}
}{
{"Basic", []interface{}{(*driver.Conn)(nil)}},
{"Execer", []interface{}{(*driver.Execer)(nil)}},
{"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}},
{"Queryer", []interface{}{(*driver.QueryerContext)(nil)}},
{"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}},
{"ExecerQueryerContext", []interface{}{
(*driver.ExecerContext)(nil),
(*driver.QueryerContext)(nil)}},
}

type fakeDriver struct{}

func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
switch dsn {
case "Basic":
return &struct{ *FakeConnBasic }{}, nil
case "Execer":
return &struct {
*FakeConnBasic
*FakeConnExecer
}{}, nil
case "ExecerContext":
return &struct {
*FakeConnBasic
*FakeConnExecerContext
}{}, nil
case "Queryer":
return &struct {
*FakeConnBasic
*FakeConnQueryer
}{}, nil
case "QueryerContext":
return &struct {
*FakeConnBasic
*FakeConnQueryerContext
}{}, nil
case "ExecerQueryerContext":
return &struct {
*FakeConnBasic
*FakeConnExecerContext
*FakeConnQueryerContext
}{}, nil
case "ExecerQueryerContextSessionResetter":
return &struct {
*FakeConnBasic
*FakeConnExecer
*FakeConnQueryer
*FakeConnSessionResetter
}{}, nil
}

return nil, errors.New("Fake driver not implemented")
}

// Conn implements a database/sql.driver.Conn
type FakeConnBasic struct{}

func (*FakeConnBasic) Prepare(query string) (driver.Stmt, error) {
return nil, errors.New("Not implemented")
}
func (*FakeConnBasic) Close() error {
return errors.New("Not implemented")
}
func (*FakeConnBasic) Begin() (driver.Tx, error) {
return nil, errors.New("Not implemented")
}

type FakeConnExecer struct{}

func (*FakeConnExecer) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, errors.New("Not implemented")
}

type FakeConnExecerContext struct{}

func (*FakeConnExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return nil, errors.New("Not implemented")
}

type FakeConnQueryer struct{}

func (*FakeConnQueryer) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, errors.New("Not implemented")
}

type FakeConnQueryerContext struct{}

func (*FakeConnQueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return nil, errors.New("Not implemented")
}

type FakeConnSessionResetter struct{}

func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error {
return errors.New("Not implemented")
}

func TestInterfaces(t *testing.T) {
drv := Wrap(&fakeDriver{}, &testHooks{})

for _, c := range interfaceTestCases {
conn, err := drv.Open(c.name)
require.NoErrorf(t, err, "Driver name %s", c.name)

for _, i := range c.expectedInterfaces {
assert.Implements(t, i, conn)
}
}
}
17 changes: 17 additions & 0 deletions sqlhooks_pre_1_10.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// +build !go1.10

package sqlhooks

import (
"context"
"database/sql/driver"
"errors"
)

func isSessionResetter(conn driver.Conn) bool {
return false
}

func (s *SessionResetter) ResetSession(ctx context.Context) error {
return errors.New("SessionResetter not implemented")
}