Skip to content

Commit

Permalink
Merge pull request #23 from twharmon/transactions
Browse files Browse the repository at this point in the history
Fix transaction approach
  • Loading branch information
twharmon authored Sep 3, 2020
2 parents e31536b + 29fca13 commit 5308b6d
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 7 deletions.
9 changes: 8 additions & 1 deletion count_query.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package gosql

import (
"database/sql"
"fmt"
"strings"
)

// QueryRower .
type QueryRower interface {
QueryRow(query string, args ...interface{}) *sql.Row
}

// CountQuery is a query for counting rows in a table.
type CountQuery struct {
db *DB
queryRower QueryRower
count string
table string
joins []string
Expand Down Expand Up @@ -55,7 +62,7 @@ func (cq *CountQuery) LeftJoin(join string) *CountQuery {
// Exec executes the query.
func (cq *CountQuery) Exec() (int64, error) {
var count int64
row := cq.db.db.QueryRow(cq.String(), cq.whereArgs...)
row := cq.queryRower.QueryRow(cq.String(), cq.whereArgs...)
err := row.Scan(&count)
return count, err
}
Expand Down
4 changes: 4 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
func (db *DB) Select(fields ...string) *SelectQuery {
sq := new(SelectQuery)
sq.db = db
sq.querier = db.db
sq.fields = fields
return sq
}
Expand All @@ -147,6 +148,7 @@ func (db *DB) Select(fields ...string) *SelectQuery {
func (db *DB) ManualUpdate(table string) *UpdateQuery {
uq := new(UpdateQuery)
uq.db = db
uq.execer = db.db
uq.table = table
return uq
}
Expand All @@ -155,6 +157,7 @@ func (db *DB) ManualUpdate(table string) *UpdateQuery {
func (db *DB) Count(table string, count string) *CountQuery {
cq := new(CountQuery)
cq.db = db
cq.queryRower = db.db
cq.table = table
cq.count = count
return cq
Expand All @@ -164,6 +167,7 @@ func (db *DB) Count(table string, count string) *CountQuery {
func (db *DB) ManualDelete(table string) *DeleteQuery {
dq := new(DeleteQuery)
dq.db = db
dq.execer = db.db
dq.table = table
return dq
}
Expand Down
15 changes: 15 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ func TestInsertWithPrimary(t *testing.T) {
check(t, mock.ExpectationsWereMet())
}

func TestInsertWithAllFieldsPrimary(t *testing.T) {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string `gosql:"primary"`
}
check(t, db.Register(T{}))
model := T{5, "foo"}
mock.ExpectExec(`^insert into t \(id, name\) values \(\?, \?\)$`).WithArgs(model.ID, model.Name).WillReturnResult(sqlmock.NewResult(0, 1))
_, err = db.Insert(&model)
check(t, err)
check(t, mock.ExpectationsWereMet())
}

func ExampleDB_Insert() {
os.Remove("/tmp/foo.db")
sqliteDB, _ := sql.Open("sqlite3", "/tmp/foo.db")
Expand Down
3 changes: 2 additions & 1 deletion delete_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
// DeleteQuery is a query for deleting rows from a table.
type DeleteQuery struct {
db *DB
execer Execer
table string
joins []string
wheres []*where
Expand Down Expand Up @@ -51,7 +52,7 @@ func (dq *DeleteQuery) LeftJoin(join string) *DeleteQuery {

// Exec executes the query.
func (dq *DeleteQuery) Exec() (sql.Result, error) {
return dq.db.db.Exec(dq.String(), dq.whereArgs...)
return dq.execer.Exec(dq.String(), dq.whereArgs...)
}

// String returns the string representation of DeleteQuery.
Expand Down
2 changes: 1 addition & 1 deletion model.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (m *model) getInsertQuery(v reflect.Value) string {
if i == len(m.fields)-1 {
query.WriteString(") ")
values.WriteString(")")
} else if !isIntIn(i+1, m.primaryFieldIndecies) || !isIntIn(len(m.fields)-1, m.primaryFieldIndecies) {
} else if !isIntIn(i+1, m.primaryFieldIndecies) || !isIntIn(len(m.fields)-1, m.primaryFieldIndecies) || !v.Field(i).IsZero() {
query.WriteString(", ")
values.WriteString(", ")
}
Expand Down
13 changes: 10 additions & 3 deletions select_query.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gosql

import (
"database/sql"
"errors"
"fmt"
"reflect"
Expand All @@ -18,9 +19,15 @@ type having struct {
condition string
}

// Querier .
type Querier interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
}

// SelectQuery holds information for a select query.
type SelectQuery struct {
db *DB
querier Querier
model *model
fields []string
joins []string
Expand Down Expand Up @@ -168,7 +175,7 @@ func (sq *SelectQuery) toOne(out interface{}) error {
}
args := sq.whereArgs
args = append(args, sq.havingArgs...)
rows, err := sq.db.db.Query(sq.String(), args...)
rows, err := sq.querier.Query(sq.String(), args...)
if err != nil {
return err
}
Expand Down Expand Up @@ -196,7 +203,7 @@ func (sq *SelectQuery) toMany(sliceType reflect.Type, outs interface{}) error {
sq.many = true
args := sq.whereArgs
args = append(args, sq.havingArgs...)
rows, err := sq.db.db.Query(sq.String(), args...)
rows, err := sq.querier.Query(sq.String(), args...)
if err != nil {
return err
}
Expand Down Expand Up @@ -231,7 +238,7 @@ func (sq *SelectQuery) toManyValues(sliceType reflect.Type, outs interface{}) er
sq.many = true
args := sq.whereArgs
args = append(args, sq.havingArgs...)
rows, err := sq.db.db.Query(sq.String(), args...)
rows, err := sq.querier.Query(sq.String(), args...)
if err != nil {
return err
}
Expand Down
108 changes: 108 additions & 0 deletions tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package gosql

import (
"database/sql"
"reflect"
)

// Tx .
type Tx struct {
tx *sql.Tx
db *DB
}

// Commit .
func (t *Tx) Commit() error {
return t.tx.Commit()
}

// Rollback .
func (t *Tx) Rollback() error {
return t.tx.Rollback()
}

// Insert insterts a row in the database.
func (t *Tx) Insert(obj interface{}) (sql.Result, error) {
m, err := t.db.getModelOf(obj)
if err != nil {
return nil, err
}
v := reflect.ValueOf(obj).Elem()
return t.tx.Exec(m.getInsertQuery(v), m.getArgs(v)...)
}

// Update updates a row in the database.
func (t *Tx) Update(obj interface{}) (sql.Result, error) {
m, err := t.db.getModelOf(obj)
if err != nil {
return nil, err
}
v := reflect.ValueOf(obj).Elem()
return t.tx.Exec(m.getUpdateQuery(), m.getArgsPrimaryLast(v)...)
}

// Delete deletes a row from the database.
func (t *Tx) Delete(obj interface{}) (sql.Result, error) {
m, err := t.db.getModelOf(obj)
if err != nil {
return nil, err
}
v := reflect.ValueOf(obj).Elem()
var inserts []interface{}
for _, i := range m.primaryFieldIndecies {
inserts = append(inserts, v.Field(i).Interface())
}
return t.tx.Exec(m.getDeleteQuery(), inserts...)
}

// Exec is a wrapper around sql.DB.Exec().
func (t *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return t.tx.Exec(query, args...)
}

// Query is a wrapper around sql.DB.Query().
func (t *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
return t.tx.Query(query, args...)
}

// QueryRow is a wrapper around sql.DB.QueryRow().
func (t *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
return t.tx.QueryRow(query, args...)
}

// Select selects columns of a table.
func (t *Tx) Select(fields ...string) *SelectQuery {
sq := new(SelectQuery)
sq.db = t.db
sq.querier = t.tx
sq.fields = fields
return sq
}

// ManualUpdate starts a query for manually updating rows in a table.
func (t *Tx) ManualUpdate(table string) *UpdateQuery {
uq := new(UpdateQuery)
uq.db = t.db
uq.execer = t.tx
uq.table = table
return uq
}

// Count starts a query for counting rows in a table.
func (t *Tx) Count(table string, count string) *CountQuery {
cq := new(CountQuery)
cq.db = t.db
cq.queryRower = t.tx
cq.table = table
cq.count = count
return cq
}

// ManualDelete starts a query for manually deleting rows in a table.
func (t *Tx) ManualDelete(table string) *DeleteQuery {
dq := new(DeleteQuery)
dq.db = t.db
dq.execer = t.tx
dq.table = table
return dq
}
8 changes: 7 additions & 1 deletion update_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
// UpdateQuery holds information for an update query.
type UpdateQuery struct {
db *DB
execer Execer
table string
joins []string
wheres []*where
Expand All @@ -17,6 +18,11 @@ type UpdateQuery struct {
setArgs []interface{}
}

// Execer .
type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}

// Where specifies which rows will be returned.
func (uq *UpdateQuery) Where(condition string, args ...interface{}) *UpdateQuery {
w := &where{
Expand Down Expand Up @@ -62,7 +68,7 @@ func (uq *UpdateQuery) LeftJoin(join string) *UpdateQuery {
func (uq *UpdateQuery) Exec() (sql.Result, error) {
args := uq.setArgs
args = append(args, uq.whereArgs...)
return uq.db.db.Exec(uq.String(), args...)
return uq.execer.Exec(uq.String(), args...)
}

// String returns the string representation of UpdateQuery.
Expand Down

0 comments on commit 5308b6d

Please sign in to comment.