Skip to content

Commit

Permalink
Merge pull request #4 from twharmon/no-globals
Browse files Browse the repository at this point in the history
remove some global variables
twharmon authored Jan 16, 2020
2 parents a5edb14 + d89dc1f commit b873252
Showing 7 changed files with 175 additions and 197 deletions.
115 changes: 85 additions & 30 deletions db.go
Original file line number Diff line number Diff line change
@@ -2,43 +2,89 @@ package gosql

import (
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
)

// DB is a wrapper around sql.DB.
type DB struct {
db *sql.DB
}

// SetMaxOpenConns sets the maximum number of open connections to the database.
//
// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
// MaxIdleConns, then MaxIdleConns will be reduced to match the new
// MaxOpenConns limit.
//
// If n <= 0, then there is no limit on the number of open connections.
// The default is 0 (unlimited).
func (db *DB) SetMaxOpenConns(max int) {
db.db.SetMaxOpenConns(max)
}

// SetMaxIdleConns sets the maximum number of connections in the idle
// connection pool.
//
// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
//
// If n <= 0, no idle connections are retained.
//
// The default max idle connections is currently 2. This may change in
// a future release.
func (db *DB) SetMaxIdleConns(max int) {
db.db.SetMaxIdleConns(max)
db *sql.DB
models map[string]*model
}

// Register .
func (db *DB) Register(structs ...interface{}) error {
for _, s := range structs {
if err := db.register(s); err != nil {
return err
}
}
return nil
}

func (db *DB) register(s interface{}) error {
typ := reflect.TypeOf(s)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return fmt.Errorf("you can only register structs, %s found", reflect.TypeOf(s).Kind())
}
m := new(model)
m.typ = typ
m.name = m.typ.Name()
m.table = toSnakeCase(m.name)
m.primaryFieldIndex = -1
for i := 0; i < m.typ.NumField(); i++ {
f := m.typ.Field(i)
tag, ok := f.Tag.Lookup("gosql")
if ok && tag == "-" {
continue
}
if ok && tag == "primary" {
m.primaryFieldIndex = i
}
m.fields = append(m.fields, toSnakeCase(f.Name))
}
if err := db.mustBeValid(m); err != nil {
return err
}
m.fieldCount = len(m.fields)
db.models[m.name] = m
return nil
}

func (db *DB) getModelOf(obj interface{}) (*model, error) {
t := reflect.TypeOf(obj)
if t.Kind() != reflect.Ptr {
return nil, fmt.Errorf("obj must be a pointer to your model struct")
}
t = t.Elem()
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("obj must be a pointer to your model struct")
}
m := db.models[t.Name()]
if m == nil {
return nil, fmt.Errorf("you must first register %s", t.Name())
}
return m, nil
}

func (db *DB) mustBeValid(m *model) error {
if db.models[m.name] != nil {
return fmt.Errorf("model %s found more than once", m.name)
}
if m.primaryFieldIndex < 0 {
return fmt.Errorf("model %s must have one and only one field tagged `gosql:\"primary\"`", m.name)
}
return nil
}

// Insert .
func (db *DB) Insert(obj interface{}) (sql.Result, error) {
m, err := getModelOf(obj)
m, err := db.getModelOf(obj)
if err != nil {
return nil, err
}
@@ -48,7 +94,7 @@ func (db *DB) Insert(obj interface{}) (sql.Result, error) {

// Update .
func (db *DB) Update(obj interface{}) (sql.Result, error) {
m, err := getModelOf(obj)
m, err := db.getModelOf(obj)
if err != nil {
return nil, err
}
@@ -58,7 +104,7 @@ func (db *DB) Update(obj interface{}) (sql.Result, error) {

// Delete .
func (db *DB) Delete(obj interface{}) (sql.Result, error) {
m, err := getModelOf(obj)
m, err := db.getModelOf(obj)
if err != nil {
return nil, err
}
@@ -113,3 +159,12 @@ func (db *DB) ManualDelete(table string) *DeleteQuery {
dq.table = table
return dq
}

var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")

func toSnakeCase(str string) string {
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
}
96 changes: 55 additions & 41 deletions db_test.go
Original file line number Diff line number Diff line change
@@ -11,71 +11,60 @@ import (
"github.com/twharmon/gosql"
)

type User struct {
ID int `gosql:"primary"`
Name string
}

func init() {
if err := gosql.Register(User{}); err != nil {
panic(err)
}
}

func TestDelete(t *testing.T) {
type DeleteModel struct {
ID int `gosql:"primary"`
}
check(t, gosql.Register(DeleteModel{}))
deleteModel := DeleteModel{5}
db, mock, err := getMockDB()
check(t, err)
mock.ExpectExec(`^delete from delete_model where id = \?$`).WithArgs(deleteModel.ID).WillReturnResult(sqlmock.NewResult(0, 1))
type T struct {
ID int `gosql:"primary"`
}
check(t, db.Register(T{}))
deleteModel := T{5}
mock.ExpectExec(`^delete from t where id = \?$`).WithArgs(deleteModel.ID).WillReturnResult(sqlmock.NewResult(0, 1))
_, err = db.Delete(&deleteModel)
check(t, err)
check(t, mock.ExpectationsWereMet())
}

func TestUpdate(t *testing.T) {
type UpdateModel struct {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string
}
check(t, gosql.Register(UpdateModel{}))
updateModel := UpdateModel{5, "foo"}
db, mock, err := getMockDB()
check(t, err)
mock.ExpectExec(`^update update_model set name = \? where id = \?$`).WithArgs(updateModel.Name, updateModel.ID).WillReturnResult(sqlmock.NewResult(0, 1))
check(t, db.Register(T{}))
updateModel := T{5, "foo"}
mock.ExpectExec(`^update t set name = \? where id = \?$`).WithArgs(updateModel.Name, updateModel.ID).WillReturnResult(sqlmock.NewResult(0, 1))
_, err = db.Update(&updateModel)
check(t, err)
check(t, mock.ExpectationsWereMet())
}

func TestInsert(t *testing.T) {
type InsertModel struct {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string
}
check(t, gosql.Register(InsertModel{}))
insertModel := InsertModel{Name: "foo"}
db, mock, err := getMockDB()
check(t, err)
mock.ExpectExec(`^insert into insert_model \(name\) values \(\?\)$`).WithArgs(insertModel.Name).WillReturnResult(sqlmock.NewResult(0, 1))
check(t, db.Register(T{}))
insertModel := T{Name: "foo"}
mock.ExpectExec(`^insert into t \(name\) values \(\?\)$`).WithArgs(insertModel.Name).WillReturnResult(sqlmock.NewResult(0, 1))
_, err = db.Insert(&insertModel)
check(t, err)
check(t, mock.ExpectationsWereMet())
}

func TestInsertWithPrimary(t *testing.T) {
type InsertWithPrimaryModel struct {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string
}
check(t, gosql.Register(InsertWithPrimaryModel{}))
insertModelWithPrimary := InsertWithPrimaryModel{5, "foo"}
db, mock, err := getMockDB()
check(t, err)
mock.ExpectExec(`^insert into insert_with_primary_model \(id, name\) values \(\?, \?\)$`).WithArgs(insertModelWithPrimary.ID, insertModelWithPrimary.Name).WillReturnResult(sqlmock.NewResult(0, 1))
check(t, db.Register(T{}))
insertModelWithPrimary := T{5, "foo"}
mock.ExpectExec(`^insert into t \(id, name\) values \(\?, \?\)$`).WithArgs(insertModelWithPrimary.ID, insertModelWithPrimary.Name).WillReturnResult(sqlmock.NewResult(0, 1))
_, err = db.Insert(&insertModelWithPrimary)
check(t, err)
check(t, mock.ExpectationsWereMet())
@@ -85,12 +74,12 @@ func ExampleDB_Insert() {
os.Remove("/tmp/foo.db")
sqliteDB, _ := sql.Open("sqlite3", "/tmp/foo.db")
sqliteDB.Exec("create table user (id integer not null primary key, name text); delete from user")
db := gosql.Conn(sqliteDB)
db := gosql.New(sqliteDB)
type User struct {
ID int `gosql:"primary"`
Name string
}
gosql.Register(User{})
db.Register(User{})
db.Insert(&User{Name: "Gopher"})
var user User
db.Select("*").To(&user)
@@ -102,12 +91,12 @@ func ExampleDB_Update() {
os.Remove("/tmp/foo.db")
sqliteDB, _ := sql.Open("sqlite3", "/tmp/foo.db")
sqliteDB.Exec("create table user (id integer not null primary key, name text); delete from user")
db := gosql.Conn(sqliteDB)
db := gosql.New(sqliteDB)
type User struct {
ID int `gosql:"primary"`
Name string
}
gosql.Register(User{})
db.Register(User{})
user := User{ID: 5, Name: "Gopher"}
db.Insert(&user)
user.Name = "Gofer"
@@ -122,12 +111,12 @@ func ExampleDB_Delete() {
os.Remove("/tmp/foo.db")
sqliteDB, _ := sql.Open("sqlite3", "/tmp/foo.db")
sqliteDB.Exec("create table user (id integer not null primary key, name text); delete from user")
db := gosql.Conn(sqliteDB)
db := gosql.New(sqliteDB)
type User struct {
ID int `gosql:"primary"`
Name string
}
gosql.Register(User{})
db.Register(User{})
user := User{ID: 5, Name: "Gopher"}
db.Insert(&user)
db.Delete(&user)
@@ -139,6 +128,11 @@ func ExampleDB_Delete() {

func BenchmarkInsert(b *testing.B) {
db := getSQLiteDB(b, "create table user (id integer not null primary key, name text); delete from user")
type User struct {
ID int `gosql:"primary"`
Name string
}
db.Register(User{})
user := User{Name: "Gopher"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -149,6 +143,11 @@ func BenchmarkInsert(b *testing.B) {

func BenchmarkUpdate(b *testing.B) {
db := getSQLiteDB(b, "create table user (id integer not null primary key, name text); delete from user")
type User struct {
ID int `gosql:"primary"`
Name string
}
db.Register(User{})
user := User{Name: "Gopher"}
_, err := db.Insert(&user)
check(b, err)
@@ -161,6 +160,11 @@ func BenchmarkUpdate(b *testing.B) {

func BenchmarkSelect(b *testing.B) {
db := getSQLiteDB(b, "create table user (id integer not null primary key, name text); delete from user")
type User struct {
ID int `gosql:"primary"`
Name string
}
db.Register(User{})
user := User{ID: 5, Name: "Gopher"}
_, err := db.Insert(&user)
check(b, err)
@@ -173,6 +177,11 @@ func BenchmarkSelect(b *testing.B) {

func BenchmarkSelectMany(b *testing.B) {
db := getSQLiteDB(b, "create table user (id integer not null primary key, name text); delete from user")
type User struct {
ID int `gosql:"primary"`
Name string
}
db.Register(User{})
user := User{Name: "Gopher"}
for i := 0; i < 100; i++ {
_, err := db.Insert(&user)
@@ -187,6 +196,11 @@ func BenchmarkSelectMany(b *testing.B) {

func BenchmarkSelectManyPtrs(b *testing.B) {
db := getSQLiteDB(b, "create table user (id integer not null primary key, name text); delete from user")
type User struct {
ID int `gosql:"primary"`
Name string
}
db.Register(User{})
user := User{Name: "Gopher"}
for i := 0; i < 100; i++ {
_, err := db.Insert(&user)
65 changes: 5 additions & 60 deletions gosql.go
Original file line number Diff line number Diff line change
@@ -3,73 +3,18 @@ package gosql
import (
"database/sql"
"errors"
"fmt"
"reflect"
"regexp"
"strings"

// mysql driver
_ "github.com/go-sql-driver/mysql"
)

// SizeOfFunc .
type SizeOfFunc func(reflect.StructField) uint64

// ErrNotFound .
var ErrNotFound = errors.New("no result found")

// Register .
func Register(structs ...interface{}) error {
for _, s := range structs {
if err := register(s); err != nil {
return err
}
}
return nil
}

func register(s interface{}) error {
typ := reflect.TypeOf(s)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return fmt.Errorf("you can only register structs, %s found", reflect.TypeOf(s).Kind())
}
m := new(model)
m.typ = typ
m.name = m.typ.Name()
m.table = toSnakeCase(m.name)
m.primaryFieldIndex = -1
for i := 0; i < m.typ.NumField(); i++ {
f := m.typ.Field(i)
tag, ok := f.Tag.Lookup("gosql")
if ok && tag == "-" {
continue
}
if ok && tag == "primary" {
m.primaryFieldIndex = i
}
m.fields = append(m.fields, toSnakeCase(f.Name))
// New returns a reference to DB.
func New(db *sql.DB) *DB {
return &DB{
db: db,
models: make(map[string]*model),
}
if err := m.mustBeValid(); err != nil {
return err
}
m.fieldCount = len(m.fields)
models[m.name] = m
return nil
}

// Conn returns a reference to DB.
func Conn(db *sql.DB) *DB {
return &DB{db}
}

var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")

func toSnakeCase(str string) string {
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
}
35 changes: 0 additions & 35 deletions model.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gosql

import (
"fmt"
"reflect"
"strings"
)
@@ -15,14 +14,6 @@ type model struct {
primaryFieldIndex int
}

type modelMap map[string]*model

var models modelMap

func init() {
models = make(modelMap)
}

func (m *model) getInsertQuery(v reflect.Value) string {
var query strings.Builder
var values strings.Builder
@@ -83,16 +74,6 @@ func (m *model) getUpdateQuery() string {
return query.String()
}

func (m *model) mustBeValid() error {
if models[m.name] != nil {
return fmt.Errorf("model %s found more than once", m.name)
}
if m.primaryFieldIndex < 0 {
return fmt.Errorf("model %s must have one and only one field tagged `gosql:\"primary\"`", m.name)
}
return nil
}

func (m *model) getFieldIndexByName(name string) int {
for i, f := range m.fields {
if name == f || strings.HasSuffix(name, "."+f) {
@@ -137,19 +118,3 @@ func (m *model) getArgsPrimaryLast(v reflect.Value) []interface{} {
args[m.fieldCount-1] = primArg
return args
}

func getModelOf(obj interface{}) (*model, error) {
t := reflect.TypeOf(obj)
if t.Kind() != reflect.Ptr {
return nil, fmt.Errorf("obj must be a pointer to your model struct")
}
t = t.Elem()
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("obj must be a pointer to your model struct")
}
m := models[t.Name()]
if m == nil {
return nil, fmt.Errorf("you must first register %s", t.Name())
}
return m, nil
}
6 changes: 3 additions & 3 deletions select_query.go
Original file line number Diff line number Diff line change
@@ -84,7 +84,7 @@ func (sq *SelectQuery) To(out interface{}) error {
t = t.Elem()
switch t.Kind() {
case reflect.Struct:
sq.model = models[t.Name()]
sq.model = sq.db.models[t.Name()]
if sq.model == nil {
return fmt.Errorf("you must first register %s", t.Name())
}
@@ -100,13 +100,13 @@ func (sq *SelectQuery) To(out interface{}) error {
if el.Kind() != reflect.Struct {
break
}
sq.model = models[el.Name()]
sq.model = sq.db.models[el.Name()]
if sq.model == nil {
return fmt.Errorf("you must first register %s", el.Name())
}
return sq.toMany(t, out)
case reflect.Struct:
sq.model = models[el.Name()]
sq.model = sq.db.models[el.Name()]
if sq.model == nil {
return fmt.Errorf("you must first register %s", el.Name())
}
51 changes: 25 additions & 26 deletions select_query_test.go
Original file line number Diff line number Diff line change
@@ -4,54 +4,53 @@ import (
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/twharmon/gosql"
)

func TestSelectQueryOne(t *testing.T) {
type SelectQueryOneModel struct {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string
}
check(t, gosql.Register(SelectQueryOneModel{}))
control := SelectQueryOneModel{
check(t, db.Register(T{}))
control := T{
ID: 5,
Name: "foo",
}
db, mock, err := getMockDB()
check(t, err)
rows := sqlmock.NewRows([]string{"id", "name"})
rows.AddRow(control.ID, control.Name)
mock.ExpectQuery(`^select \* from select_query_one_model where id = \? limit 1$`).WithArgs(control.ID).WillReturnRows(rows)
var test SelectQueryOneModel
mock.ExpectQuery(`^select \* from t where id = \? limit 1$`).WithArgs(control.ID).WillReturnRows(rows)
var test T
check(t, db.Select("*").Where("id = ?", control.ID).To(&test))
check(t, mock.ExpectationsWereMet())
equals(t, control, test)
}

func TestSelectQueryMany(t *testing.T) {
type SelectQueryManyModel struct {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string
}
check(t, gosql.Register(SelectQueryManyModel{}))
control := []*SelectQueryManyModel{
&SelectQueryManyModel{
check(t, db.Register(T{}))
control := []*T{
&T{
ID: 5,
Name: "foo",
},
&SelectQueryManyModel{
&T{
ID: 6,
Name: "bar",
},
}
db, mock, err := getMockDB()
check(t, err)
rows := sqlmock.NewRows([]string{"id", "name"})
for _, c := range control {
rows.AddRow(c.ID, c.Name)
}
mock.ExpectQuery(`^select \* from select_query_many_model limit 10$`).WillReturnRows(rows)
var test []*SelectQueryManyModel
mock.ExpectQuery(`^select \* from t limit 10$`).WillReturnRows(rows)
var test []*T
check(t, db.Select("*").Limit(10).To(&test))
check(t, mock.ExpectationsWereMet())
for i := 0; i < len(control); i++ {
@@ -60,29 +59,29 @@ func TestSelectQueryMany(t *testing.T) {
}

func TestSelectQueryManyValues(t *testing.T) {
type SelectQueryManyValuesModel struct {
db, mock, err := getMockDB()
check(t, err)
type T struct {
ID int `gosql:"primary"`
Name string
}
check(t, gosql.Register(SelectQueryManyValuesModel{}))
control := []SelectQueryManyValuesModel{
SelectQueryManyValuesModel{
check(t, db.Register(T{}))
control := []T{
T{
ID: 5,
Name: "foo",
},
SelectQueryManyValuesModel{
T{
ID: 6,
Name: "bar",
},
}
db, mock, err := getMockDB()
check(t, err)
rows := sqlmock.NewRows([]string{"id", "name"})
for _, c := range control {
rows.AddRow(c.ID, c.Name)
}
mock.ExpectQuery(`^select \* from select_query_many_values_model limit 10$`).WillReturnRows(rows)
var test []SelectQueryManyValuesModel
mock.ExpectQuery(`^select \* from t limit 10$`).WillReturnRows(rows)
var test []T
check(t, db.Select("*").Limit(10).To(&test))
check(t, mock.ExpectationsWereMet())
for i := 0; i < len(control); i++ {
4 changes: 2 additions & 2 deletions testing_test.go
Original file line number Diff line number Diff line change
@@ -35,13 +35,13 @@ func getMockDB() (*gosql.DB, sqlmock.Sqlmock, error) {
if err != nil {
panic(err)
}
return gosql.Conn(db), mock, err
return gosql.New(db), mock, err
}

func getSQLiteDB(f fataler, q string) *gosql.DB {
os.Remove("/tmp/foo.db")
sqliteDB, err := sql.Open("sqlite3", "/tmp/foo.db")
check(f, err)
sqliteDB.Exec(q)
return gosql.Conn(sqliteDB)
return gosql.New(sqliteDB)
}

0 comments on commit b873252

Please sign in to comment.