diff --git a/pkg/db/client.go b/pkg/db/client.go index bbd23b326..19c79dbcb 100644 --- a/pkg/db/client.go +++ b/pkg/db/client.go @@ -74,7 +74,7 @@ type Client interface { Select(ctx context.Context, dest any, query string, args ...any) error NamedSelect(ctx context.Context, dest any, query string, arg any) error Get(ctx context.Context, dest any, query string, args ...any) error - WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sql.Tx) error) error + WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sqlx.Tx) error) error } type ClientSqlx struct { @@ -226,7 +226,7 @@ func (c *ClientSqlx) NamedExec(ctx context.Context, query string, arg any) (sql. } func (c *ClientSqlx) ExecMultiInTx(ctx context.Context, sqlers ...Sqler) (results []sql.Result, err error) { - var tx *sql.Tx + var tx *sqlx.Tx var res sql.Result var queries []string var argss [][]any @@ -405,21 +405,21 @@ func (c *ClientSqlx) Get(ctx context.Context, dest any, query string, args ...an return err } -func (c *ClientSqlx) BeginTx(ctx context.Context, ops *sql.TxOptions) (*sql.Tx, error) { +func (c *ClientSqlx) BeginTx(ctx context.Context, ops *sql.TxOptions) (*sqlx.Tx, error) { c.logger.Debug("start tx") res, err := c.executor.Execute(ctx, func(ctx context.Context) (any, error) { - return c.db.BeginTx(ctx, ops) + return c.db.BeginTxx(ctx, ops) }) if err != nil { return nil, err } - return res.(*sql.Tx), err + return res.(*sqlx.Tx), err } -func (c *ClientSqlx) WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sql.Tx) error) (err error) { - var tx *sql.Tx +func (c *ClientSqlx) WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sqlx.Tx) error) (err error) { + var tx *sqlx.Tx tx, err = c.BeginTx(ctx, ops) if err != nil { return err diff --git a/pkg/db/driver_cratedb.go b/pkg/db/driver_cratedb.go index 5390983f4..855c6b2c8 100644 --- a/pkg/db/driver_cratedb.go +++ b/pkg/db/driver_cratedb.go @@ -15,7 +15,7 @@ const DriverNameCrateDb = "cratedb" func init() { sql.Register(DriverNameCrateDb, stdlib.GetDefaultDriver()) - connectionFactories[DriverNameCrateDb] = NewCrateDbDriver + AddDriverFactory(DriverNameCrateDb, NewCrateDbDriver) } type crateDbDriver struct{} diff --git a/pkg/db/driver_factory.go b/pkg/db/driver_factory.go index 1cf6eb856..90118507c 100644 --- a/pkg/db/driver_factory.go +++ b/pkg/db/driver_factory.go @@ -15,7 +15,11 @@ type Driver interface { GetMigrationDriver(db *sql.DB, database string, migrationsTable string) (database.Driver, error) } -var connectionFactories = map[string]DriverFactory{} +var driverFactories = map[string]DriverFactory{} + +func AddDriverFactory(name string, factory DriverFactory) { + driverFactories[name] = factory +} func GetDriver(logger log.Logger, driverName string) (Driver, error) { var ok bool @@ -23,7 +27,7 @@ func GetDriver(logger log.Logger, driverName string) (Driver, error) { var factory DriverFactory var driver Driver - if factory, ok = connectionFactories[driverName]; !ok { + if factory, ok = driverFactories[driverName]; !ok { return nil, fmt.Errorf("no driver factory defined for %s", driverName) } diff --git a/pkg/db/driver_mysql.go b/pkg/db/driver_mysql.go index 2ea52529f..07f203c99 100644 --- a/pkg/db/driver_mysql.go +++ b/pkg/db/driver_mysql.go @@ -15,7 +15,7 @@ import ( const DriverMysql = "mysql" func init() { - connectionFactories[DriverMysql] = NewMysqlDriver + AddDriverFactory(DriverMysql, NewMysqlDriver) } func NewMysqlDriver(logger log.Logger) (Driver, error) { @@ -70,7 +70,7 @@ type mysqlLogger struct { logger log.Logger } -func (m mysqlLogger) Print(v ...interface{}) { +func (m mysqlLogger) Print(v ...any) { msg := fmt.Sprint(v...) m.logger.Error(msg) } diff --git a/pkg/db/driver_redshift.go b/pkg/db/driver_redshift.go index e2732551a..6beec8c7f 100644 --- a/pkg/db/driver_redshift.go +++ b/pkg/db/driver_redshift.go @@ -16,7 +16,7 @@ const DriverNameRedshift = "redshift" func init() { sql.Register(DriverNameRedshift, &pq.Driver{}) - connectionFactories[DriverNameRedshift] = NewRedshiftDriver + AddDriverFactory(DriverNameRedshift, NewRedshiftDriver) } func NewRedshiftDriver(logger log.Logger) (Driver, error) { diff --git a/pkg/db/mocks/Client.go b/pkg/db/mocks/Client.go index 80767c9b3..541f824ce 100644 --- a/pkg/db/mocks/Client.go +++ b/pkg/db/mocks/Client.go @@ -1038,7 +1038,7 @@ func (_c *Client_Select_Call) RunAndReturn(run func(context.Context, interface{} } // WithTx provides a mock function with given fields: ctx, ops, do -func (_m *Client) WithTx(ctx context.Context, ops *sql.TxOptions, do func(context.Context, *sql.Tx) error) error { +func (_m *Client) WithTx(ctx context.Context, ops *sql.TxOptions, do func(context.Context, *sqlx.Tx) error) error { ret := _m.Called(ctx, ops, do) if len(ret) == 0 { @@ -1046,7 +1046,7 @@ func (_m *Client) WithTx(ctx context.Context, ops *sql.TxOptions, do func(contex } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *sql.TxOptions, func(context.Context, *sql.Tx) error) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *sql.TxOptions, func(context.Context, *sqlx.Tx) error) error); ok { r0 = rf(ctx, ops, do) } else { r0 = ret.Error(0) @@ -1063,14 +1063,14 @@ type Client_WithTx_Call struct { // WithTx is a helper method to define mock.On call // - ctx context.Context // - ops *sql.TxOptions -// - do func(context.Context , *sql.Tx) error +// - do func(context.Context , *sqlx.Tx) error func (_e *Client_Expecter) WithTx(ctx interface{}, ops interface{}, do interface{}) *Client_WithTx_Call { return &Client_WithTx_Call{Call: _e.mock.On("WithTx", ctx, ops, do)} } -func (_c *Client_WithTx_Call) Run(run func(ctx context.Context, ops *sql.TxOptions, do func(context.Context, *sql.Tx) error)) *Client_WithTx_Call { +func (_c *Client_WithTx_Call) Run(run func(ctx context.Context, ops *sql.TxOptions, do func(context.Context, *sqlx.Tx) error)) *Client_WithTx_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*sql.TxOptions), args[2].(func(context.Context, *sql.Tx) error)) + run(args[0].(context.Context), args[1].(*sql.TxOptions), args[2].(func(context.Context, *sqlx.Tx) error)) }) return _c } @@ -1080,7 +1080,7 @@ func (_c *Client_WithTx_Call) Return(_a0 error) *Client_WithTx_Call { return _c } -func (_c *Client_WithTx_Call) RunAndReturn(run func(context.Context, *sql.TxOptions, func(context.Context, *sql.Tx) error) error) *Client_WithTx_Call { +func (_c *Client_WithTx_Call) RunAndReturn(run func(context.Context, *sql.TxOptions, func(context.Context, *sqlx.Tx) error) error) *Client_WithTx_Call { _c.Call.Return(run) return _c }