-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package mysql | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/zhufuyi/sponge/pkg/mysql/query" | ||
|
||
"gorm.io/gorm" | ||
) | ||
|
||
// TableName get table name | ||
func TableName(table interface{}) string { | ||
return GetTableName(table) | ||
} | ||
|
||
// Create a new record | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func Create(ctx context.Context, db *gorm.DB, table interface{}) error { | ||
return db.WithContext(ctx).Create(table).Error | ||
} | ||
|
||
// Delete record | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func Delete(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) error { | ||
return db.WithContext(ctx).Where(queryCondition, args...).Delete(table).Error | ||
} | ||
|
||
// DeleteByID delete record by id | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func DeleteByID(ctx context.Context, db *gorm.DB, table interface{}, id interface{}) error { | ||
return db.WithContext(ctx).Where("id = ?", id).Delete(table).Error | ||
} | ||
|
||
// Update record | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func Update(ctx context.Context, db *gorm.DB, table interface{}, column string, value interface{}, queryCondition interface{}, args ...interface{}) error { | ||
return db.WithContext(ctx).Model(table).Where(queryCondition, args...).Update(column, value).Error | ||
} | ||
|
||
// Updates record | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func Updates(ctx context.Context, db *gorm.DB, table interface{}, update KV, queryCondition interface{}, args ...interface{}) error { | ||
return db.WithContext(ctx).Model(table).Where(queryCondition, args...).Updates(update).Error | ||
} | ||
|
||
// Get one record | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func Get(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) error { | ||
return db.WithContext(ctx).Where(queryCondition, args...).First(table).Error | ||
} | ||
|
||
// GetByID get record by id | ||
func GetByID(ctx context.Context, db *gorm.DB, table interface{}, id interface{}) error { | ||
return db.WithContext(ctx).Where("id = ?", id).First(table).Error | ||
} | ||
|
||
// List multiple records, starting from page 0 | ||
// the param of 'tables' must be a slice, eg: []StructName | ||
func List(ctx context.Context, db *gorm.DB, tables interface{}, page *query.Page, queryCondition interface{}, args ...interface{}) error { | ||
return db.WithContext(ctx).Order(page.Sort()).Limit(page.Size()).Offset(page.Offset()).Where(queryCondition, args...).Find(tables).Error | ||
} | ||
|
||
// Count number of records | ||
// the param of 'table' must be pointer, eg: &StructName | ||
func Count(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) (int64, error) { | ||
var count int64 | ||
err := db.WithContext(ctx).Model(table).Where(queryCondition, args...).Count(&count).Error | ||
return count, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
package mysql | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/zhufuyi/sponge/pkg/gotest" | ||
"github.com/zhufuyi/sponge/pkg/mysql/query" | ||
|
||
"github.com/DATA-DOG/go-sqlmock" | ||
"github.com/stretchr/testify/assert" | ||
"gorm.io/gorm" | ||
) | ||
|
||
var table = &userExample{} | ||
|
||
type userExample struct { | ||
Model `gorm:"embedded"` | ||
|
||
Name string `gorm:"type:varchar(40);unique_index;not null" json:"name"` | ||
Age int `gorm:"not null" json:"age"` | ||
Gender string `gorm:"type:varchar(10);not null" json:"gender"` | ||
} | ||
|
||
func newUserExampleDao() *gotest.Dao { | ||
testData := &userExample{Name: "ZhangSan", Age: 20, Gender: "male"} | ||
testData.ID = 1 | ||
testData.CreatedAt = time.Now() | ||
testData.UpdatedAt = testData.CreatedAt | ||
|
||
// init mock dao | ||
d := gotest.NewDao(nil, testData) | ||
|
||
return d | ||
} | ||
|
||
func TestTableName(t *testing.T) { | ||
t.Logf("table name = %s", TableName(&userExample{})) | ||
} | ||
|
||
func TestCreate(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
d.SQLMock.ExpectBegin() | ||
d.SQLMock.ExpectExec("INSERT INTO .*"). | ||
WithArgs(d.GetAnyArgs(testData)...). | ||
WillReturnResult(sqlmock.NewResult(1, 1)) | ||
d.SQLMock.ExpectCommit() | ||
|
||
err := Create(d.Ctx, d.DB, testData) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestDelete(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
d.SQLMock.ExpectBegin() | ||
d.SQLMock.ExpectExec("UPDATE .*"). | ||
WithArgs(d.AnyTime, testData.Name). | ||
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1)) | ||
d.SQLMock.ExpectCommit() | ||
|
||
err := Delete(d.Ctx, d.DB, table, "name = ?", testData.Name) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestDeleteByID(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
d.SQLMock.ExpectBegin() | ||
d.SQLMock.ExpectExec("UPDATE .*"). | ||
WithArgs(d.AnyTime, testData.ID). | ||
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1)) | ||
d.SQLMock.ExpectCommit() | ||
|
||
err := Delete(d.Ctx, d.DB, table, "id = ?", testData.ID) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestUpdate(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
d.SQLMock.ExpectBegin() | ||
d.SQLMock.ExpectExec("UPDATE .*"). | ||
WithArgs(sqlmock.AnyArg(), d.AnyTime, testData.Name). | ||
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1)) | ||
d.SQLMock.ExpectCommit() | ||
|
||
err := Update(d.Ctx, d.DB, table, "age", gorm.Expr("age + ?", 1), "name = ?", testData.Name) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestUpdates(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
d.SQLMock.ExpectBegin() | ||
d.SQLMock.ExpectExec("UPDATE .*"). | ||
WithArgs(sqlmock.AnyArg(), d.AnyTime, testData.Gender). | ||
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1)) | ||
d.SQLMock.ExpectCommit() | ||
|
||
update := KV{"age": gorm.Expr("age + ?", 1)} | ||
err := Updates(d.Ctx, d.DB, table, update, "gender = ?", testData.Gender) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetByID(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}). | ||
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender) | ||
|
||
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(testData.ID).WillReturnRows(rows) | ||
|
||
err := GetByID(d.Ctx, d.DB, table, testData.ID) | ||
assert.NoError(t, err) | ||
|
||
t.Logf("%+v", table) | ||
} | ||
|
||
func TestGet(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}). | ||
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender) | ||
|
||
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnRows(rows) // adjusted for number of fields | ||
|
||
err := Get(d.Ctx, d.DB, table, "name = ?", testData.Name) | ||
assert.NoError(t, err) | ||
|
||
t.Logf("%+v", table) | ||
} | ||
|
||
func TestList(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}). | ||
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender) | ||
|
||
d.SQLMock.ExpectQuery("SELECT .*").WillReturnRows(rows) | ||
|
||
page := query.NewPage(0, 10, "") | ||
tables := []userExample{} | ||
err := List(d.Ctx, d.DB, &tables, page, "") | ||
assert.NoError(t, err) | ||
|
||
for _, user := range tables { | ||
t.Logf("%+v", user) | ||
} | ||
} | ||
|
||
func TestCount(t *testing.T) { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
|
||
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}). | ||
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender) | ||
|
||
d.SQLMock.ExpectQuery("SELECT .*"). | ||
WithArgs(sqlmock.AnyArg()). | ||
WillReturnRows(rows) | ||
|
||
count, err := Count(d.Ctx, d.DB, table, "id > ?", 0) | ||
assert.NotNil(t, err) | ||
|
||
t.Logf("count=%d", count) | ||
} | ||
|
||
func TestTx(t *testing.T) { | ||
err := createUser() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
} | ||
|
||
func createUser() error { | ||
d := newUserExampleDao() | ||
defer d.Close() | ||
testData := d.TestData.(*userExample) | ||
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}). | ||
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender) | ||
d.SQLMock.ExpectBegin() | ||
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnRows(rows) // adjusted for number of fields | ||
d.SQLMock.ExpectCommit() | ||
|
||
// note that you should use tx as the database handle when you are in a transaction | ||
tx := d.DB.Begin() | ||
defer func() { | ||
if err := recover(); err != nil { // rollback after a panic during transaction execution | ||
tx.Rollback() | ||
fmt.Printf("transaction failed, err = %v\n", err) | ||
} | ||
}() | ||
|
||
var err error | ||
if err = tx.Error; err != nil { | ||
return err | ||
} | ||
|
||
if err = tx.WithContext(d.Ctx).Where("id = ?", testData.ID).First(table).Error; err != nil { | ||
tx.Rollback() | ||
return err | ||
} | ||
|
||
panic("mock panic") | ||
|
||
if err = tx.WithContext(d.Ctx).Create(&userExample{Name: "lisi", Age: table.Age + 2, Gender: "male"}).Error; err != nil { | ||
tx.Rollback() | ||
return err | ||
} | ||
|
||
return tx.Commit().Error | ||
} |