diff --git a/example/bun/user_model.gen.go b/example/bun/user_model.gen.go index d79d380..9b7b212 100644 --- a/example/bun/user_model.gen.go +++ b/example/bun/user_model.gen.go @@ -4,9 +4,11 @@ package model import ( "context" + "database/sql" "fmt" "time" + "github.com/shopspring/decimal" "github.com/uptrace/bun" ) @@ -18,7 +20,7 @@ type UserModel struct { // User represents a user struct data. type User struct { bun.BaseModel `bun:"table:user"` - Id uint64 `bun:"id,pk,autoincrement;" json:"id"` + Id uint64 `bun:"id,pk,autoincrement" json:"id"` Name string `bun:"name" json:"name"` Password string `bun:"password" json:"password"` Mobile string `bun:"mobile" json:"mobile"` @@ -137,7 +139,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { bun.BaseModel `bun:"table:user"` - CountID uint64 `bun:"countID" json:"countID"` + CountID sql.NullInt64 `bun:"countID" json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -148,25 +150,25 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { bun.BaseModel `bun:"table:user"` - CountID uint64 `bun:"countID" json:"countID"` + CountID sql.NullInt64 `bun:"countID" json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { bun.BaseModel `bun:"table:user"` - MaxID uint64 `bun:"maxID" json:"maxID"` + MaxID sql.NullInt64 `bun:"maxID" json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { bun.BaseModel `bun:"table:user"` - MinID uint64 `bun:"minID" json:"minID"` + MinID sql.NullInt64 `bun:"minID" json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { bun.BaseModel `bun:"table:user"` - AvgID uint64 `bun:"avgID" json:"avgID"` + AvgID decimal.NullDecimal `bun:"avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. @@ -222,11 +224,7 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) error { return fmt.Errorf("data is empty") } - var list []User - for _, v := range data { - list = append(list, *v) - } - + list := data[:] _, err := m.db.NewInsert().Model(&list).Exec(ctx) return err } @@ -461,7 +459,8 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("user") + db.Model(&map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -480,7 +479,8 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("user") + db.Model(&map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -496,10 +496,11 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("user") + db.Model(&map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, diff --git a/example/bun/user_model.go b/example/bun/user_model.go index 8528c27..b7ce929 100644 --- a/example/bun/user_model.go +++ b/example/bun/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/example/example.sql b/example/example.sql index 03488ed..79c0d7a 100644 --- a/example/example.sql +++ b/example/example.sql @@ -10,7 +10,6 @@ CREATE TABLE `user` `create_at` timestamp NULL, `update_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY `name_index` (`name`), - UNIQUE KEY `type_index` (`type`), UNIQUE KEY `mobile_index` (`mobile`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'user table' COLLATE=utf8mb4_general_ci; @@ -43,7 +42,7 @@ update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickn -- test case: update one with order by desc, limit count clause. -- fn: UpdateOrderByIdDescLimitCount -update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; +update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; -- operation: read diff --git a/example/example_test/NOTES b/example/example_test/NOTES new file mode 100644 index 0000000..c214b87 --- /dev/null +++ b/example/example_test/NOTES @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 zeromicro + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/example/example_test/bun/mock.go b/example/example_test/bun/mock.go new file mode 100644 index 0000000..3687a30 --- /dev/null +++ b/example/example_test/bun/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/bun" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/bun/user_model.gen_test.go b/example/example_test/bun/user_model.gen_test.go new file mode 100644 index 0000000..a390b56 --- /dev/null +++ b/example/example_test/bun/user_model.gen_test.go @@ -0,0 +1,560 @@ +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/bun" + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/mysqldialect" + "github.com/uptrace/bun/extra/bundebug" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *bun.DB +) + +func TestMain(m *testing.M) { + conn, err := sql.Open("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + if err != nil { + log.Fatalln(err) + } + + db = bun.NewDB(conn, mysqldialect.New()) + err = db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + db.AddQueryHook(bundebug.NewQueryHook( + bundebug.WithVerbose(true), + )) + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(db *bun.DB) { + err := db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.ExecContext(ctx, `SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + return err + } + _, err = tx.ExecContext(ctx, `truncate table user`) + if err != nil { + log.Fatalln(err) + } + _, err = tx.ExecContext(ctx, `alter table user auto_increment=1`) + return err + }) + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/gorm/mock.go b/example/example_test/gorm/mock.go new file mode 100644 index 0000000..8bd62b8 --- /dev/null +++ b/example/example_test/gorm/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/gorm" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/gorm/user_model.gen_test.go b/example/example_test/gorm/user_model.gen_test.go new file mode 100644 index 0000000..96aa3d1 --- /dev/null +++ b/example/example_test/gorm/user_model.gen_test.go @@ -0,0 +1,545 @@ +package sql + +import ( + "context" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/gorm" + "github.com/stretchr/testify/assert" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *gorm.DB +) + +func TestMain(m *testing.M) { + var err error + conn := mysql.Open("root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + db, err = gorm.Open(conn, &gorm.Config{}) + if err != nil { + fmt.Println("gorm open error:", err) + return + } + + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(db *gorm.DB) { + err := db.Transaction(func(tx *gorm.DB) error { + tx.Exec(`SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + tx.Exec(`truncate table user`) + tx.Exec(`alter table user auto_increment=1`) + return nil + }) + if err != nil { + log.Fatal(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/readme.md b/example/example_test/readme.md new file mode 100644 index 0000000..930af29 --- /dev/null +++ b/example/example_test/readme.md @@ -0,0 +1,27 @@ +# example_test + +## before test +1. started docker +2. run a mysql container which dsn is `root:mysqlpw@(localhost:55000)` +3. new a schema `test` +4. create a table use the following sql +```sql +CREATE TABLE `user` +( + `id` bigint(10) unsigned NOT NULL AUTO_INCREMENT primary key, + `name` varchar(255) COLLATE utf8mb4_general_ci NULL COMMENT 'The username', + `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'The \n user password', + `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'The mobile phone number', + `gender` char(10) COLLATE utf8mb4_general_ci NOT NULL COMMENT 'gender,male|female|unknown', + `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT 'The nickname', + `type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT 'The user type, 0:normal,1:vip, for test golang keyword', + `create_at` timestamp NULL, + `update_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY `name_index` (`name`), + UNIQUE KEY `mobile_index` (`mobile`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'user table' COLLATE=utf8mb4_general_ci; +``` +5. clean the test data and set auto_increment to `1` +6. run the test + + diff --git a/example/example_test/sql/mock.go b/example/example_test/sql/mock.go new file mode 100644 index 0000000..6e7a213 --- /dev/null +++ b/example/example_test/sql/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/sql" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/sql/scanner.go b/example/example_test/sql/scanner.go new file mode 100644 index 0000000..253653c --- /dev/null +++ b/example/example_test/sql/scanner.go @@ -0,0 +1,138 @@ +package sql + +import ( + "database/sql" + "errors" + "reflect" + + model "github.com/anqiansong/sqlgen/example/sql" + "github.com/iancoleman/strcase" +) + +type customScanner struct { +} + +func (c customScanner) ColumnMapper(colName string) string { + return strcase.ToCamel(colName) +} + +func (c customScanner) TagKey() string { + return `db` +} + +func (c customScanner) getRowElem(rows *sql.Rows, v interface{}) ([]interface{}, error) { + var elem reflect.Value + value, ok := v.(reflect.Value) + if !ok { + elem = reflect.ValueOf(v) + } else { + elem = value + } + + switch elem.Kind() { + case reflect.Pointer: + return c.getRowElem(rows, elem.Elem()) + case reflect.Struct: + var list []interface{} + cols, err := rows.Columns() + if err != nil { + return nil, err + } + + targetField := make(map[string]reflect.Value) + for i := 0; i < elem.NumField(); i++ { + f := elem.Field(i) + t := elem.Type().Field(i) + tag, ok := t.Tag.Lookup(c.TagKey()) + if ok { + targetField[tag] = f + } + } + + for _, name := range cols { + f, ok := targetField[name] + if !ok { + f = elem.FieldByName(c.ColumnMapper(name)) + } + if f.CanAddr() { + list = append(list, f.Addr().Interface()) + } + } + return list, nil + default: + return nil, errors.New("expect a struct") + } +} + +// getRowsElem is inspired by https://github.com/zeromicro/go-zero/blob/8ed22eafdda04c4526164450d7c13c2f4b0f076c/core/stores/sqlx/orm.go#L163 +func (c customScanner) getRowsElem(rows *sql.Rows, v interface{}) error { + valueOf := reflect.ValueOf(v) + if valueOf.Kind() != reflect.Ptr { + return errors.New("expect a pointer") + } + + typeOf := reflect.TypeOf(v) + sliceTypeOf := typeOf.Elem() + sliceValueOf := valueOf.Elem() + + if sliceTypeOf.Kind() != reflect.Slice { + return errors.New("expect a slice") + } + if !sliceValueOf.CanSet() { + return errors.New("expect a settable slice") + } + isASlicePointer := sliceTypeOf.Elem().Kind() == reflect.Ptr + + var itemReceiver reflect.Type + itemType := sliceTypeOf.Elem() + if itemType.Kind() == reflect.Ptr { + itemReceiver = itemType.Elem() + } else { + itemReceiver = itemType + } + if itemReceiver.Kind() != reflect.Struct { + return errors.New("expect a struct") + } + + for rows.Next() { + value := reflect.New(itemReceiver) + dest, err := c.getRowElem(rows, value) + if err != nil { + return err + } + + err = rows.Scan(dest...) + if err != nil { + return err + } + + if isASlicePointer { + sliceValueOf.Set(reflect.Append(sliceValueOf, value)) + } else { + sliceValueOf.Set(reflect.Append(sliceValueOf, reflect.Indirect(sliceValueOf))) + } + } + + return nil +} + +func (c customScanner) ScanRow(rows *sql.Rows, v interface{}) error { + if !rows.Next() { + return sql.ErrNoRows + } + + dest, err := c.getRowElem(rows, v) + if err != nil { + return err + } + + return rows.Scan(dest...) +} + +func (c customScanner) ScanRows(rows *sql.Rows, v interface{}) error { + return c.getRowsElem(rows, v) +} + +func getScanner() model.Scanner { + return customScanner{} +} diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go new file mode 100644 index 0000000..d793a0c --- /dev/null +++ b/example/example_test/sql/user_model.gen_test.go @@ -0,0 +1,563 @@ +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/sql" + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *sql.DB +) + +func TestMain(m *testing.M) { + var err error + db, err = sql.Open("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + if err != nil { + log.Fatalln(err) + } + + err = db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + um = model.NewUserModel(db, getScanner()) + m.Run() +} + +func mustInitDB(db *sql.DB) { + tx, err := db.Begin() + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `truncate table user`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `alter table user auto_increment=1`) + if err != nil { + log.Fatalln(err) + } + + err = tx.Commit() + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/sqlx/mock.go b/example/example_test/sqlx/mock.go new file mode 100644 index 0000000..a438320 --- /dev/null +++ b/example/example_test/sqlx/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/sqlx" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/sqlx/user_model.gen_test.go b/example/example_test/sqlx/user_model.gen_test.go new file mode 100644 index 0000000..d1a8f11 --- /dev/null +++ b/example/example_test/sqlx/user_model.gen_test.go @@ -0,0 +1,559 @@ +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/sqlx" + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *sqlx.DB +) + +func TestMain(m *testing.M) { + db = sqlx.MustOpen("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + err := db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(db *sqlx.DB) { + tx, err := db.Begin() + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `truncate table user`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `alter table user auto_increment=1`) + if err != nil { + log.Fatalln(err) + } + + err = tx.Commit() + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/xorm/mock.go b/example/example_test/xorm/mock.go new file mode 100644 index 0000000..8dee030 --- /dev/null +++ b/example/example_test/xorm/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/xorm" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/xorm/user_model.gen_test.go b/example/example_test/xorm/user_model.gen_test.go new file mode 100644 index 0000000..e0264e7 --- /dev/null +++ b/example/example_test/xorm/user_model.gen_test.go @@ -0,0 +1,569 @@ +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/xorm" + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "xorm.io/xorm" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db xorm.EngineInterface +) + +func TestMain(m *testing.M) { + var err error + db, err = xorm.NewEngine("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + if err != nil { + log.Fatalln(err) + } + + err = db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(t *testing.T, db xorm.EngineInterface) { + session := db.NewSession() + t.Cleanup(func() { + session.Close() + }) + + err := session.Begin() + if err != nil { + log.Fatalln(err) + } + + _, err = session.Exec(`SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + log.Fatalln(err) + } + + _, err = session.Exec(`truncate table user`) + if err != nil { + log.Fatalln(err) + } + + _, err = session.Exec(`alter table user auto_increment=1`) + if err != nil { + log.Fatalln(err) + } + + err = session.Commit() + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + return func(t *testing.T) { + mustInitDB(t, db) + f(t) + } +} diff --git a/example/go.mod b/example/go.mod index 9445122..937ae7c 100644 --- a/example/go.mod +++ b/example/go.mod @@ -3,24 +3,39 @@ module github.com/anqiansong/sqlgen/example go 1.18 require ( + github.com/go-sql-driver/mysql v1.6.0 + github.com/iancoleman/strcase v0.2.0 github.com/jmoiron/sqlx v1.3.5 + github.com/satori/go.uuid v1.2.0 + github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.7.0 github.com/uptrace/bun v1.1.7 + github.com/uptrace/bun/dialect/mysqldialect v1.1.7 + github.com/uptrace/bun/extra/bundebug v1.1.7 + gorm.io/driver/mysql v1.3.6 gorm.io/gorm v1.23.8 xorm.io/builder v0.3.12 xorm.io/xorm v1.3.1 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.13.0 // indirect github.com/goccy/go-json v0.8.1 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.4 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/syndtr/goleveldb v1.0.0 // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect + golang.org/x/mod v0.5.1 // indirect + golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/example/go.sum b/example/go.sum index 8a1e25d..79fe62d 100644 --- a/example/go.sum +++ b/example/go.sum @@ -58,6 +58,8 @@ github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4s github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -134,6 +136,8 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHLwW0= +github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= @@ -185,8 +189,9 @@ github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= @@ -205,9 +210,11 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -222,14 +229,18 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= @@ -314,10 +325,12 @@ github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThC github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -349,6 +362,10 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/uptrace/bun v1.1.7 h1:biOoh5dov69hQPBlaRsXSHoEOIEnCxFzQvUmbscSNJI= github.com/uptrace/bun v1.1.7/go.mod h1:Z2Pd3cRvNKbrYuL6Gp1XGjA9QEYz+rDz5KkEi9MZLnQ= +github.com/uptrace/bun/dialect/mysqldialect v1.1.7 h1:eMDtsuu5BRuh0P2l0/j0Qv5UBmcqJE0u3F8Zy//klNM= +github.com/uptrace/bun/dialect/mysqldialect v1.1.7/go.mod h1:cCSZH3IULSGaG76Z96mAC7O74MeIYGfDX7CWGanGc0s= +github.com/uptrace/bun/extra/bundebug v1.1.7 h1:YbW7i9pUfPJMzclSzdHslIvAAR0WO9dW34ctL1Xh+UM= +github.com/uptrace/bun/extra/bundebug v1.1.7/go.mod h1:WoBnTrBG9CXITZUw+UfF+DYjWi71boo8FKZGuS5qDzA= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= @@ -398,8 +415,9 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= +golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -453,10 +471,12 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 h1:nonptSpoQ4vQjyraW20DXPAglgQfVnM9ZC6MmNLMR60= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -512,6 +532,7 @@ google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= @@ -529,6 +550,8 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.3.6 h1:BhX1Y/RyALb+T9bZ3t07wLnPZBukt+IRkMn8UZSNbGM= +gorm.io/driver/mysql v1.3.6/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/gorm v1.23.8 h1:h8sGJ+biDgBA1AD1Ha9gFCx7h8npU7AsLdlkX0n2TpE= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/example/gorm/user_model.gen.go b/example/gorm/user_model.gen.go index 1063a06..fe88929 100644 --- a/example/gorm/user_model.gen.go +++ b/example/gorm/user_model.gen.go @@ -4,15 +4,18 @@ package model import ( "context" + "database/sql" "fmt" "time" "gorm.io/gorm" + + "github.com/shopspring/decimal" ) // UserModel represents a user model. type UserModel struct { - db gorm.DB + db *gorm.DB } // User represents a user struct data. @@ -33,16 +36,22 @@ type FindOneWhereParameter struct { IdEqual uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneByNameWhereParameter is a where parameter structure. type FindOneByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameWhereParameter is a where parameter structure. type FindOneGroupByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameHavingNameWhereParameter is a where parameter structure. type FindOneGroupByNameHavingNameWhereParameter struct { NameEqual string @@ -53,6 +62,10 @@ type FindOneGroupByNameHavingNameHavingParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitWhereParameter is a where parameter structure. type FindLimitWhereParameter struct { IdGT uint64 @@ -63,12 +76,16 @@ type FindLimitLimitParameter struct { Count int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitOffsetLimitParameter is a limit parameter structure. type FindLimitOffsetLimitParameter struct { Count int Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupLimitOffsetWhereParameter is a where parameter structure. type FindGroupLimitOffsetWhereParameter struct { IdGT uint64 @@ -80,6 +97,8 @@ type FindGroupLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingLimitOffsetWhereParameter struct { IdGT uint64 @@ -96,6 +115,8 @@ type FindGroupHavingLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderAscLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderAscLimitOffsetWhereParameter struct { IdGT uint64 @@ -112,6 +133,8 @@ type FindGroupHavingOrderAscLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderDescLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderDescLimitOffsetWhereParameter struct { IdGT uint64 @@ -128,14 +151,23 @@ type FindGroupHavingOrderDescLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOnePartWhereParameter is a where parameter structure. type FindOnePartWhereParameter struct { IdGT uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `gorm:"column:countID" json:"countID"` + CountID sql.NullInt64 `gorm:"column:countID" json:"countID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountResult) TableName() string { + return "user" } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -145,22 +177,42 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `gorm:"column:countID" json:"countID"` + CountID sql.NullInt64 `gorm:"column:countID" json:"countID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountWhereResult) TableName() string { + return "user" } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `gorm:"column:maxID" json:"maxID"` + MaxID sql.NullInt64 `gorm:"column:maxID" json:"maxID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMaxIDResult) TableName() string { + return "user" } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `gorm:"column:minID" json:"minID"` + MinID sql.NullInt64 `gorm:"column:minID" json:"minID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMinIDResult) TableName() string { + return "user" } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `gorm:"column:avgID" json:"avgID"` + AvgID decimal.NullDecimal `gorm:"column:avgID" json:"avgID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAvgIDResult) TableName() string { + return "user" } // UpdateWhereParameter is a where parameter structure. @@ -178,6 +230,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -214,7 +271,7 @@ func (User) TableName() string { } // NewUserModel returns a new user model. -func NewUserModel(db gorm.DB) *UserModel { +func NewUserModel(db *gorm.DB) *UserModel { return &UserModel{db: db} } @@ -225,11 +282,7 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) error { } db := m.db.WithContext(ctx) - var list []User - for _, v := range data { - list = append(list, *v) - } - + list := data[:] return db.Create(&list).Error } @@ -238,10 +291,10 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) error { func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id = ?`, where.IdEqual) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -250,10 +303,10 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (* func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`name = ?`, where.NameEqual) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -262,11 +315,11 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupByNameWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`name = ?`, where.NameEqual) - db.Group(`name`) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Group(`name`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -275,12 +328,12 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where FindOneGroupByNameHavingNameWhereParameter, having FindOneGroupByNameHavingNameHavingParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`name = ?`, where.NameEqual) - db.Group(`name`) - db.Having(`name = ?`, having.NameEqual) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Group(`name`) + db = db.Having(`name = ?`, having.NameEqual) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -289,8 +342,8 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find func (m *UserModel) FindAll(ctx context.Context) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Find(&result) + db = db.Select(`*`) + db = db.Find(&result) return result, db.Error } @@ -299,10 +352,10 @@ func (m *UserModel) FindAll(ctx context.Context) ([]*User, error) { func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter, limit FindLimitLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -311,9 +364,9 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -322,11 +375,11 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLimitOffsetWhereParameter, limit FindGroupLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -335,12 +388,12 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGroupHavingLimitOffsetWhereParameter, having FindGroupHavingLimitOffsetHavingParameter, limit FindGroupHavingLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Having(`id > ?`, having.IdGT) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Having(`id > ?`, having.IdGT) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -349,13 +402,13 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, where FindGroupHavingOrderAscLimitOffsetWhereParameter, having FindGroupHavingOrderAscLimitOffsetHavingParameter, limit FindGroupHavingOrderAscLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Having(`id > ?`, having.IdGT) - db.Order(`id`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Having(`id > ?`, having.IdGT) + db = db.Order(`id`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -364,13 +417,13 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, where FindGroupHavingOrderDescLimitOffsetWhereParameter, having FindGroupHavingOrderDescLimitOffsetHavingParameter, limit FindGroupHavingOrderDescLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Having(`id > ?`, having.IdGT) - db.Order(`id desc`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Having(`id > ?`, having.IdGT) + db = db.Order(`id desc`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -379,10 +432,10 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`name, password, mobile`) - db.Where(`id > ?`, where.IdGT) - db.Limit(1) - db.Find(result) + db = db.Select(`name, password, mobile`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -391,9 +444,9 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam func (m *UserModel) FindAllCount(ctx context.Context) (*FindAllCountResult, error) { var result = new(FindAllCountResult) var db = m.db.WithContext(ctx) - db.Select(`count(id) AS countID`) - db.Limit(1) - db.Find(result) + db = db.Select(`count(id) AS countID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -402,10 +455,10 @@ func (m *UserModel) FindAllCount(ctx context.Context) (*FindAllCountResult, erro func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhereWhereParameter) (*FindAllCountWhereResult, error) { var result = new(FindAllCountWhereResult) var db = m.db.WithContext(ctx) - db.Select(`count(id) AS countID`) - db.Where(`id > ?`, where.IdGT) - db.Limit(1) - db.Find(result) + db = db.Select(`count(id) AS countID`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -414,9 +467,9 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe func (m *UserModel) FindMaxID(ctx context.Context) (*FindMaxIDResult, error) { var result = new(FindMaxIDResult) var db = m.db.WithContext(ctx) - db.Select(`max(id) AS maxID`) - db.Limit(1) - db.Find(result) + db = db.Select(`max(id) AS maxID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -425,9 +478,9 @@ func (m *UserModel) FindMaxID(ctx context.Context) (*FindMaxIDResult, error) { func (m *UserModel) FindMinID(ctx context.Context) (*FindMinIDResult, error) { var result = new(FindMinIDResult) var db = m.db.WithContext(ctx) - db.Select(`min(id) AS minID`) - db.Limit(1) - db.Find(result) + db = db.Select(`min(id) AS minID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -436,9 +489,9 @@ func (m *UserModel) FindMinID(ctx context.Context) (*FindMinIDResult, error) { func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { var result = new(FindAvgIDResult) var db = m.db.WithContext(ctx) - db.Select(`avg(id) AS avgID`) - db.Limit(1) - db.Find(result) + db = db.Select(`avg(id) AS avgID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -446,9 +499,9 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { var db = m.db.WithContext(ctx) - db.Model(&User{}) - db.Where(`id = ?`, where.IdEqual) - db.Updates(map[string]interface{}{ + db = db.Model(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Updates(map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -465,10 +518,10 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { var db = m.db.WithContext(ctx) - db.Model(&User{}) - db.Where(`id = ?`, where.IdEqual) - db.Order(`id desc`) - db.Updates(map[string]interface{}{ + db = db.Model(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Order(`id desc`) + db = db.Updates(map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -482,13 +535,14 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { var db = m.db.WithContext(ctx) - db.Model(&User{}) - db.Where(`id = ?`, where.IdEqual) - db.Order(`id desc`) - db.Updates(map[string]interface{}{ + db = db.Model(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Order(`id desc`) + db = db.Limit(limit.Count) + db = db.Updates(map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -505,8 +559,8 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use // delete from `user` where `id` = ?; func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`id = ?`, where.IdEqual) - db.Delete(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Delete(&User{}) return db.Error } @@ -514,8 +568,8 @@ func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter // delete from `user` where `name` = ?; func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Delete(&User{}) return db.Error } @@ -523,9 +577,9 @@ func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWh // delete from `user` where `name` = ? order by id; func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOrderByIDAscWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Order(`id`) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Order(`id`) + db = db.Delete(&User{}) return db.Error } @@ -533,9 +587,9 @@ func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOr // delete from `user` where `name` = ? order by id desc; func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneOrderByIDDescWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Order(`id desc`) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Order(`id desc`) + db = db.Delete(&User{}) return db.Error } @@ -543,9 +597,9 @@ func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneO // delete from `user` where `name` = ? order by id desc limit ?; func (m *UserModel) DeleteOneOrderByIDDescLimitCount(ctx context.Context, where DeleteOneOrderByIDDescLimitCountWhereParameter, limit DeleteOneOrderByIDDescLimitCountLimitParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Order(`id desc`) - db.Limit(limit.Count) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Order(`id desc`) + db = db.Limit(limit.Count) + db = db.Delete(&User{}) return db.Error } diff --git a/example/gorm/user_model.go b/example/gorm/user_model.go index 8528c27..b7ce929 100644 --- a/example/gorm/user_model.go +++ b/example/gorm/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/example/sql/scanner.go b/example/sql/scanner.go index 7306a65..a56452e 100644 --- a/example/sql/scanner.go +++ b/example/sql/scanner.go @@ -3,6 +3,8 @@ package model import "database/sql" type Scanner interface { - ScanRow(row *sql.Row, v interface{}) error + ScanRow(rows *sql.Rows, v interface{}) error ScanRows(rows *sql.Rows, v interface{}) error + ColumnMapper(colName string) string + TagKey() string } diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index cdfd925..81cbcc8 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -8,12 +8,13 @@ import ( "fmt" "time" + "github.com/shopspring/decimal" "xorm.io/builder" ) // UserModel represents a user model. type UserModel struct { - db *sql.Conn + db *sql.DB scanner Scanner } @@ -137,7 +138,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `json:"countID"` + CountID sql.NullInt64 `json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -147,22 +148,22 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `json:"countID"` + CountID sql.NullInt64 `json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `json:"maxID"` + MaxID sql.NullInt64 `json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `json:"minID"` + MinID sql.NullInt64 `json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `json:"avgID"` + AvgID decimal.NullDecimal `json:"avgID"` } // UpdateWhereParameter is a where parameter structure. @@ -180,6 +181,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -211,7 +217,7 @@ type DeleteOneOrderByIDDescLimitCountLimitParameter struct { } // NewUserModel creates a new user model. -func NewUserModel(db *sql.Conn, scanner Scanner) *UserModel { +func NewUserModel(db *sql.DB, scanner Scanner) *UserModel { return &UserModel{ db: db, scanner: scanner, @@ -219,19 +225,17 @@ func NewUserModel(db *sql.Conn, scanner Scanner) *UserModel { } // Create creates user data. -func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { +func (m *UserModel) Create(ctx context.Context, data ...*User) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, v.Name, v.Password, v.Mobile, v.Gender, v.Nickname, v.Type, v.CreateAt, v.UpdateAt) if err != nil { @@ -245,262 +249,305 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { v.Id = uint64(id) } - return + return nil } // FindOne is generated from sql: // select * from `user` where `id` = ? limit 1; func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindOneByName is generated from sql: // select * from `user` where `name` = ? limit 1; func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { + return nil, err + } + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + return result, nil } // FindOneGroupByName is generated from sql: // select * from `user` where `name` = ? group by name limit 1; func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { + return nil, err + } + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindOneGroupByNameHavingName is generated from sql: // select * from `user` where `name` = ? group by name having name = ? limit 1; func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where FindOneGroupByNameHavingNameWhereParameter, having FindOneGroupByNameHavingNameHavingParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) + b.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { + return nil, err + } + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + return result, nil } // FindAll is generated from sql: // select * from `user`; func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { + return nil, err + } + return result, nil } // FindLimit is generated from sql: // select * from `user` where id > ? limit ?; func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter, limit FindLimitLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(limit.Count) query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { + return nil, err + } + return result, nil } // FindLimitOffset is generated from sql: // select * from `user` limit ?, ?; func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } // FindGroupLimitOffset is generated from sql: // select * from `user` where id > ? group by name limit ?, ?; func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLimitOffsetWhereParameter, limit FindGroupLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { + return nil, err + } + return result, nil } // FindGroupHavingLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? limit ?, ?; func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGroupHavingLimitOffsetWhereParameter, having FindGroupHavingLimitOffsetHavingParameter, limit FindGroupHavingLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { + return nil, err + } + return result, nil } // FindGroupHavingOrderAscLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id limit ?, ?; func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, where FindGroupHavingOrderAscLimitOffsetWhereParameter, having FindGroupHavingOrderAscLimitOffsetHavingParameter, limit FindGroupHavingOrderAscLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { + return nil, err + } + return result, nil } // FindGroupHavingOrderDescLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id desc limit ?, ?; func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, where FindGroupHavingOrderDescLimitOffsetWhereParameter, having FindGroupHavingOrderDescLimitOffsetHavingParameter, limit FindGroupHavingOrderDescLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id desc`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil { + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -508,110 +555,165 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe // select `name`, `password`, `mobile` from `user` where id > ? limit 1; func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`name, password, mobile`) + b := builder.MySQL() + b.Select(`name, password, mobile`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { + return nil, err + } + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindAllCount is generated from sql: // select count(id) AS countID from `user`; func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResult, err error) { result = new(FindAllCountResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { + return nil, err + } + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindAllCountWhere is generated from sql: // select count(id) AS countID from `user` where id > ?; func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhereWhereParameter) (result *FindAllCountWhereResult, err error) { result = new(FindAllCountWhereResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { + return nil, err + } + + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindMaxID is generated from sql: // select max(id) AS maxID from `user`; func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err error) { result = new(FindMaxIDResult) - b := builder.Select(`max(id) AS maxID`) + b := builder.MySQL() + b.Select(`max(id) AS maxID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindMinID is generated from sql: // select min(id) AS minID from `user`; func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err error) { result = new(FindMinIDResult) - b := builder.Select(`min(id) AS minID`) + b := builder.MySQL() + b.Select(`min(id) AS minID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindAvgID is generated from sql: // select avg(id) AS avgID from `user`; func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err error) { result = new(FindAvgIDResult) - b := builder.Select(`avg(id) AS avgID`) + b := builder.MySQL() + b.Select(`avg(id) AS avgID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // Update is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -634,7 +736,8 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // UpdateOrderByIdDesc is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -656,9 +759,10 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { - b := builder.Update(builder.Eq{ +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -671,6 +775,7 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.OrderBy(`id desc`) + b.Limit(limit.Count) query, args, err := b.ToSQL() if err != nil { return err @@ -682,7 +787,8 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use // DeleteOne is generated from sql: // delete from `user` where `id` = ?; func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) query, args, err := b.ToSQL() @@ -696,7 +802,8 @@ func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter // DeleteOneByName is generated from sql: // delete from `user` where `name` = ?; func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) query, args, err := b.ToSQL() @@ -710,7 +817,8 @@ func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWh // DeleteOneOrderByIDAsc is generated from sql: // delete from `user` where `name` = ? order by id; func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOrderByIDAscWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id`) @@ -725,7 +833,8 @@ func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOr // DeleteOneOrderByIDDesc is generated from sql: // delete from `user` where `name` = ? order by id desc; func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneOrderByIDDescWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) @@ -740,7 +849,8 @@ func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneO // DeleteOneOrderByIDDescLimitCount is generated from sql: // delete from `user` where `name` = ? order by id desc limit ?; func (m *UserModel) DeleteOneOrderByIDDescLimitCount(ctx context.Context, where DeleteOneOrderByIDDescLimitCountWhereParameter, limit DeleteOneOrderByIDDescLimitCountLimitParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) diff --git a/example/sql/user_model.go b/example/sql/user_model.go index 8528c27..b7ce929 100644 --- a/example/sql/user_model.go +++ b/example/sql/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index 76b2bcb..b3a5331 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jmoiron/sqlx" + "github.com/shopspring/decimal" "xorm.io/builder" ) @@ -137,7 +138,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `db:"countID" json:"countID"` + CountID sql.NullInt64 `db:"countID" json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -147,22 +148,22 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `db:"countID" json:"countID"` + CountID sql.NullInt64 `db:"countID" json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `db:"maxID" json:"maxID"` + MaxID sql.NullInt64 `db:"maxID" json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `db:"minID" json:"minID"` + MinID sql.NullInt64 `db:"minID" json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `db:"avgID" json:"avgID"` + AvgID decimal.NullDecimal `db:"avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. @@ -180,6 +181,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -218,19 +224,17 @@ func NewUserModel(db *sqlx.DB) *UserModel { } // Create creates user data. -func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { +func (m *UserModel) Create(ctx context.Context, data ...*User) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, v.Name, v.Password, v.Mobile, v.Gender, v.Nickname, v.Type, v.CreateAt, v.UpdateAt) if err != nil { @@ -244,165 +248,165 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { v.Id = uint64(id) } - return + return nil } // FindOne is generated from sql: // select * from `user` where `id` = ? limit 1; func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindOneByName is generated from sql: // select * from `user` where `name` = ? limit 1; func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindOneGroupByName is generated from sql: // select * from `user` where `name` = ? group by name limit 1; func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindOneGroupByNameHavingName is generated from sql: // select * from `user` where `name` = ? group by name having name = ? limit 1; func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where FindOneGroupByNameHavingNameWhereParameter, having FindOneGroupByNameHavingNameHavingParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) + b.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAll is generated from sql: // select * from `user`; func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -414,27 +418,28 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { } return result, nil + } // FindLimit is generated from sql: // select * from `user` where id > ? limit ?; func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter, limit FindLimitLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(limit.Count) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -446,26 +451,27 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter } return result, nil + } // FindLimitOffset is generated from sql: // select * from `user` limit ?, ?; func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -477,28 +483,29 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi } return result, nil + } // FindGroupLimitOffset is generated from sql: // select * from `user` where id > ? group by name limit ?, ?; func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLimitOffsetWhereParameter, limit FindGroupLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -510,29 +517,30 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim } return result, nil + } // FindGroupHavingLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? limit ?, ?; func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGroupHavingLimitOffsetWhereParameter, having FindGroupHavingLimitOffsetHavingParameter, limit FindGroupHavingLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -544,30 +552,31 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr } return result, nil + } // FindGroupHavingOrderAscLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id limit ?, ?; func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, where FindGroupHavingOrderAscLimitOffsetWhereParameter, having FindGroupHavingOrderAscLimitOffsetHavingParameter, limit FindGroupHavingOrderAscLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -579,30 +588,31 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher } return result, nil + } // FindGroupHavingOrderDescLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id desc limit ?, ?; func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, where FindGroupHavingOrderDescLimitOffsetWhereParameter, having FindGroupHavingOrderDescLimitOffsetHavingParameter, limit FindGroupHavingOrderDescLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id desc`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -614,212 +624,214 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe } return result, nil + } // FindOnePart is generated from sql: // select `name`, `password`, `mobile` from `user` where id > ? limit 1; func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`name, password, mobile`) + b := builder.MySQL() + b.Select(`name, password, mobile`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAllCount is generated from sql: // select count(id) AS countID from `user`; func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResult, err error) { result = new(FindAllCountResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v FindAllCountResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAllCountWhere is generated from sql: // select count(id) AS countID from `user` where id > ?; func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhereWhereParameter) (result *FindAllCountWhereResult, err error) { result = new(FindAllCountWhereResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v FindAllCountWhereResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindMaxID is generated from sql: // select max(id) AS maxID from `user`; func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err error) { result = new(FindMaxIDResult) - b := builder.Select(`max(id) AS maxID`) + b := builder.MySQL() + b.Select(`max(id) AS maxID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v FindMaxIDResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindMinID is generated from sql: // select min(id) AS minID from `user`; func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err error) { result = new(FindMinIDResult) - b := builder.Select(`min(id) AS minID`) + b := builder.MySQL() + b.Select(`min(id) AS minID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v FindMinIDResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAvgID is generated from sql: // select avg(id) AS avgID from `user`; func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err error) { result = new(FindAvgIDResult) - b := builder.Select(`avg(id) AS avgID`) + b := builder.MySQL() + b.Select(`avg(id) AS avgID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() - for rows.Next() { - var v FindAvgIDResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // Update is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -842,7 +854,8 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // UpdateOrderByIdDesc is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -864,9 +877,10 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { - b := builder.Update(builder.Eq{ +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -879,6 +893,7 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.OrderBy(`id desc`) + b.Limit(limit.Count) query, args, err := b.ToSQL() if err != nil { return err @@ -890,7 +905,8 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use // DeleteOne is generated from sql: // delete from `user` where `id` = ?; func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) query, args, err := b.ToSQL() @@ -904,7 +920,8 @@ func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter // DeleteOneByName is generated from sql: // delete from `user` where `name` = ?; func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) query, args, err := b.ToSQL() @@ -918,7 +935,8 @@ func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWh // DeleteOneOrderByIDAsc is generated from sql: // delete from `user` where `name` = ? order by id; func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOrderByIDAscWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id`) @@ -933,7 +951,8 @@ func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOr // DeleteOneOrderByIDDesc is generated from sql: // delete from `user` where `name` = ? order by id desc; func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneOrderByIDDescWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) @@ -948,7 +967,8 @@ func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneO // DeleteOneOrderByIDDescLimitCount is generated from sql: // delete from `user` where `name` = ? order by id desc limit ?; func (m *UserModel) DeleteOneOrderByIDDescLimitCount(ctx context.Context, where DeleteOneOrderByIDDescLimitCountWhereParameter, limit DeleteOneOrderByIDDescLimitCountLimitParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) diff --git a/example/xorm/user_model.gen.go b/example/xorm/user_model.gen.go index 257a312..63d1692 100644 --- a/example/xorm/user_model.gen.go +++ b/example/xorm/user_model.gen.go @@ -4,10 +4,13 @@ package model import ( "context" + "database/sql" "fmt" "time" "xorm.io/xorm" + + "github.com/shopspring/decimal" ) // UserModel represents a user model. @@ -33,16 +36,22 @@ type FindOneWhereParameter struct { IdEqual uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneByNameWhereParameter is a where parameter structure. type FindOneByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameWhereParameter is a where parameter structure. type FindOneGroupByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameHavingNameWhereParameter is a where parameter structure. type FindOneGroupByNameHavingNameWhereParameter struct { NameEqual string @@ -53,6 +62,10 @@ type FindOneGroupByNameHavingNameHavingParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitWhereParameter is a where parameter structure. type FindLimitWhereParameter struct { IdGT uint64 @@ -63,12 +76,16 @@ type FindLimitLimitParameter struct { Count int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitOffsetLimitParameter is a limit parameter structure. type FindLimitOffsetLimitParameter struct { Count int Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupLimitOffsetWhereParameter is a where parameter structure. type FindGroupLimitOffsetWhereParameter struct { IdGT uint64 @@ -80,6 +97,8 @@ type FindGroupLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingLimitOffsetWhereParameter struct { IdGT uint64 @@ -96,6 +115,8 @@ type FindGroupHavingLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderAscLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderAscLimitOffsetWhereParameter struct { IdGT uint64 @@ -112,6 +133,8 @@ type FindGroupHavingOrderAscLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderDescLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderDescLimitOffsetWhereParameter struct { IdGT uint64 @@ -128,14 +151,23 @@ type FindGroupHavingOrderDescLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOnePartWhereParameter is a where parameter structure. type FindOnePartWhereParameter struct { IdGT uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `xorm:"'countID'" json:"countID"` + CountID sql.NullInt64 `xorm:"'countID'" json:"countID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountResult) TableName() string { + return "user" } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -145,22 +177,42 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `xorm:"'countID'" json:"countID"` + CountID sql.NullInt64 `xorm:"'countID'" json:"countID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountWhereResult) TableName() string { + return "user" } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `xorm:"'maxID'" json:"maxID"` + MaxID sql.NullInt64 `xorm:"'maxID'" json:"maxID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMaxIDResult) TableName() string { + return "user" } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `xorm:"'minID'" json:"minID"` + MinID sql.NullInt64 `xorm:"'minID'" json:"minID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMinIDResult) TableName() string { + return "user" } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `xorm:"'avgID'" json:"avgID"` + AvgID decimal.NullDecimal `xorm:"'avgID'" json:"avgID"` +} + +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAvgIDResult) TableName() string { + return "user" } // UpdateWhereParameter is a where parameter structure. @@ -178,6 +230,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -217,19 +274,19 @@ func NewUserModel(engine xorm.EngineInterface) *UserModel { return &UserModel{engine: engine} } -// Insert creates user data. -func (m *UserModel) Insert(ctx context.Context, data ...*User) error { +// Create creates user data. +func (m *UserModel) Create(ctx context.Context, data ...*User) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var session = m.engine.Context(ctx) - var list []User + var list []interface{} for _, v := range data { - list = append(list, *v) + list = append(list, v) } - _, err := session.Insert(&list) + _, err := session.Insert(list...) return err } @@ -241,7 +298,11 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (* session.Select(`*`) session.Where(`id = ?`, where.IdEqual) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -253,7 +314,11 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP session.Select(`*`) session.Where(`name = ?`, where.NameEqual) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -266,7 +331,11 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy session.Where(`name = ?`, where.NameEqual) session.GroupBy(`name`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -278,9 +347,13 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find session.Select(`*`) session.Where(`name = ?`, where.NameEqual) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) + session.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -338,7 +411,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr session.Select(`*`) session.Where(`id > ?`, where.IdGT) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + session.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) session.Limit(limit.Count, limit.Offset) err := session.Find(&result) return result, err @@ -352,7 +425,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher session.Select(`*`) session.Where(`id > ?`, where.IdGT) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + session.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) session.OrderBy(`id`) session.Limit(limit.Count, limit.Offset) err := session.Find(&result) @@ -367,7 +440,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe session.Select(`*`) session.Where(`id > ?`, where.IdGT) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + session.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) session.OrderBy(`id desc`) session.Limit(limit.Count, limit.Offset) err := session.Find(&result) @@ -382,7 +455,11 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam session.Select(`name, password, mobile`) session.Where(`id > ?`, where.IdGT) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -393,7 +470,11 @@ func (m *UserModel) FindAllCount(ctx context.Context) (*FindAllCountResult, erro var session = m.engine.Context(ctx) session.Select(`count(id) AS countID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -405,7 +486,11 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe session.Select(`count(id) AS countID`) session.Where(`id > ?`, where.IdGT) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -416,7 +501,11 @@ func (m *UserModel) FindMaxID(ctx context.Context) (*FindMaxIDResult, error) { var session = m.engine.Context(ctx) session.Select(`max(id) AS maxID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -427,7 +516,11 @@ func (m *UserModel) FindMinID(ctx context.Context) (*FindMinIDResult, error) { var session = m.engine.Context(ctx) session.Select(`min(id) AS minID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -438,7 +531,11 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { var session = m.engine.Context(ctx) session.Select(`avg(id) AS avgID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -446,6 +543,7 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { var session = m.engine.Context(ctx) + session.Table(&User{}) session.Where(`id = ?`, where.IdEqual) _, err := session.Update(map[string]interface{}{ "name": data.Name, @@ -464,6 +562,7 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { var session = m.engine.Context(ctx) + session.Table(&User{}) session.Where(`id = ?`, where.IdEqual) session.OrderBy(`id desc`) _, err := session.Update(map[string]interface{}{ @@ -480,11 +579,13 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { var session = m.engine.Context(ctx) + session.Table(&User{}) session.Where(`id = ?`, where.IdEqual) session.OrderBy(`id desc`) + session.Limit(limit.Count) _, err := session.Update(map[string]interface{}{ "name": data.Name, "password": data.Password, diff --git a/example/xorm/user_model.go b/example/xorm/user_model.go index 8528c27..b7ce929 100644 --- a/example/xorm/user_model.go +++ b/example/xorm/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/internal/gen/bun/bun_custom.tpl b/internal/gen/bun/bun_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/bun/bun_custom.tpl +++ b/internal/gen/bun/bun_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/bun/bun_gen.tpl b/internal/gen/bun/bun_gen.tpl index 0f7634d..b8aa810 100644 --- a/internal/gen/bun/bun_gen.tpl +++ b/internal/gen/bun/bun_gen.tpl @@ -20,7 +20,7 @@ type {{UpperCamel $.Table.Name}}Model struct { // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. type {{UpperCamel $.Table.Name}} struct { bun.BaseModel `bun:"table:{{$.Table.Name}}"`{{range $.Table.Columns}} - {{UpperCamel .Name}} {{.GoType}} `bun:"{{.Name}}{{if IsPrimary .Name}},pk{{end}}{{if .AutoIncrement}},autoincrement;{{end}}" json:"{{LowerCamel .Name}}"`{{end}} + {{UpperCamel .Name}} {{.GoType}} `bun:"{{.Name}}{{if IsPrimary .Name}},pk{{end}}{{if .AutoIncrement}},autoincrement{{end}}" json:"{{LowerCamel .Name}}"`{{end}} } {{range $stmt := .SelectStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} @@ -50,11 +50,7 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* return fmt.Errorf("data is empty") } - var list []{{UpperCamel $.Table.Name}} - for _,v:=range data{ - list = append(list,*v) - } - + list := data[:] _,err := m.db.NewInsert().Model(&list).Exec(ctx) return err } @@ -81,7 +77,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("{{$.Table.Name}}") + db.Model(&map[string]interface{}{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) diff --git a/internal/gen/gorm/gorm.go b/internal/gen/gorm/gorm.go index e618ec4..568325e 100644 --- a/internal/gen/gorm/gorm.go +++ b/internal/gen/gorm/gorm.go @@ -26,6 +26,9 @@ func Run(list []spec.Context, output string) error { "IsPrimary": func(name string) bool { return ctx.Table.IsPrimary(name) }, + "IsExtraResult": func(name string) bool { + return name != templatex.UpperCamel(ctx.Table.Name) + }, }) gen.MustParse(gormGenTpl) gen.MustExecute(ctx) diff --git a/internal/gen/gorm/gorm_custom.tpl b/internal/gen/gorm/gorm_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/gorm/gorm_custom.tpl +++ b/internal/gen/gorm/gorm_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/gorm/gorm_gen.tpl b/internal/gen/gorm/gorm_gen.tpl index 8e84c97..a49bb4e 100644 --- a/internal/gen/gorm/gorm_gen.tpl +++ b/internal/gen/gorm/gorm_gen.tpl @@ -14,7 +14,7 @@ import ( // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. type {{UpperCamel $.Table.Name}}Model struct { - db gorm.DB + db *gorm.DB } // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. @@ -25,6 +25,10 @@ type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} {{end}}{{$stmt.ReceiverStructure "gorm"}} +// TableName returns the table name. it implemented by gorm.Tabler. +{{if IsExtraResult $stmt.ReceiverName}}func ({{$stmt.ReceiverName}}) TableName() string { + return "{{$.Table.Name}}" +}{{end}} {{end}} {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} @@ -43,7 +47,7 @@ func ({{UpperCamel $.Table.Name}}) TableName() string { } // New{{UpperCamel $.Table.Name}}Model returns a new {{$.Table.Name}} model. -func New{{UpperCamel $.Table.Name}}Model (db gorm.DB) *{{UpperCamel $.Table.Name}}Model { +func New{{UpperCamel $.Table.Name}}Model (db *gorm.DB) *{{UpperCamel $.Table.Name}}Model { return &{{UpperCamel $.Table.Name}}Model{db: db} } @@ -54,11 +58,7 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* } db:=m.db.WithContext(ctx) - var list []{{UpperCamel $.Table.Name}} - for _,v:=range data{ - list = append(list,*v) - } - + list := data[:] return db.Create(&list).Error } {{range $stmt := .SelectStmt}} @@ -67,13 +67,13 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})({{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} error){ var result {{if $stmt.Limit.One}} = new({{$stmt.ReceiverName}}){{else}}[]*{{$stmt.ReceiverName}}{{end}} var db = m.db.WithContext(ctx) - db.Select(`{{$stmt.SelectSQL}}`) - {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) - {{end }}{{if $stmt.GroupBy.IsValid}}db.Group({{$stmt.GroupBy.SQL}}) - {{end}}{{if $stmt.Having.IsValid}}db.Having({{$stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}}) - {{end}}{{if $stmt.OrderBy.IsValid}}db.Order({{$stmt.OrderBy.SQL}}) - {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) - {{end}}db.Find({{if $stmt.Limit.One}}result{{else}}&result{{end}}) + db=db.Select(`{{$stmt.SelectSQL}}`) + {{if $stmt.Where.IsValid}}db=db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) + {{end }}{{if $stmt.GroupBy.IsValid}}db=db.Group({{$stmt.GroupBy.SQL}}) + {{end}}{{if $stmt.Having.IsValid}}db=db.Having({{$stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}}) + {{end}}{{if $stmt.OrderBy.IsValid}}db=db.Order({{$stmt.OrderBy.SQL}}) + {{end}}{{if $stmt.Limit.IsValid}}db=db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) + {{end}}db={{if $stmt.Limit.One}}db.Take(result){{else}}db.Find(&result){{end}} return result, db.Error } {{end}} @@ -83,11 +83,11 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { var db = m.db.WithContext(ctx) - db.Model(&{{UpperCamel $.Table.Name}}{}) - {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) - {{end}}{{if $stmt.OrderBy.IsValid}}db.Order({{$stmt.OrderBy.SQL}}) - {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) - {{end}}db.Updates(map[string]interface{}{ + db=db.Model(&{{UpperCamel $.Table.Name}}{}) + {{if $stmt.Where.IsValid}}db=db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) + {{end}}{{if $stmt.OrderBy.IsValid}}db=db.Order({{$stmt.OrderBy.SQL}}) + {{end}}{{if $stmt.Limit.IsValid}}db=db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) + {{end}}db=db.Updates(map[string]interface{}{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) @@ -100,10 +100,10 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, dat // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { var db = m.db.WithContext(ctx) - {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) - {{end}}{{if $stmt.OrderBy.IsValid}}db.Order({{$stmt.OrderBy.SQL}}) - {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) - {{end}}db.Delete(&{{UpperCamel $.Table.Name}}{}) + {{if $stmt.Where.IsValid}}db=db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) + {{end}}{{if $stmt.OrderBy.IsValid}}db=db.Order({{$stmt.OrderBy.SQL}}) + {{end}}{{if $stmt.Limit.IsValid}}db=db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) + {{end}}db=db.Delete(&{{UpperCamel $.Table.Name}}{}) return db.Error } {{end}} \ No newline at end of file diff --git a/internal/gen/sql/scanner.tpl b/internal/gen/sql/scanner.tpl index 858a4c3..a1ec44d 100644 --- a/internal/gen/sql/scanner.tpl +++ b/internal/gen/sql/scanner.tpl @@ -3,6 +3,8 @@ package model import "database/sql" type Scanner interface { - ScanRow(row *sql.Row, v interface{}) error + ScanRow(rows *sql.Rows, v interface{}) error ScanRows(rows *sql.Rows, v interface{}) error -} \ No newline at end of file + ColumnMapper(colName string) string + TagKey() string +} diff --git a/internal/gen/sql/sql.go b/internal/gen/sql/sql.go index 8816ac8..6bc70e3 100644 --- a/internal/gen/sql/sql.go +++ b/internal/gen/sql/sql.go @@ -62,7 +62,7 @@ func Run(list []spec.Context, output string) error { return strings.Join(values, ", ") }, "HavingSprintf": func(format string) string { - format = strings.ReplaceAll(format, "?", "%v") + format = strings.ReplaceAll(format, "?", "'%v'") return format }, }) diff --git a/internal/gen/sql/sql_custom.tpl b/internal/gen/sql/sql_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/sql/sql_custom.tpl +++ b/internal/gen/sql/sql_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/sql/sql_gen.tpl b/internal/gen/sql/sql_gen.tpl index cd293f4..72867f9 100644 --- a/internal/gen/sql/sql_gen.tpl +++ b/internal/gen/sql/sql_gen.tpl @@ -14,7 +14,7 @@ import ( // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. type {{UpperCamel $.Table.Name}}Model struct { - db *sql.Conn + db *sql.DB scanner Scanner } @@ -40,7 +40,7 @@ type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} // New{{UpperCamel $.Table.Name}}Model creates a new {{$.Table.Name}} model. -func New{{UpperCamel $.Table.Name}}Model(db *sql.Conn, scanner Scanner) *{{UpperCamel $.Table.Name}}Model { +func New{{UpperCamel $.Table.Name}}Model(db *sql.DB, scanner Scanner) *{{UpperCamel $.Table.Name}}Model { return &{{UpperCamel $.Table.Name}}Model{ db: db, scanner: scanner, @@ -48,19 +48,17 @@ func New{{UpperCamel $.Table.Name}}Model(db *sql.Conn, scanner Scanner) *{{Upper } // Create creates {{$.Table.Name}} data. -func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) (err error) { +func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, {{InsertValues "v"}}) if err != nil { @@ -74,14 +72,15 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* {{range $.Table.Columns}}{{if IsPrimary .Name}}{{if .AutoIncrement}}v.{{UpperCamel .Name}} = {{.GoType}}(id){{end}}{{end}}{{end}} } - return + return nil } {{range $stmt := .SelectStmt}} // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})(result {{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} err error){ {{if $stmt.Limit.One}} result = new({{$stmt.ReceiverName}}){{end}} - b := builder.Select(`{{$stmt.SelectSQL}}`) + b := builder.MySQL() + b.Select(`{{$stmt.SelectSQL}}`) b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end }}{{if $stmt.GroupBy.IsValid}}b.GroupBy({{$stmt.GroupBy.SQL}}) @@ -89,27 +88,21 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) {{end}}query, args, err := b.ToSQL() - {{if $stmt.Limit.One}}row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return - {{else}}var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() - if err = m.scanner.ScanRows(rows, result); err != nil{ + defer rows.Close() + + if err = m.scanner. {{if $stmt.Limit.One}}ScanRow{{else}}ScanRows{{end}}(rows, &result); err != nil{ return nil, err } - return result, nil{{end}} + + return result, nil } {{end}} @@ -117,7 +110,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) @@ -138,7 +132,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, dat // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) diff --git a/internal/gen/sqlx/sqlx.go b/internal/gen/sqlx/sqlx.go index 1c87628..61b3dda 100644 --- a/internal/gen/sqlx/sqlx.go +++ b/internal/gen/sqlx/sqlx.go @@ -52,7 +52,7 @@ func Run(list []spec.Context, output string) error { return strings.Join(values, ", ") }, "HavingSprintf": func(format string) string { - format = strings.ReplaceAll(format, "?", "%v") + format = strings.ReplaceAll(format, "?", "'%v'") return format }, }) diff --git a/internal/gen/sqlx/sqlx_gen.tpl b/internal/gen/sqlx/sqlx_gen.tpl index ea86f0a..84d75bf 100644 --- a/internal/gen/sqlx/sqlx_gen.tpl +++ b/internal/gen/sqlx/sqlx_gen.tpl @@ -47,19 +47,17 @@ func New{{UpperCamel $.Table.Name}}Model(db *sqlx.DB) *{{UpperCamel $.Table.Name } // Create creates {{$.Table.Name}} data. -func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) (err error) { +func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, {{InsertValues "v"}}) if err != nil { @@ -73,14 +71,15 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* {{range $.Table.Columns}}{{if IsPrimary .Name}}{{if .AutoIncrement}}v.{{UpperCamel .Name}} = {{.GoType}}(id){{end}}{{end}}{{end}} } - return + return nil } {{range $stmt := .SelectStmt}} // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})(result {{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} err error){ {{if $stmt.Limit.One}} result = new({{$stmt.ReceiverName}}){{end}} - b := builder.Select(`{{$stmt.SelectSQL}}`) + b := builder.MySQL() + b.Select(`{{$stmt.SelectSQL}}`) b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end }}{{if $stmt.GroupBy.IsValid}}b.GroupBy({{$stmt.GroupBy.SQL}}) @@ -88,29 +87,40 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) {{end}}query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() + + {{if $stmt.Limit.One}}if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err + } + + return result, nil + {{else}} for rows.Next() { var v {{$stmt.ReceiverName}} err = rows.StructScan(&v) if err != nil { return nil, err } - {{if $stmt.Limit.One}}result=&v - break{{else}} result = append(result, &v){{end}} + result = append(result, &v) } return result, nil + {{end}} } {{end}} @@ -118,7 +128,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) @@ -139,7 +150,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, dat // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) diff --git a/internal/gen/xorm/xorm.go b/internal/gen/xorm/xorm.go index caa8156..946588f 100644 --- a/internal/gen/xorm/xorm.go +++ b/internal/gen/xorm/xorm.go @@ -27,9 +27,12 @@ func Run(list []spec.Context, output string) error { return ctx.Table.IsPrimary(name) }, "HavingSprintf": func(format string) string { - format = strings.ReplaceAll(format, "?", "%v") + format = strings.ReplaceAll(format, "?", "'%v'") return format }, + "IsExtraResult": func(name string) bool { + return name != templatex.UpperCamel(ctx.Table.Name) + }, }) gen.MustParse(xormGenTpl) gen.MustExecute(ctx) diff --git a/internal/gen/xorm/xorm_custom.tpl b/internal/gen/xorm/xorm_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/xorm/xorm_custom.tpl +++ b/internal/gen/xorm/xorm_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/xorm/xorm_gen.tpl b/internal/gen/xorm/xorm_gen.tpl index 7f4b903..a2bea0b 100644 --- a/internal/gen/xorm/xorm_gen.tpl +++ b/internal/gen/xorm/xorm_gen.tpl @@ -26,6 +26,10 @@ type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} {{end}}{{$stmt.ReceiverStructure "xorm"}} +// TableName returns the table name. it implemented by gorm.Tabler. +{{if IsExtraResult $stmt.ReceiverName}}func ({{$stmt.ReceiverName}}) TableName() string { + return "{{$.Table.Name}}" +}{{end}} {{end}} {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} @@ -47,19 +51,19 @@ func New{{UpperCamel $.Table.Name}}Model (engine xorm.EngineInterface) *{{UpperC return &{{UpperCamel $.Table.Name}}Model{engine: engine} } -// Insert creates {{$.Table.Name}} data. -func (m *{{UpperCamel $.Table.Name}}Model) Insert(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { +// Create creates {{$.Table.Name}} data. +func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { if len(data)==0{ return fmt.Errorf("data is empty") } var session = m.engine.Context(ctx) - var list []{{UpperCamel $.Table.Name}} - for _,v := range data{ - list = append(list,*v) + var list []interface{} + for _, v := range data { + list = append(list, v) } - _,err := session.Insert(&list) + _,err := session.Insert(list...) return err } {{range $stmt := .SelectStmt}} @@ -74,7 +78,11 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if {{end}}{{if $stmt.Having.IsValid}}session.Having(fmt.Sprintf({{HavingSprintf $stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}})) {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) - {{end}}{{if $stmt.Limit.One}}_,{{end}} err := session{{if $stmt.Limit.One}}.Get({{if $stmt.Limit.One}}result{{end}}){{else}}.Find(&result){{end}} + {{end}}{{if $stmt.Limit.One}}has, err := session.Get(result) + if !has{ + return nil, sql.ErrNoRows + } + {{else}}err :=session.Find(&result){{end}} return result, err } {{end}} @@ -84,6 +92,7 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { var session = m.engine.Context(ctx) + session.Table(&{{UpperCamel $.Table.Name}}{}) {{if $stmt.Where.IsValid}}session.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) diff --git a/internal/parser/funcmap.go b/internal/parser/funcmap.go index cfd5ea9..7ce4c6f 100644 --- a/internal/parser/funcmap.go +++ b/internal/parser/funcmap.go @@ -1,92 +1,95 @@ package parser -import "github.com/pingcap/parser/mysql" +import ( + "github.com/anqiansong/sqlgen/internal/spec" + "github.com/pingcap/parser/mysql" +) var funcMap = map[string]byte{ - // mysql.TypeLonglong - "count": mysql.TypeLonglong, - "length": mysql.TypeLonglong, - "char_length": mysql.TypeLonglong, - "locate": mysql.TypeLonglong, - "position": mysql.TypeLonglong, - "instr": mysql.TypeLonglong, - "field": mysql.TypeLonglong, - "sign": mysql.TypeLonglong, - "mod": mysql.TypeLonglong, - "unix_timestamp": mysql.TypeLonglong, - "sysdate": mysql.TypeLonglong, - "utc_date": mysql.TypeLonglong, - "utc_time": mysql.TypeLonglong, - "month": mysql.TypeLonglong, - "day": mysql.TypeLonglong, - "dayofmonth": mysql.TypeLonglong, - "dayofweek": mysql.TypeLonglong, - "dayofyear": mysql.TypeLonglong, - "week": mysql.TypeLonglong, - "weekday": mysql.TypeLonglong, - "weekofyear": mysql.TypeLonglong, - "quarter": mysql.TypeLonglong, - "hour": mysql.TypeLonglong, - "minute": mysql.TypeLonglong, - "second": mysql.TypeLonglong, - "extract": mysql.TypeLonglong, - "time_to_sec": mysql.TypeLonglong, - "to_days": mysql.TypeLonglong, - "to_seconds": mysql.TypeLonglong, - "datadiff": mysql.TypeLonglong, + // spec.TypeNullLongLong + "count": spec.TypeNullLongLong, + "length": spec.TypeNullLongLong, + "char_length": spec.TypeNullLongLong, + "locate": spec.TypeNullLongLong, + "position": spec.TypeNullLongLong, + "instr": spec.TypeNullLongLong, + "field": spec.TypeNullLongLong, + "sign": spec.TypeNullLongLong, + "mod": spec.TypeNullLongLong, + "unix_timestamp": spec.TypeNullLongLong, + "sysdate": spec.TypeNullLongLong, + "utc_date": spec.TypeNullLongLong, + "utc_time": spec.TypeNullLongLong, + "month": spec.TypeNullLongLong, + "day": spec.TypeNullLongLong, + "dayofmonth": spec.TypeNullLongLong, + "dayofweek": spec.TypeNullLongLong, + "dayofyear": spec.TypeNullLongLong, + "week": spec.TypeNullLongLong, + "weekday": spec.TypeNullLongLong, + "weekofyear": spec.TypeNullLongLong, + "quarter": spec.TypeNullLongLong, + "hour": spec.TypeNullLongLong, + "minute": spec.TypeNullLongLong, + "second": spec.TypeNullLongLong, + "extract": spec.TypeNullLongLong, + "time_to_sec": spec.TypeNullLongLong, + "to_days": spec.TypeNullLongLong, + "to_seconds": spec.TypeNullLongLong, + "datadiff": spec.TypeNullLongLong, - // mysql.TypeNewDecimal - "avg": mysql.TypeNewDecimal, - "abs": mysql.TypeNewDecimal, - "ceil": mysql.TypeNewDecimal, - "floor": mysql.TypeNewDecimal, - "round": mysql.TypeNewDecimal, - "rand": mysql.TypeNewDecimal, - "pi": mysql.TypeNewDecimal, - "truncate": mysql.TypeNewDecimal, - "pow": mysql.TypeNewDecimal, - "sqrt": mysql.TypeNewDecimal, - "exp": mysql.TypeNewDecimal, - "log": mysql.TypeNewDecimal, - "log10": mysql.TypeNewDecimal, - "radians": mysql.TypeNewDecimal, - "degrees": mysql.TypeNewDecimal, - "sin": mysql.TypeNewDecimal, - "cos": mysql.TypeNewDecimal, - "tan": mysql.TypeNewDecimal, - "cot": mysql.TypeNewDecimal, - "asin": mysql.TypeNewDecimal, - "acos": mysql.TypeNewDecimal, - "atan": mysql.TypeNewDecimal, + // spec.TypeNullDecimal + "avg": spec.TypeNullDecimal, + "abs": spec.TypeNullDecimal, + "ceil": spec.TypeNullDecimal, + "floor": spec.TypeNullDecimal, + "round": spec.TypeNullDecimal, + "rand": spec.TypeNullDecimal, + "pi": spec.TypeNullDecimal, + "truncate": spec.TypeNullDecimal, + "pow": spec.TypeNullDecimal, + "sqrt": spec.TypeNullDecimal, + "exp": spec.TypeNullDecimal, + "log": spec.TypeNullDecimal, + "log10": spec.TypeNullDecimal, + "radians": spec.TypeNullDecimal, + "degrees": spec.TypeNullDecimal, + "sin": spec.TypeNullDecimal, + "cos": spec.TypeNullDecimal, + "tan": spec.TypeNullDecimal, + "cot": spec.TypeNullDecimal, + "asin": spec.TypeNullDecimal, + "acos": spec.TypeNullDecimal, + "atan": spec.TypeNullDecimal, - // mysql.TypeString - "concat_ws": mysql.TypeString, - "concat": mysql.TypeString, - "insert": mysql.TypeString, - "upper": mysql.TypeString, - "ucaase": mysql.TypeString, - "lower": mysql.TypeString, - "lcase": mysql.TypeString, - "left": mysql.TypeString, - "right": mysql.TypeString, - "lpad": mysql.TypeString, - "rpad": mysql.TypeString, - "replace": mysql.TypeString, - "substring": mysql.TypeString, - "substr": mysql.TypeString, - "trim": mysql.TypeString, - "ltrim": mysql.TypeString, - "rtrim": mysql.TypeString, - "reverse": mysql.TypeString, - "repeat": mysql.TypeString, - "space": mysql.TypeString, - "strcmp": mysql.TypeString, - "mid": mysql.TypeString, - "from_unixtime": mysql.TypeString, - "month_name": mysql.TypeString, - "day_name": mysql.TypeString, - "date_format": mysql.TypeString, - "time_format": mysql.TypeString, + // spec.TypeNullString + "concat_ws": spec.TypeNullString, + "concat": spec.TypeNullString, + "insert": spec.TypeNullString, + "upper": spec.TypeNullString, + "ucaase": spec.TypeNullString, + "lower": spec.TypeNullString, + "lcase": spec.TypeNullString, + "left": spec.TypeNullString, + "right": spec.TypeNullString, + "lpad": spec.TypeNullString, + "rpad": spec.TypeNullString, + "replace": spec.TypeNullString, + "substring": spec.TypeNullString, + "substr": spec.TypeNullString, + "trim": spec.TypeNullString, + "ltrim": spec.TypeNullString, + "rtrim": spec.TypeNullString, + "reverse": spec.TypeNullString, + "repeat": spec.TypeNullString, + "space": spec.TypeNullString, + "strcmp": spec.TypeNullString, + "mid": spec.TypeNullString, + "from_unixtime": spec.TypeNullString, + "month_name": spec.TypeNullString, + "day_name": spec.TypeNullString, + "date_format": spec.TypeNullString, + "time_format": spec.TypeNullString, // mysql.TypeDate "curdate": mysql.TypeDate, diff --git a/internal/parser/parser_test.go b/internal/parser/parser_test.go index 79224a8..20b25b4 100644 --- a/internal/parser/parser_test.go +++ b/internal/parser/parser_test.go @@ -27,7 +27,11 @@ func TestParse(t *testing.T) { ctxOne := ctx[0] selectOne := ctxOne.SelectStmt[0] - selectOne.Where.ParameterStructure("test") + p, err := selectOne.ColumnInfo[0].DataType() + if err != nil { + log.Fatal(err) + } + fmt.Println(p) } func TestFrom(t *testing.T) { diff --git a/internal/parser/select.go b/internal/parser/select.go index 56a8d92..aae4351 100644 --- a/internal/parser/select.go +++ b/internal/parser/select.go @@ -4,13 +4,12 @@ import ( "fmt" "strings" + "github.com/anqiansong/sqlgen/internal/set" + "github.com/anqiansong/sqlgen/internal/spec" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/parser/test_driver" - - "github.com/anqiansong/sqlgen/internal/set" - "github.com/anqiansong/sqlgen/internal/spec" ) func parseSelect(stmt *ast.SelectStmt) (*spec.SelectStmt, error) { @@ -416,9 +415,10 @@ func parseFieldList(fieldList *ast.FieldList, from string) (spec.Fields, string, } selectField = append(selectField, funcSql) columnSet.Add(spec.Field{ - ASName: f.AsName.String(), - ColumnName: columnName, - TP: tp, + ASName: f.AsName.String(), + ColumnName: columnName, + TP: tp, + AggregateCall: aggregate, }) } diff --git a/internal/parser/test.sql b/internal/parser/test.sql index d8401f2..1d87114 100644 --- a/internal/parser/test.sql +++ b/internal/parser/test.sql @@ -16,5 +16,5 @@ CREATE TABLE `user` ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT '用户表' COLLATE=utf8mb4_general_ci; -- fn: test -select * from user where id = ? and name in (?) limit 1; +select avg(id) AS acg from user where id > ?; diff --git a/internal/spec/converter.go b/internal/spec/converter.go index a6eee60..83d2d1c 100644 --- a/internal/spec/converter.go +++ b/internal/spec/converter.go @@ -127,6 +127,10 @@ func convertField(table *Table, fields []Field) (Columns, error) { column, ok := table.GetColumnByName(f.ColumnName) if ok { column.Name = name + column.AggregateCall = f.AggregateCall + if f.TP != mysql.TypeUnspecified { + column.TP = f.TP + } list = append(list, column) } else { return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) @@ -136,8 +140,9 @@ func convertField(table *Table, fields []Field) (Columns, error) { return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) } list = append(list, Column{ - Name: name, - TP: f.TP, + Name: name, + TP: f.TP, + AggregateCall: f.AggregateCall, }) } } diff --git a/internal/spec/stmt.go b/internal/spec/stmt.go index 8114aa8..20ae5e4 100644 --- a/internal/spec/stmt.go +++ b/internal/spec/stmt.go @@ -16,9 +16,10 @@ type Fields []Field // Field represents a select filed. type Field struct { - ASName string - ColumnName string - TP byte + ASName string + ColumnName string + TP byte + AggregateCall bool } // Limit represents a limit clause. diff --git a/internal/spec/table.go b/internal/spec/table.go index 33f0cee..32c6fe7 100644 --- a/internal/spec/table.go +++ b/internal/spec/table.go @@ -28,7 +28,8 @@ type Column struct { // Name is the name of the column. Name string // TP is the type of the column. - TP byte + TP byte + AggregateCall bool } // ColumnOption is a column option. diff --git a/internal/spec/type.go b/internal/spec/type.go index 163a079..1a62939 100644 --- a/internal/spec/type.go +++ b/internal/spec/type.go @@ -1,19 +1,30 @@ package spec import ( + "database/sql" "fmt" + "github.com/anqiansong/sqlgen/internal/parameter" "github.com/pingcap/parser/mysql" +) - "github.com/anqiansong/sqlgen/internal/parameter" +const ( + // TypeNullLongLong is a type extension for mysql.TypeLongLong. + TypeNullLongLong byte = 0xf0 + // TypeNullDecimal is a type extension for mysql.TypeDecimal. + TypeNullDecimal byte = 0xf1 + // TypeNullString is a type extension for mysql.TypeString. + TypeNullString byte = 0xf2 ) const defaultThirdDecimalPkg = "github.com/shopspring/decimal" type typeKey struct { - tp byte - signed bool - thirdPkg string + tp byte + signed bool + thirdPkg string + aggregateCall bool + sql.NullFloat64 } var typeMapper = map[typeKey]string{ @@ -40,7 +51,11 @@ var typeMapper = map[typeKey]string{ typeKey{ tp: mysql.TypeNewDecimal, thirdPkg: defaultThirdDecimalPkg, - }: "byte", + }: "decimal.Decimal", + typeKey{ + tp: TypeNullDecimal, + thirdPkg: defaultThirdDecimalPkg, + }: "decimal.NullDecimal", typeKey{tp: mysql.TypeEnum}: "string", typeKey{tp: mysql.TypeSet}: "string", typeKey{tp: mysql.TypeTinyBlob}: "string", @@ -49,6 +64,45 @@ var typeMapper = map[typeKey]string{ typeKey{tp: mysql.TypeBlob}: "string", typeKey{tp: mysql.TypeVarString}: "string", typeKey{tp: mysql.TypeString}: "string", + typeKey{tp: TypeNullString}: "sql.NullString", + + // aggregate functions + typeKey{tp: mysql.TypeTiny, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeShort, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeLong, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeFloat, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeDouble, aggregateCall: true}: "sql.NullFloat64", + typeKey{tp: mysql.TypeLonglong, aggregateCall: true}: "sql.NullInt64", + typeKey{tp: mysql.TypeInt24, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeYear, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeVarchar, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeBit, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeJSON, aggregateCall: true}: "sql.NullString", + typeKey{ + tp: mysql.TypeNewDecimal, + thirdPkg: defaultThirdDecimalPkg, + aggregateCall: true, + }: "decimal.NullDecimal", + typeKey{ + tp: TypeNullDecimal, + thirdPkg: defaultThirdDecimalPkg, + aggregateCall: true, + }: "decimal.NullDecimal", + typeKey{tp: mysql.TypeEnum, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeSet, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeTinyBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeMediumBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeLongBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeVarString, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", + typeKey{tp: TypeNullLongLong}: "sql.NullInt64", + typeKey{tp: TypeNullDecimal}: "decimal.NullDecimal", + typeKey{tp: TypeNullString}: "sql.NullString", + typeKey{tp: TypeNullLongLong, aggregateCall: true}: "sql.NullInt64", + typeKey{tp: TypeNullDecimal, aggregateCall: true}: "decimal.NullDecimal", + typeKey{tp: TypeNullString, aggregateCall: true}: "sql.NullString", } // Type is the type of the column. @@ -56,10 +110,14 @@ type Type byte // DataType returns the Go type, third-package of the column. func (c Column) DataType() (parameter.Parameter, error) { - var key = typeKey{tp: c.TP, signed: c.Unsigned} + var key = typeKey{tp: c.TP, signed: c.Unsigned, aggregateCall: c.AggregateCall} + if c.AggregateCall { + key = typeKey{tp: c.TP, aggregateCall: c.AggregateCall} + } if c.TP == mysql.TypeNewDecimal { key.thirdPkg = defaultThirdDecimalPkg } + goType, ok := typeMapper[key] if !ok { return parameter.Parameter{}, fmt.Errorf("unsupported type: %v", c.TP) @@ -73,3 +131,7 @@ func (c Column) GoType() (string, error) { p, err := c.DataType() return p.Type, err } + +func isNullType(tp byte) bool { + return tp >= TypeNullLongLong && tp <= TypeNullString +}