Skip to content

Commit

Permalink
db: using transactions from the sqlx package instead of native sql;
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Kamieth authored and j4k4 committed Nov 8, 2024
1 parent f54d343 commit 9a5db74
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
14 changes: 7 additions & 7 deletions pkg/db/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type Client interface {
Select(ctx context.Context, dest any, query string, args ...any) error
NamedSelect(ctx context.Context, dest any, query string, arg any) error
Get(ctx context.Context, dest any, query string, args ...any) error
WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sql.Tx) error) error
WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sqlx.Tx) error) error
}

type ClientSqlx struct {
Expand Down Expand Up @@ -226,7 +226,7 @@ func (c *ClientSqlx) NamedExec(ctx context.Context, query string, arg any) (sql.
}

func (c *ClientSqlx) ExecMultiInTx(ctx context.Context, sqlers ...Sqler) (results []sql.Result, err error) {
var tx *sql.Tx
var tx *sqlx.Tx
var res sql.Result
var queries []string
var argss [][]any
Expand Down Expand Up @@ -405,21 +405,21 @@ func (c *ClientSqlx) Get(ctx context.Context, dest any, query string, args ...an
return err
}

func (c *ClientSqlx) BeginTx(ctx context.Context, ops *sql.TxOptions) (*sql.Tx, error) {
func (c *ClientSqlx) BeginTx(ctx context.Context, ops *sql.TxOptions) (*sqlx.Tx, error) {
c.logger.Debug("start tx")

res, err := c.executor.Execute(ctx, func(ctx context.Context) (any, error) {
return c.db.BeginTx(ctx, ops)
return c.db.BeginTxx(ctx, ops)
})
if err != nil {
return nil, err
}

return res.(*sql.Tx), err
return res.(*sqlx.Tx), err
}

func (c *ClientSqlx) WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sql.Tx) error) (err error) {
var tx *sql.Tx
func (c *ClientSqlx) WithTx(ctx context.Context, ops *sql.TxOptions, do func(ctx context.Context, tx *sqlx.Tx) error) (err error) {
var tx *sqlx.Tx
tx, err = c.BeginTx(ctx, ops)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/db/driver_cratedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const DriverNameCrateDb = "cratedb"

func init() {
sql.Register(DriverNameCrateDb, stdlib.GetDefaultDriver())
connectionFactories[DriverNameCrateDb] = NewCrateDbDriver
AddDriverFactory(DriverNameCrateDb, NewCrateDbDriver)
}

type crateDbDriver struct{}
Expand Down
8 changes: 6 additions & 2 deletions pkg/db/driver_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ type Driver interface {
GetMigrationDriver(db *sql.DB, database string, migrationsTable string) (database.Driver, error)
}

var connectionFactories = map[string]DriverFactory{}
var driverFactories = map[string]DriverFactory{}

func AddDriverFactory(name string, factory DriverFactory) {
driverFactories[name] = factory
}

func GetDriver(logger log.Logger, driverName string) (Driver, error) {
var ok bool
var err error
var factory DriverFactory
var driver Driver

if factory, ok = connectionFactories[driverName]; !ok {
if factory, ok = driverFactories[driverName]; !ok {
return nil, fmt.Errorf("no driver factory defined for %s", driverName)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/db/driver_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
const DriverMysql = "mysql"

func init() {
connectionFactories[DriverMysql] = NewMysqlDriver
AddDriverFactory(DriverMysql, NewMysqlDriver)
}

func NewMysqlDriver(logger log.Logger) (Driver, error) {
Expand Down Expand Up @@ -70,7 +70,7 @@ type mysqlLogger struct {
logger log.Logger
}

func (m mysqlLogger) Print(v ...interface{}) {
func (m mysqlLogger) Print(v ...any) {
msg := fmt.Sprint(v...)
m.logger.Error(msg)
}
2 changes: 1 addition & 1 deletion pkg/db/driver_redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const DriverNameRedshift = "redshift"

func init() {
sql.Register(DriverNameRedshift, &pq.Driver{})
connectionFactories[DriverNameRedshift] = NewRedshiftDriver
AddDriverFactory(DriverNameRedshift, NewRedshiftDriver)
}

func NewRedshiftDriver(logger log.Logger) (Driver, error) {
Expand Down
12 changes: 6 additions & 6 deletions pkg/db/mocks/Client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9a5db74

Please sign in to comment.