Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhufuyi committed Feb 4, 2024
1 parent 19d9d35 commit 61e8328
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 0 deletions.
69 changes: 69 additions & 0 deletions pkg/mysql/crud.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package mysql

import (
"context"

"github.com/zhufuyi/sponge/pkg/mysql/query"

"gorm.io/gorm"
)

// TableName get table name
func TableName(table interface{}) string {
return GetTableName(table)
}

// Create a new record
// the param of 'table' must be pointer, eg: &StructName
func Create(ctx context.Context, db *gorm.DB, table interface{}) error {
return db.WithContext(ctx).Create(table).Error
}

// Delete record
// the param of 'table' must be pointer, eg: &StructName
func Delete(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) error {
return db.WithContext(ctx).Where(queryCondition, args...).Delete(table).Error
}

// DeleteByID delete record by id
// the param of 'table' must be pointer, eg: &StructName
func DeleteByID(ctx context.Context, db *gorm.DB, table interface{}, id interface{}) error {
return db.WithContext(ctx).Where("id = ?", id).Delete(table).Error
}

// Update record
// the param of 'table' must be pointer, eg: &StructName
func Update(ctx context.Context, db *gorm.DB, table interface{}, column string, value interface{}, queryCondition interface{}, args ...interface{}) error {
return db.WithContext(ctx).Model(table).Where(queryCondition, args...).Update(column, value).Error
}

// Updates record
// the param of 'table' must be pointer, eg: &StructName
func Updates(ctx context.Context, db *gorm.DB, table interface{}, update KV, queryCondition interface{}, args ...interface{}) error {
return db.WithContext(ctx).Model(table).Where(queryCondition, args...).Updates(update).Error
}

// Get one record
// the param of 'table' must be pointer, eg: &StructName
func Get(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) error {
return db.WithContext(ctx).Where(queryCondition, args...).First(table).Error
}

// GetByID get record by id
func GetByID(ctx context.Context, db *gorm.DB, table interface{}, id interface{}) error {
return db.WithContext(ctx).Where("id = ?", id).First(table).Error
}

// List multiple records, starting from page 0
// the param of 'tables' must be a slice, eg: []StructName
func List(ctx context.Context, db *gorm.DB, tables interface{}, page *query.Page, queryCondition interface{}, args ...interface{}) error {
return db.WithContext(ctx).Order(page.Sort()).Limit(page.Size()).Offset(page.Offset()).Where(queryCondition, args...).Find(tables).Error
}

// Count number of records
// the param of 'table' must be pointer, eg: &StructName
func Count(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) (int64, error) {
var count int64
err := db.WithContext(ctx).Model(table).Where(queryCondition, args...).Count(&count).Error
return count, err
}
232 changes: 232 additions & 0 deletions pkg/mysql/crud_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package mysql

import (
"fmt"
"testing"
"time"

"github.com/zhufuyi/sponge/pkg/gotest"
"github.com/zhufuyi/sponge/pkg/mysql/query"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

var table = &userExample{}

type userExample struct {
Model `gorm:"embedded"`

Name string `gorm:"type:varchar(40);unique_index;not null" json:"name"`
Age int `gorm:"not null" json:"age"`
Gender string `gorm:"type:varchar(10);not null" json:"gender"`
}

func newUserExampleDao() *gotest.Dao {
testData := &userExample{Name: "ZhangSan", Age: 20, Gender: "male"}
testData.ID = 1
testData.CreatedAt = time.Now()
testData.UpdatedAt = testData.CreatedAt

// init mock dao
d := gotest.NewDao(nil, testData)

return d
}

func TestTableName(t *testing.T) {
t.Logf("table name = %s", TableName(&userExample{}))
}

func TestCreate(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

d.SQLMock.ExpectBegin()
d.SQLMock.ExpectExec("INSERT INTO .*").
WithArgs(d.GetAnyArgs(testData)...).
WillReturnResult(sqlmock.NewResult(1, 1))
d.SQLMock.ExpectCommit()

err := Create(d.Ctx, d.DB, testData)
assert.NoError(t, err)
}

func TestDelete(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

d.SQLMock.ExpectBegin()
d.SQLMock.ExpectExec("UPDATE .*").
WithArgs(d.AnyTime, testData.Name).
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
d.SQLMock.ExpectCommit()

err := Delete(d.Ctx, d.DB, table, "name = ?", testData.Name)
assert.NoError(t, err)
}

func TestDeleteByID(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

d.SQLMock.ExpectBegin()
d.SQLMock.ExpectExec("UPDATE .*").
WithArgs(d.AnyTime, testData.ID).
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
d.SQLMock.ExpectCommit()

err := Delete(d.Ctx, d.DB, table, "id = ?", testData.ID)
assert.NoError(t, err)
}

func TestUpdate(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

d.SQLMock.ExpectBegin()
d.SQLMock.ExpectExec("UPDATE .*").
WithArgs(sqlmock.AnyArg(), d.AnyTime, testData.Name).
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
d.SQLMock.ExpectCommit()

err := Update(d.Ctx, d.DB, table, "age", gorm.Expr("age + ?", 1), "name = ?", testData.Name)
assert.NoError(t, err)
}

func TestUpdates(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

d.SQLMock.ExpectBegin()
d.SQLMock.ExpectExec("UPDATE .*").
WithArgs(sqlmock.AnyArg(), d.AnyTime, testData.Gender).
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
d.SQLMock.ExpectCommit()

update := KV{"age": gorm.Expr("age + ?", 1)}
err := Updates(d.Ctx, d.DB, table, update, "gender = ?", testData.Gender)
assert.NoError(t, err)
}

func TestGetByID(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)

d.SQLMock.ExpectQuery("SELECT .*").WithArgs(testData.ID).WillReturnRows(rows)

err := GetByID(d.Ctx, d.DB, table, testData.ID)
assert.NoError(t, err)

t.Logf("%+v", table)
}

func TestGet(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)

d.SQLMock.ExpectQuery("SELECT .*").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnRows(rows) // adjusted for number of fields

err := Get(d.Ctx, d.DB, table, "name = ?", testData.Name)
assert.NoError(t, err)

t.Logf("%+v", table)
}

func TestList(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)

d.SQLMock.ExpectQuery("SELECT .*").WillReturnRows(rows)

page := query.NewPage(0, 10, "")
tables := []userExample{}
err := List(d.Ctx, d.DB, &tables, page, "")
assert.NoError(t, err)

for _, user := range tables {
t.Logf("%+v", user)
}
}

func TestCount(t *testing.T) {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)

rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)

d.SQLMock.ExpectQuery("SELECT .*").
WithArgs(sqlmock.AnyArg()).
WillReturnRows(rows)

count, err := Count(d.Ctx, d.DB, table, "id > ?", 0)
assert.NotNil(t, err)

t.Logf("count=%d", count)
}

func TestTx(t *testing.T) {
err := createUser()
if err != nil {
t.Fatal(err)
}
}

func createUser() error {
d := newUserExampleDao()
defer d.Close()
testData := d.TestData.(*userExample)
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)
d.SQLMock.ExpectBegin()
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnRows(rows) // adjusted for number of fields
d.SQLMock.ExpectCommit()

// note that you should use tx as the database handle when you are in a transaction
tx := d.DB.Begin()
defer func() {
if err := recover(); err != nil { // rollback after a panic during transaction execution
tx.Rollback()
fmt.Printf("transaction failed, err = %v\n", err)
}
}()

var err error
if err = tx.Error; err != nil {
return err
}

if err = tx.WithContext(d.Ctx).Where("id = ?", testData.ID).First(table).Error; err != nil {
tx.Rollback()
return err
}

panic("mock panic")

if err = tx.WithContext(d.Ctx).Create(&userExample{Name: "lisi", Age: table.Age + 2, Gender: "male"}).Error; err != nil {
tx.Rollback()
return err
}

return tx.Commit().Error
}

0 comments on commit 61e8328

Please sign in to comment.