From 385fd0fd8a7aff1fb0ef3ac06949f819aae998c2 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Mon, 15 Aug 2022 21:52:04 +0800 Subject: [PATCH 01/10] Update bun's template --- example/bun/user_model.gen.go | 2 +- internal/gen/bun/bun_gen.tpl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/bun/user_model.gen.go b/example/bun/user_model.gen.go index d79d380..48617ef 100644 --- a/example/bun/user_model.gen.go +++ b/example/bun/user_model.gen.go @@ -18,7 +18,7 @@ type UserModel struct { // User represents a user struct data. type User struct { bun.BaseModel `bun:"table:user"` - Id uint64 `bun:"id,pk,autoincrement;" json:"id"` + Id uint64 `bun:"id,pk,autoincrement" json:"id"` Name string `bun:"name" json:"name"` Password string `bun:"password" json:"password"` Mobile string `bun:"mobile" json:"mobile"` diff --git a/internal/gen/bun/bun_gen.tpl b/internal/gen/bun/bun_gen.tpl index 0f7634d..222f7eb 100644 --- a/internal/gen/bun/bun_gen.tpl +++ b/internal/gen/bun/bun_gen.tpl @@ -20,7 +20,7 @@ type {{UpperCamel $.Table.Name}}Model struct { // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. type {{UpperCamel $.Table.Name}} struct { bun.BaseModel `bun:"table:{{$.Table.Name}}"`{{range $.Table.Columns}} - {{UpperCamel .Name}} {{.GoType}} `bun:"{{.Name}}{{if IsPrimary .Name}},pk{{end}}{{if .AutoIncrement}},autoincrement;{{end}}" json:"{{LowerCamel .Name}}"`{{end}} + {{UpperCamel .Name}} {{.GoType}} `bun:"{{.Name}}{{if IsPrimary .Name}},pk{{end}}{{if .AutoIncrement}},autoincrement{{end}}" json:"{{LowerCamel .Name}}"`{{end}} } {{range $stmt := .SelectStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} From aa5b52a412263bcc5254787f41f5ae0b2e374eb3 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Mon, 15 Aug 2022 21:53:39 +0800 Subject: [PATCH 02/10] Update custom template --- internal/gen/bun/bun_custom.tpl | 2 +- internal/gen/gorm/gorm_custom.tpl | 2 +- internal/gen/sql/sql_custom.tpl | 2 +- internal/gen/xorm/xorm_custom.tpl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/gen/bun/bun_custom.tpl b/internal/gen/bun/bun_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/bun/bun_custom.tpl +++ b/internal/gen/bun/bun_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/gorm/gorm_custom.tpl b/internal/gen/gorm/gorm_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/gorm/gorm_custom.tpl +++ b/internal/gen/gorm/gorm_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/sql/sql_custom.tpl b/internal/gen/sql/sql_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/sql/sql_custom.tpl +++ b/internal/gen/sql/sql_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file diff --git a/internal/gen/xorm/xorm_custom.tpl b/internal/gen/xorm/xorm_custom.tpl index 978b982..27abb75 100644 --- a/internal/gen/xorm/xorm_custom.tpl +++ b/internal/gen/xorm/xorm_custom.tpl @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { +func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { } \ No newline at end of file From ba7aa522c2dcb4df32fc7b005c78c5f2d31daebd Mon Sep 17 00:00:00 2001 From: anqiansong Date: Tue, 16 Aug 2022 00:13:04 +0800 Subject: [PATCH 03/10] Add example test --- example/bun/user_model.go | 2 +- example/example.sql | 1 - .../example_test/bun/user_model.gen_test.go | 1 + .../example_test/gorm/user_model.gen_test.go | 1 + example/example_test/readme.md | 28 +++ example/example_test/sql/mock.go | 23 +++ example/example_test/sql/scanner.go | 86 ++++++++++ .../example_test/sql/user_model.gen_test.go | 110 ++++++++++++ .../example_test/sqlx/user_model.gen_test.go | 1 + .../example_test/xorm/user_model.gen_test.go | 1 + example/go.mod | 6 + example/go.sum | 4 + example/gorm/user_model.go | 2 +- example/sql/user_model.gen.go | 159 ++++++++++++++---- example/sql/user_model.go | 2 +- example/sqlx/user_model.gen.go | 155 +++++++++++++---- example/xorm/user_model.go | 2 +- internal/gen/sql/sql_gen.tpl | 29 ++-- internal/gen/sqlx/sqlx_gen.tpl | 25 +-- 19 files changed, 545 insertions(+), 93 deletions(-) create mode 100644 example/example_test/bun/user_model.gen_test.go create mode 100644 example/example_test/gorm/user_model.gen_test.go create mode 100644 example/example_test/readme.md create mode 100644 example/example_test/sql/mock.go create mode 100644 example/example_test/sql/scanner.go create mode 100644 example/example_test/sql/user_model.gen_test.go create mode 100644 example/example_test/sqlx/user_model.gen_test.go create mode 100644 example/example_test/xorm/user_model.gen_test.go diff --git a/example/bun/user_model.go b/example/bun/user_model.go index 8528c27..b7ce929 100644 --- a/example/bun/user_model.go +++ b/example/bun/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/example/example.sql b/example/example.sql index 03488ed..71eb623 100644 --- a/example/example.sql +++ b/example/example.sql @@ -10,7 +10,6 @@ CREATE TABLE `user` `create_at` timestamp NULL, `update_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY `name_index` (`name`), - UNIQUE KEY `type_index` (`type`), UNIQUE KEY `mobile_index` (`mobile`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'user table' COLLATE=utf8mb4_general_ci; diff --git a/example/example_test/bun/user_model.gen_test.go b/example/example_test/bun/user_model.gen_test.go new file mode 100644 index 0000000..c0ac838 --- /dev/null +++ b/example/example_test/bun/user_model.gen_test.go @@ -0,0 +1 @@ +package bun diff --git a/example/example_test/gorm/user_model.gen_test.go b/example/example_test/gorm/user_model.gen_test.go new file mode 100644 index 0000000..4506bcf --- /dev/null +++ b/example/example_test/gorm/user_model.gen_test.go @@ -0,0 +1 @@ +package gorm diff --git a/example/example_test/readme.md b/example/example_test/readme.md new file mode 100644 index 0000000..4ff62db --- /dev/null +++ b/example/example_test/readme.md @@ -0,0 +1,28 @@ +# example_test + +## before test +1. started docker +2. run a mysql container which dsn is `root:mysqlpw@(localhost:55000)` +3. new a schema `test` +4. create a table use the following sql +```sql +CREATE TABLE `user` +( + `id` bigint(10) unsigned NOT NULL AUTO_INCREMENT primary key, + `name` varchar(255) COLLATE utf8mb4_general_ci NULL COMMENT 'The username', + `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'The \n user password', + `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'The mobile phone number', + `gender` char(10) COLLATE utf8mb4_general_ci NOT NULL COMMENT 'gender,male|female|unknown', + `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT 'The nickname', + `type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT 'The user type, 0:normal,1:vip, for test golang keyword', + `create_at` timestamp NULL, + `update_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY `name_index` (`name`), + UNIQUE KEY `type_index` (`type`), + UNIQUE KEY `mobile_index` (`mobile`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'user table' COLLATE=utf8mb4_general_ci; +``` +5. clean the test data and set auto_increment to `1` +6. run the test + + diff --git a/example/example_test/sql/mock.go b/example/example_test/sql/mock.go new file mode 100644 index 0000000..6e7a213 --- /dev/null +++ b/example/example_test/sql/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/sql" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/sql/scanner.go b/example/example_test/sql/scanner.go new file mode 100644 index 0000000..018f31d --- /dev/null +++ b/example/example_test/sql/scanner.go @@ -0,0 +1,86 @@ +package sql + +import ( + "database/sql" + "errors" + "reflect" + + model "github.com/anqiansong/sqlgen/example/sql" +) + +type customScanner struct { +} + +func (c customScanner) getRowElem(v interface{}) ([]interface{}, error) { + value, ok := v.(reflect.Value) + if !ok { + value = reflect.ValueOf(v) + } + elem := value.Elem() + switch elem.Kind() { + case reflect.Pointer: + return c.getRowElem(elem.Elem()) + case reflect.Struct: + var list []interface{} + for i := 0; i < elem.NumField(); i++ { + f := elem.Field(i) + list = append(list, f.Addr().Interface()) + } + return list, nil + default: + return nil, errors.New("expect a struct") + } +} + +func (c customScanner) getRowsElem(v interface{}) ([][]interface{}, error) { + value := reflect.ValueOf(v) + elem := value.Elem() + switch elem.Kind() { + case reflect.Pointer: + return c.getRowsElem(elem.Elem()) + case reflect.Slice: + var list [][]interface{} + for i := 0; i < elem.NumField(); i++ { + f := elem.Field(i) + item := f.Elem() + rowElem, err := c.getRowsElem(item) + if err != nil { + return nil, err + } + + list = append(list, rowElem...) + } + return list, nil + default: + return nil, errors.New("expect a struct") + } +} + +func (c customScanner) ScanRow(row *sql.Row, v interface{}) error { + dest, err := c.getRowElem(v) + if err != nil { + return err + } + return row.Scan(dest...) +} + +func (c customScanner) ScanRows(rows *sql.Rows, v interface{}) error { + dests, err := c.getRowsElem(v) + if err != nil { + return err + } + var i int + for rows.Next() && i < len(dests) { + dest := dests[i] + err = rows.Scan(dest) + if err != nil { + return err + } + i++ + } + return nil +} + +func getScanner() model.Scanner { + return customScanner{} +} diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go new file mode 100644 index 0000000..5a934e3 --- /dev/null +++ b/example/example_test/sql/user_model.gen_test.go @@ -0,0 +1,110 @@ +package sql + +import ( + "context" + "database/sql" + "log" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/sql" + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *sql.DB +) + +func TestMain(m *testing.M) { + var err error + db, err = sql.Open("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + if err != nil { + log.Fatalln(err) + } + + um = model.NewUserModel(db, getScanner()) + m.Run() +} + +func mustInitDB(db *sql.DB) { + tx, err := db.Begin() + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `truncate table user`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `alter table user auto_increment=1`) + if err != nil { + log.Fatalln(err) + } + + err = tx.Commit() + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) + +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/sqlx/user_model.gen_test.go b/example/example_test/sqlx/user_model.gen_test.go new file mode 100644 index 0000000..ed8ee00 --- /dev/null +++ b/example/example_test/sqlx/user_model.gen_test.go @@ -0,0 +1 @@ +package sqlx diff --git a/example/example_test/xorm/user_model.gen_test.go b/example/example_test/xorm/user_model.gen_test.go new file mode 100644 index 0000000..0baf644 --- /dev/null +++ b/example/example_test/xorm/user_model.gen_test.go @@ -0,0 +1 @@ +package xorm diff --git a/example/go.mod b/example/go.mod index 9445122..46ff15e 100644 --- a/example/go.mod +++ b/example/go.mod @@ -3,7 +3,10 @@ module github.com/anqiansong/sqlgen/example go 1.18 require ( + github.com/go-sql-driver/mysql v1.6.0 github.com/jmoiron/sqlx v1.3.5 + github.com/satori/go.uuid v1.2.0 + github.com/stretchr/testify v1.7.0 github.com/uptrace/bun v1.1.7 gorm.io/gorm v1.23.8 xorm.io/builder v0.3.12 @@ -11,6 +14,7 @@ require ( ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/goccy/go-json v0.8.1 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -18,9 +22,11 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/syndtr/goleveldb v1.0.0 // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/example/go.sum b/example/go.sum index 8a1e25d..4061a5a 100644 --- a/example/go.sum +++ b/example/go.sum @@ -205,9 +205,11 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -314,6 +316,7 @@ github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThC github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= @@ -512,6 +515,7 @@ google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= diff --git a/example/gorm/user_model.go b/example/gorm/user_model.go index 8528c27..b7ce929 100644 --- a/example/gorm/user_model.go +++ b/example/gorm/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index cdfd925..774f762 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -13,7 +13,7 @@ import ( // UserModel represents a user model. type UserModel struct { - db *sql.Conn + db *sql.DB scanner Scanner } @@ -211,7 +211,7 @@ type DeleteOneOrderByIDDescLimitCountLimitParameter struct { } // NewUserModel creates a new user model. -func NewUserModel(db *sql.Conn, scanner Scanner) *UserModel { +func NewUserModel(db *sql.DB, scanner Scanner) *UserModel { return &UserModel{ db: db, scanner: scanner, @@ -219,19 +219,17 @@ func NewUserModel(db *sql.Conn, scanner Scanner) *UserModel { } // Create creates user data. -func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { +func (m *UserModel) Create(ctx context.Context, data ...*User) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, v.Name, v.Password, v.Mobile, v.Gender, v.Nickname, v.Type, v.CreateAt, v.UpdateAt) if err != nil { @@ -245,18 +243,23 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { v.Id = uint64(id) } - return + return nil } // FindOne is generated from sql: // select * from `user` where `id` = ? limit 1; func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -270,11 +273,16 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (r // select * from `user` where `name` = ? limit 1; func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -288,12 +296,17 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP // select * from `user` where `name` = ? group by name limit 1; func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -307,13 +320,18 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy // select * from `user` where `name` = ? group by name having name = ? limit 1; func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where FindOneGroupByNameHavingNameWhereParameter, having FindOneGroupByNameHavingNameHavingParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) b.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -326,9 +344,14 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find // FindAll is generated from sql: // select * from `user`; func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -349,11 +372,16 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { // FindLimit is generated from sql: // select * from `user` where id > ? limit ?; func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter, limit FindLimitLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(limit.Count) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -374,10 +402,15 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter // FindLimitOffset is generated from sql: // select * from `user` limit ?, ?; func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -398,12 +431,17 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi // FindGroupLimitOffset is generated from sql: // select * from `user` where id > ? group by name limit ?, ?; func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLimitOffsetWhereParameter, limit FindGroupLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -424,13 +462,18 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim // FindGroupHavingLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? limit ?, ?; func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGroupHavingLimitOffsetWhereParameter, having FindGroupHavingLimitOffsetHavingParameter, limit FindGroupHavingLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -451,7 +494,8 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr // FindGroupHavingOrderAscLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id limit ?, ?; func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, where FindGroupHavingOrderAscLimitOffsetWhereParameter, having FindGroupHavingOrderAscLimitOffsetHavingParameter, limit FindGroupHavingOrderAscLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) @@ -459,6 +503,10 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher b.OrderBy(`id`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -479,7 +527,8 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher // FindGroupHavingOrderDescLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id desc limit ?, ?; func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, where FindGroupHavingOrderDescLimitOffsetWhereParameter, having FindGroupHavingOrderDescLimitOffsetHavingParameter, limit FindGroupHavingOrderDescLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) @@ -487,6 +536,10 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe b.OrderBy(`id desc`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sql.Rows rows, err = m.db.QueryContext(ctx, query, args...) if err != nil { @@ -508,11 +561,16 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe // select `name`, `password`, `mobile` from `user` where id > ? limit 1; func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`name, password, mobile`) + b := builder.MySQL() + b.Select(`name, password, mobile`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -526,10 +584,15 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam // select count(id) AS countID from `user`; func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResult, err error) { result = new(FindAllCountResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -543,11 +606,16 @@ func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResul // select count(id) AS countID from `user` where id > ?; func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhereWhereParameter) (result *FindAllCountWhereResult, err error) { result = new(FindAllCountWhereResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -561,10 +629,15 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe // select max(id) AS maxID from `user`; func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err error) { result = new(FindMaxIDResult) - b := builder.Select(`max(id) AS maxID`) + b := builder.MySQL() + b.Select(`max(id) AS maxID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -578,10 +651,15 @@ func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err // select min(id) AS minID from `user`; func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err error) { result = new(FindMinIDResult) - b := builder.Select(`min(id) AS minID`) + b := builder.MySQL() + b.Select(`min(id) AS minID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -595,10 +673,15 @@ func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err // select avg(id) AS avgID from `user`; func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err error) { result = new(FindAvgIDResult) - b := builder.Select(`avg(id) AS avgID`) + b := builder.MySQL() + b.Select(`avg(id) AS avgID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -611,7 +694,8 @@ func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err // Update is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -634,7 +718,8 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // UpdateOrderByIdDesc is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -658,7 +743,8 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U // UpdateOrderByIdDescLimitCount is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -682,7 +768,8 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use // DeleteOne is generated from sql: // delete from `user` where `id` = ?; func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) query, args, err := b.ToSQL() @@ -696,7 +783,8 @@ func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter // DeleteOneByName is generated from sql: // delete from `user` where `name` = ?; func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) query, args, err := b.ToSQL() @@ -710,7 +798,8 @@ func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWh // DeleteOneOrderByIDAsc is generated from sql: // delete from `user` where `name` = ? order by id; func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOrderByIDAscWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id`) @@ -725,7 +814,8 @@ func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOr // DeleteOneOrderByIDDesc is generated from sql: // delete from `user` where `name` = ? order by id desc; func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneOrderByIDDescWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) @@ -740,7 +830,8 @@ func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneO // DeleteOneOrderByIDDescLimitCount is generated from sql: // delete from `user` where `name` = ? order by id desc limit ?; func (m *UserModel) DeleteOneOrderByIDDescLimitCount(ctx context.Context, where DeleteOneOrderByIDDescLimitCountWhereParameter, limit DeleteOneOrderByIDDescLimitCountLimitParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) diff --git a/example/sql/user_model.go b/example/sql/user_model.go index 8528c27..b7ce929 100644 --- a/example/sql/user_model.go +++ b/example/sql/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index 76b2bcb..455fd33 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -218,19 +218,17 @@ func NewUserModel(db *sqlx.DB) *UserModel { } // Create creates user data. -func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { +func (m *UserModel) Create(ctx context.Context, data ...*User) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO user (`name`, `password`, `mobile`, `gender`, `nickname`, `type`, `create_at`, `update_at`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, v.Name, v.Password, v.Mobile, v.Gender, v.Nickname, v.Type, v.CreateAt, v.UpdateAt) if err != nil { @@ -244,18 +242,23 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) (err error) { v.Id = uint64(id) } - return + return nil } // FindOne is generated from sql: // select * from `user` where `id` = ? limit 1; func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -285,11 +288,16 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (r // select * from `user` where `name` = ? limit 1; func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -319,12 +327,17 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP // select * from `user` where `name` = ? group by name limit 1; func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupByNameWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -354,13 +367,18 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy // select * from `user` where `name` = ? group by name having name = ? limit 1; func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where FindOneGroupByNameHavingNameWhereParameter, having FindOneGroupByNameHavingNameHavingParameter) (result *User, err error) { result = new(User) - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) b.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -389,9 +407,14 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find // FindAll is generated from sql: // select * from `user`; func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -419,11 +442,16 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { // FindLimit is generated from sql: // select * from `user` where id > ? limit ?; func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter, limit FindLimitLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(limit.Count) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -451,10 +479,15 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter // FindLimitOffset is generated from sql: // select * from `user` limit ?, ?; func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -482,12 +515,17 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi // FindGroupLimitOffset is generated from sql: // select * from `user` where id > ? group by name limit ?, ?; func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLimitOffsetWhereParameter, limit FindGroupLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -515,13 +553,18 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim // FindGroupHavingLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? limit ?, ?; func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGroupHavingLimitOffsetWhereParameter, having FindGroupHavingLimitOffsetHavingParameter, limit FindGroupHavingLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -549,7 +592,8 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr // FindGroupHavingOrderAscLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id limit ?, ?; func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, where FindGroupHavingOrderAscLimitOffsetWhereParameter, having FindGroupHavingOrderAscLimitOffsetHavingParameter, limit FindGroupHavingOrderAscLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) @@ -557,6 +601,10 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher b.OrderBy(`id`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -584,7 +632,8 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher // FindGroupHavingOrderDescLimitOffset is generated from sql: // select * from `user` where id > ? group by name having id > ? order by id desc limit ?, ?; func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, where FindGroupHavingOrderDescLimitOffsetWhereParameter, having FindGroupHavingOrderDescLimitOffsetHavingParameter, limit FindGroupHavingOrderDescLimitOffsetLimitParameter) (result []*User, err error) { - b := builder.Select(`*`) + b := builder.MySQL() + b.Select(`*`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) @@ -592,6 +641,10 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe b.OrderBy(`id desc`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -620,11 +673,16 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe // select `name`, `password`, `mobile` from `user` where id > ? limit 1; func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParameter) (result *User, err error) { result = new(User) - b := builder.Select(`name, password, mobile`) + b := builder.MySQL() + b.Select(`name, password, mobile`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -654,10 +712,15 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam // select count(id) AS countID from `user`; func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResult, err error) { result = new(FindAllCountResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -687,11 +750,16 @@ func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResul // select count(id) AS countID from `user` where id > ?; func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhereWhereParameter) (result *FindAllCountWhereResult, err error) { result = new(FindAllCountWhereResult) - b := builder.Select(`count(id) AS countID`) + b := builder.MySQL() + b.Select(`count(id) AS countID`) b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -721,10 +789,15 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe // select max(id) AS maxID from `user`; func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err error) { result = new(FindMaxIDResult) - b := builder.Select(`max(id) AS maxID`) + b := builder.MySQL() + b.Select(`max(id) AS maxID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -754,10 +827,15 @@ func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err // select min(id) AS minID from `user`; func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err error) { result = new(FindMinIDResult) - b := builder.Select(`min(id) AS minID`) + b := builder.MySQL() + b.Select(`min(id) AS minID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -787,10 +865,15 @@ func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err // select avg(id) AS avgID from `user`; func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err error) { result = new(FindAvgIDResult) - b := builder.Select(`avg(id) AS avgID`) + b := builder.MySQL() + b.Select(`avg(id) AS avgID`) b.From("`user`") b.Limit(1) query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -819,7 +902,8 @@ func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err // Update is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -842,7 +926,8 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // UpdateOrderByIdDesc is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -866,7 +951,8 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U // UpdateOrderByIdDescLimitCount is generated from sql: // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -890,7 +976,8 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use // DeleteOne is generated from sql: // delete from `user` where `id` = ?; func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) query, args, err := b.ToSQL() @@ -904,7 +991,8 @@ func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter // DeleteOneByName is generated from sql: // delete from `user` where `name` = ?; func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) query, args, err := b.ToSQL() @@ -918,7 +1006,8 @@ func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWh // DeleteOneOrderByIDAsc is generated from sql: // delete from `user` where `name` = ? order by id; func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOrderByIDAscWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id`) @@ -933,7 +1022,8 @@ func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOr // DeleteOneOrderByIDDesc is generated from sql: // delete from `user` where `name` = ? order by id desc; func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneOrderByIDDescWhereParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) @@ -948,7 +1038,8 @@ func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneO // DeleteOneOrderByIDDescLimitCount is generated from sql: // delete from `user` where `name` = ? order by id desc limit ?; func (m *UserModel) DeleteOneOrderByIDDescLimitCount(ctx context.Context, where DeleteOneOrderByIDDescLimitCountWhereParameter, limit DeleteOneOrderByIDDescLimitCountLimitParameter) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.OrderBy(`id desc`) diff --git a/example/xorm/user_model.go b/example/xorm/user_model.go index 8528c27..b7ce929 100644 --- a/example/xorm/user_model.go +++ b/example/xorm/user_model.go @@ -3,6 +3,6 @@ package model import "context" // TODO(sqlgen): Add your own customize code here. -func (m *UserModel) Customize(ctx context.Context, args ...any) { +func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { } diff --git a/internal/gen/sql/sql_gen.tpl b/internal/gen/sql/sql_gen.tpl index cd293f4..69552c6 100644 --- a/internal/gen/sql/sql_gen.tpl +++ b/internal/gen/sql/sql_gen.tpl @@ -14,7 +14,7 @@ import ( // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. type {{UpperCamel $.Table.Name}}Model struct { - db *sql.Conn + db *sql.DB scanner Scanner } @@ -40,7 +40,7 @@ type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} // New{{UpperCamel $.Table.Name}}Model creates a new {{$.Table.Name}} model. -func New{{UpperCamel $.Table.Name}}Model(db *sql.Conn, scanner Scanner) *{{UpperCamel $.Table.Name}}Model { +func New{{UpperCamel $.Table.Name}}Model(db *sql.DB, scanner Scanner) *{{UpperCamel $.Table.Name}}Model { return &{{UpperCamel $.Table.Name}}Model{ db: db, scanner: scanner, @@ -48,19 +48,17 @@ func New{{UpperCamel $.Table.Name}}Model(db *sql.Conn, scanner Scanner) *{{Upper } // Create creates {{$.Table.Name}} data. -func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) (err error) { +func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, {{InsertValues "v"}}) if err != nil { @@ -74,14 +72,15 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* {{range $.Table.Columns}}{{if IsPrimary .Name}}{{if .AutoIncrement}}v.{{UpperCamel .Name}} = {{.GoType}}(id){{end}}{{end}}{{end}} } - return + return nil } {{range $stmt := .SelectStmt}} // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})(result {{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} err error){ {{if $stmt.Limit.One}} result = new({{$stmt.ReceiverName}}){{end}} - b := builder.Select(`{{$stmt.SelectSQL}}`) + b := builder.MySQL() + b.Select(`{{$stmt.SelectSQL}}`) b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end }}{{if $stmt.GroupBy.IsValid}}b.GroupBy({{$stmt.GroupBy.SQL}}) @@ -89,6 +88,10 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) {{end}}query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + {{if $stmt.Limit.One}}row := m.db.QueryRowContext(ctx, query, args...) if err = row.Err(); err != nil { return nil, err @@ -117,7 +120,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) @@ -138,7 +142,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, dat // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) diff --git a/internal/gen/sqlx/sqlx_gen.tpl b/internal/gen/sqlx/sqlx_gen.tpl index ea86f0a..20020de 100644 --- a/internal/gen/sqlx/sqlx_gen.tpl +++ b/internal/gen/sqlx/sqlx_gen.tpl @@ -47,19 +47,17 @@ func New{{UpperCamel $.Table.Name}}Model(db *sqlx.DB) *{{UpperCamel $.Table.Name } // Create creates {{$.Table.Name}} data. -func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) (err error) { +func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var stmt *sql.Stmt - stmt, err = m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") + stmt, err := m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") if err != nil { - return + return err } - defer func() { - err = stmt.Close() - }() + defer stmt.Close() for _, v := range data { result, err := stmt.ExecContext(ctx, {{InsertValues "v"}}) if err != nil { @@ -73,14 +71,15 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* {{range $.Table.Columns}}{{if IsPrimary .Name}}{{if .AutoIncrement}}v.{{UpperCamel .Name}} = {{.GoType}}(id){{end}}{{end}}{{end}} } - return + return nil } {{range $stmt := .SelectStmt}} // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})(result {{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} err error){ {{if $stmt.Limit.One}} result = new({{$stmt.ReceiverName}}){{end}} - b := builder.Select(`{{$stmt.SelectSQL}}`) + b := builder.MySQL() + b.Select(`{{$stmt.SelectSQL}}`) b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end }}{{if $stmt.GroupBy.IsValid}}b.GroupBy({{$stmt.GroupBy.SQL}}) @@ -88,6 +87,10 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) {{end}}query, args, err := b.ToSQL() + if err != nil { + return nil, err + } + var rows *sqlx.Rows rows, err = m.db.QueryxContext(ctx, query, args...) if err != nil { @@ -118,7 +121,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Update(builder.Eq{ + b := builder.MySQL() + b.Update(builder.Eq{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) @@ -139,7 +143,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, dat // {{.FuncName}} is generated from sql: // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { - b := builder.Delete() + b := builder.MySQL() + b.Delete() b.From("`{{$.Table.Name}}`") {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) From a53d18adb67404f779a24041b86aa926181a234b Mon Sep 17 00:00:00 2001 From: anqiansong Date: Tue, 16 Aug 2022 00:14:13 +0800 Subject: [PATCH 04/10] Update readme --- example/example_test/readme.md | 1 - 1 file changed, 1 deletion(-) diff --git a/example/example_test/readme.md b/example/example_test/readme.md index 4ff62db..930af29 100644 --- a/example/example_test/readme.md +++ b/example/example_test/readme.md @@ -18,7 +18,6 @@ CREATE TABLE `user` `create_at` timestamp NULL, `update_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY `name_index` (`name`), - UNIQUE KEY `type_index` (`type`), UNIQUE KEY `mobile_index` (`mobile`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'user table' COLLATE=utf8mb4_general_ci; ``` From 2b98b49a06856e22e7a371e094e4cce1a33f55d3 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Tue, 16 Aug 2022 14:23:35 +0800 Subject: [PATCH 05/10] Add sql example_test --- example/example_test/sql/scanner.go | 90 ++++++++++++------- .../example_test/sql/user_model.gen_test.go | 88 ++++++++++++++++++ example/sql/user_model.gen.go | 22 ++--- example/sqlx/user_model.gen.go | 8 +- example/xorm/user_model.gen.go | 8 +- internal/gen/sql/sql.go | 2 +- internal/gen/sql/sql_gen.tpl | 2 +- internal/gen/sqlx/sqlx.go | 2 +- internal/gen/xorm/xorm.go | 2 +- 9 files changed, 168 insertions(+), 56 deletions(-) diff --git a/example/example_test/sql/scanner.go b/example/example_test/sql/scanner.go index 018f31d..c2477f7 100644 --- a/example/example_test/sql/scanner.go +++ b/example/example_test/sql/scanner.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "reflect" + "strings" model "github.com/anqiansong/sqlgen/example/sql" ) @@ -32,28 +33,63 @@ func (c customScanner) getRowElem(v interface{}) ([]interface{}, error) { } } -func (c customScanner) getRowsElem(v interface{}) ([][]interface{}, error) { - value := reflect.ValueOf(v) - elem := value.Elem() - switch elem.Kind() { - case reflect.Pointer: - return c.getRowsElem(elem.Elem()) - case reflect.Slice: - var list [][]interface{} - for i := 0; i < elem.NumField(); i++ { - f := elem.Field(i) - item := f.Elem() - rowElem, err := c.getRowsElem(item) - if err != nil { - return nil, err - } +func (c customScanner) getRowsElem(rows *sql.Rows, v interface{}) error { + tp := reflect.TypeOf(v) + if tp.Kind() != reflect.Pointer { + return errors.New("expected a pointer") + } + sliceTp := tp.Elem() + if sliceTp.Kind() != reflect.Slice { + return errors.New("expected a slice") + } + + sliceValue := reflect.Indirect(reflect.ValueOf(v)) + itemType := sliceTp.Elem() + cols, err := rows.Columns() + if err != nil { + return err + } + + for rows.Next() { + item := reflect.New(itemType.Elem()).Elem() + dest := structPointers(item.Elem(), cols) - list = append(list, rowElem...) + err := rows.Scan(dest...) + if err != nil { + return err } - return list, nil - default: - return nil, errors.New("expect a struct") + sliceValue.Set(reflect.Append(sliceValue, item)) + } + + return rows.Err() +} + +func fieldByName(v reflect.Value, name string) reflect.Value { + typ := v.Type() + + for i := 0; i < v.NumField(); i++ { + tag, ok := typ.Field(i).Tag.Lookup("db") + if ok && tag == name { + return v.Field(i) + } + } + + return v.FieldByName(strings.Title(name)) +} + +func structPointers(stct reflect.Value, cols []string) []interface{} { + pointers := make([]interface{}, 0, len(cols)) + for _, colName := range cols { + fieldVal := fieldByName(stct, colName) + if !fieldVal.IsValid() || !fieldVal.CanSet() { + var nothing interface{} + pointers = append(pointers, ¬hing) + continue + } + + pointers = append(pointers, fieldVal.Addr().Interface()) } + return pointers } func (c customScanner) ScanRow(row *sql.Row, v interface{}) error { @@ -65,20 +101,8 @@ func (c customScanner) ScanRow(row *sql.Row, v interface{}) error { } func (c customScanner) ScanRows(rows *sql.Rows, v interface{}) error { - dests, err := c.getRowsElem(v) - if err != nil { - return err - } - var i int - for rows.Next() && i < len(dests) { - dest := dests[i] - err = rows.Scan(dest) - if err != nil { - return err - } - i++ - } - return nil + //return scan.Rows(v, rows) + return c.getRowsElem(rows, v) } func getScanner() model.Scanner { diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go index 5a934e3..1f883f7 100644 --- a/example/example_test/sql/user_model.gen_test.go +++ b/example/example_test/sql/user_model.gen_test.go @@ -90,7 +90,87 @@ func TestFindOne(t *testing.T) { assert.NoError(t, err) assertUserEqual(t, mockUser, actual) })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) } func assertUserEqual(t *testing.T, expected, actual *model.User) { @@ -102,6 +182,14 @@ func assertUserEqual(t *testing.T, expected, actual *model.User) { assert.Equal(t, *expected, *actual) } +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + func initAndRun(f func(t *testing.T)) func(t *testing.T) { mustInitDB(db) return func(t *testing.T) { diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index 774f762..992aba5 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -325,7 +325,7 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) + b.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() if err != nil { @@ -363,7 +363,7 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil @@ -393,7 +393,7 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil @@ -422,7 +422,7 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil @@ -453,7 +453,7 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil @@ -467,7 +467,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() if err != nil { @@ -485,7 +485,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil @@ -499,7 +499,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() @@ -518,7 +518,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil @@ -532,7 +532,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id desc`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() @@ -551,7 +551,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil { + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } return result, nil diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index 455fd33..858ee9f 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -372,7 +372,7 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find b.From("`user`") b.Where(builder.Expr(`name = ?`, where.NameEqual)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) + b.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) b.Limit(1) query, args, err := b.ToSQL() if err != nil { @@ -558,7 +558,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() if err != nil { @@ -597,7 +597,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() @@ -637,7 +637,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe b.From("`user`") b.Where(builder.Expr(`id > ?`, where.IdGT)) b.GroupBy(`name`) - b.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + b.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) b.OrderBy(`id desc`) b.Limit(limit.Count, limit.Offset) query, args, err := b.ToSQL() diff --git a/example/xorm/user_model.gen.go b/example/xorm/user_model.gen.go index 257a312..7ee7d5e 100644 --- a/example/xorm/user_model.gen.go +++ b/example/xorm/user_model.gen.go @@ -278,7 +278,7 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find session.Select(`*`) session.Where(`name = ?`, where.NameEqual) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`name = %v`, having.NameEqual)) + session.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) session.Limit(1) _, err := session.Get(result) return result, err @@ -338,7 +338,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr session.Select(`*`) session.Where(`id > ?`, where.IdGT) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + session.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) session.Limit(limit.Count, limit.Offset) err := session.Find(&result) return result, err @@ -352,7 +352,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher session.Select(`*`) session.Where(`id > ?`, where.IdGT) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + session.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) session.OrderBy(`id`) session.Limit(limit.Count, limit.Offset) err := session.Find(&result) @@ -367,7 +367,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe session.Select(`*`) session.Where(`id > ?`, where.IdGT) session.GroupBy(`name`) - session.Having(fmt.Sprintf(`id > %v`, having.IdGT)) + session.Having(fmt.Sprintf(`id > '%v'`, having.IdGT)) session.OrderBy(`id desc`) session.Limit(limit.Count, limit.Offset) err := session.Find(&result) diff --git a/internal/gen/sql/sql.go b/internal/gen/sql/sql.go index 8816ac8..6bc70e3 100644 --- a/internal/gen/sql/sql.go +++ b/internal/gen/sql/sql.go @@ -62,7 +62,7 @@ func Run(list []spec.Context, output string) error { return strings.Join(values, ", ") }, "HavingSprintf": func(format string) string { - format = strings.ReplaceAll(format, "?", "%v") + format = strings.ReplaceAll(format, "?", "'%v'") return format }, }) diff --git a/internal/gen/sql/sql_gen.tpl b/internal/gen/sql/sql_gen.tpl index 69552c6..c46f8a8 100644 --- a/internal/gen/sql/sql_gen.tpl +++ b/internal/gen/sql/sql_gen.tpl @@ -109,7 +109,7 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if result = nil } }() - if err = m.scanner.ScanRows(rows, result); err != nil{ + if err = m.scanner.ScanRows(rows, &result); err != nil{ return nil, err } return result, nil{{end}} diff --git a/internal/gen/sqlx/sqlx.go b/internal/gen/sqlx/sqlx.go index 1c87628..61b3dda 100644 --- a/internal/gen/sqlx/sqlx.go +++ b/internal/gen/sqlx/sqlx.go @@ -52,7 +52,7 @@ func Run(list []spec.Context, output string) error { return strings.Join(values, ", ") }, "HavingSprintf": func(format string) string { - format = strings.ReplaceAll(format, "?", "%v") + format = strings.ReplaceAll(format, "?", "'%v'") return format }, }) diff --git a/internal/gen/xorm/xorm.go b/internal/gen/xorm/xorm.go index caa8156..bc950c3 100644 --- a/internal/gen/xorm/xorm.go +++ b/internal/gen/xorm/xorm.go @@ -27,7 +27,7 @@ func Run(list []spec.Context, output string) error { return ctx.Table.IsPrimary(name) }, "HavingSprintf": func(format string) string { - format = strings.ReplaceAll(format, "?", "%v") + format = strings.ReplaceAll(format, "?", "'%v'") return format }, }) From 50a6f35d8330f6aeb0bcbbb3b6f30b42421759fc Mon Sep 17 00:00:00 2001 From: anqiansong Date: Tue, 16 Aug 2022 21:21:57 +0800 Subject: [PATCH 06/10] Add example_text --- example/example_test/NOTES | 21 ++++ example/example_test/sql/scanner.go | 86 ++++++------- .../example_test/sql/user_model.gen_test.go | 114 +++++++++++++++++ example/sql/user_model.gen.go | 49 ++------ example/sqlx/user_model.gen.go | 119 +++--------------- internal/gen/sql/sql_gen.tpl | 7 +- internal/gen/sqlx/sqlx_gen.tpl | 7 +- 7 files changed, 202 insertions(+), 201 deletions(-) create mode 100644 example/example_test/NOTES diff --git a/example/example_test/NOTES b/example/example_test/NOTES new file mode 100644 index 0000000..c214b87 --- /dev/null +++ b/example/example_test/NOTES @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 zeromicro + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/example/example_test/sql/scanner.go b/example/example_test/sql/scanner.go index c2477f7..32b1c4f 100644 --- a/example/example_test/sql/scanner.go +++ b/example/example_test/sql/scanner.go @@ -4,7 +4,6 @@ import ( "database/sql" "errors" "reflect" - "strings" model "github.com/anqiansong/sqlgen/example/sql" ) @@ -13,11 +12,15 @@ type customScanner struct { } func (c customScanner) getRowElem(v interface{}) ([]interface{}, error) { + var elem reflect.Value value, ok := v.(reflect.Value) if !ok { + elem = value.Elem() value = reflect.ValueOf(v) + } else { + elem = value } - elem := value.Elem() + switch elem.Kind() { case reflect.Pointer: return c.getRowElem(elem.Elem()) @@ -33,63 +36,56 @@ func (c customScanner) getRowElem(v interface{}) ([]interface{}, error) { } } +// getRowsElem is inspired by https://github.com/zeromicro/go-zero/blob/8ed22eafdda04c4526164450d7c13c2f4b0f076c/core/stores/sqlx/orm.go#L163 func (c customScanner) getRowsElem(rows *sql.Rows, v interface{}) error { - tp := reflect.TypeOf(v) - if tp.Kind() != reflect.Pointer { - return errors.New("expected a pointer") - } - sliceTp := tp.Elem() - if sliceTp.Kind() != reflect.Slice { - return errors.New("expected a slice") + valueOf := reflect.ValueOf(v) + if valueOf.Kind() != reflect.Ptr { + return errors.New("expect a pointer") } - sliceValue := reflect.Indirect(reflect.ValueOf(v)) - itemType := sliceTp.Elem() - cols, err := rows.Columns() - if err != nil { - return err + typeOf := reflect.TypeOf(v) + sliceTypeOf := typeOf.Elem() + sliceValueOf := valueOf.Elem() + + if sliceTypeOf.Kind() != reflect.Slice { + return errors.New("expect a slice") + } + if !sliceValueOf.CanSet() { + return errors.New("expect a settable slice") + } + isASlicePointer := sliceTypeOf.Elem().Kind() == reflect.Ptr + + var itemReceiver reflect.Type + itemType := sliceTypeOf.Elem() + if itemType.Kind() == reflect.Ptr { + itemReceiver = itemType.Elem() + } else { + itemReceiver = itemType + } + if itemReceiver.Kind() != reflect.Struct { + return errors.New("expect a struct") } for rows.Next() { - item := reflect.New(itemType.Elem()).Elem() - dest := structPointers(item.Elem(), cols) - - err := rows.Scan(dest...) + value := reflect.New(itemReceiver) + dest, err := c.getRowElem(value) if err != nil { return err } - sliceValue.Set(reflect.Append(sliceValue, item)) - } - return rows.Err() -} - -func fieldByName(v reflect.Value, name string) reflect.Value { - typ := v.Type() - - for i := 0; i < v.NumField(); i++ { - tag, ok := typ.Field(i).Tag.Lookup("db") - if ok && tag == name { - return v.Field(i) + err = rows.Scan(dest...) + if err != nil { + return err } - } - - return v.FieldByName(strings.Title(name)) -} -func structPointers(stct reflect.Value, cols []string) []interface{} { - pointers := make([]interface{}, 0, len(cols)) - for _, colName := range cols { - fieldVal := fieldByName(stct, colName) - if !fieldVal.IsValid() || !fieldVal.CanSet() { - var nothing interface{} - pointers = append(pointers, ¬hing) - continue + if isASlicePointer { + sliceValueOf.Set(reflect.Append(sliceValueOf, value)) + } else { + sliceValueOf.Set(reflect.Append(sliceValueOf, reflect.Indirect(sliceValueOf))) } - - pointers = append(pointers, fieldVal.Addr().Interface()) } - return pointers + + return nil } func (c customScanner) ScanRow(row *sql.Row, v interface{}) error { diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go index 1f883f7..e2f0496 100644 --- a/example/example_test/sql/user_model.gen_test.go +++ b/example/example_test/sql/user_model.gen_test.go @@ -173,6 +173,120 @@ func TestFindAll(t *testing.T) { })) } +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindGroupHavingLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + func assertUserEqual(t *testing.T, expected, actual *model.User) { now := time.Now() expected.CreateAt = now diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index 992aba5..a12140e 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -357,12 +357,7 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } @@ -387,12 +382,7 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } @@ -416,12 +406,7 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } @@ -447,12 +432,7 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } @@ -479,12 +459,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } @@ -512,12 +487,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } @@ -545,12 +515,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index 858ee9f..90f2576 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -264,12 +264,7 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (r if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -303,12 +298,7 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -343,12 +333,7 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -384,12 +369,7 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -420,12 +400,7 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -457,12 +432,7 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -493,12 +463,7 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -531,12 +496,7 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -570,12 +530,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -610,12 +565,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -650,12 +600,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -688,12 +633,7 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v User @@ -726,12 +666,7 @@ func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResul if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v FindAllCountResult @@ -765,12 +700,7 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v FindAllCountWhereResult @@ -803,12 +733,7 @@ func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v FindMaxIDResult @@ -841,12 +766,7 @@ func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v FindMinIDResult @@ -879,12 +799,7 @@ func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v FindAvgIDResult diff --git a/internal/gen/sql/sql_gen.tpl b/internal/gen/sql/sql_gen.tpl index c46f8a8..f2d8837 100644 --- a/internal/gen/sql/sql_gen.tpl +++ b/internal/gen/sql/sql_gen.tpl @@ -103,12 +103,7 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() if err = m.scanner.ScanRows(rows, &result); err != nil{ return nil, err } diff --git a/internal/gen/sqlx/sqlx_gen.tpl b/internal/gen/sqlx/sqlx_gen.tpl index 20020de..82c95c6 100644 --- a/internal/gen/sqlx/sqlx_gen.tpl +++ b/internal/gen/sqlx/sqlx_gen.tpl @@ -96,12 +96,7 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if if err != nil { return nil, err } - defer func() { - err = rows.Close() - if err != nil { - result = nil - } - }() + defer rows.Close() for rows.Next() { var v {{$stmt.ReceiverName}} From ea8bdc7106a1443c067c38e4938da863b788ed80 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Tue, 16 Aug 2022 23:44:09 +0800 Subject: [PATCH 07/10] Add example test --- example/bun/user_model.gen.go | 11 +- example/example_test/sql/scanner.go | 52 ++++-- .../example_test/sql/user_model.gen_test.go | 129 ++++++++++++- example/go.mod | 1 + example/go.sum | 2 + example/gorm/user_model.gen.go | 11 +- example/sql/scanner.go | 4 +- example/sql/user_model.gen.go | 165 +++++++++++------ example/sqlx/user_model.gen.go | 10 +- example/xorm/user_model.gen.go | 11 +- internal/gen/sql/scanner.tpl | 6 +- internal/gen/sql/sql_gen.tpl | 15 +- internal/parser/funcmap.go | 169 +++++++++--------- internal/parser/select.go | 12 +- internal/spec/converter.go | 6 +- internal/spec/stmt.go | 7 +- internal/spec/table.go | 3 +- internal/spec/type.go | 71 +++++++- 18 files changed, 479 insertions(+), 206 deletions(-) diff --git a/example/bun/user_model.gen.go b/example/bun/user_model.gen.go index 48617ef..0d167dd 100644 --- a/example/bun/user_model.gen.go +++ b/example/bun/user_model.gen.go @@ -4,6 +4,7 @@ package model import ( "context" + "database/sql" "fmt" "time" @@ -137,7 +138,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { bun.BaseModel `bun:"table:user"` - CountID uint64 `bun:"countID" json:"countID"` + CountID sql.NullInt64 `bun:"countID" json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -148,25 +149,25 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { bun.BaseModel `bun:"table:user"` - CountID uint64 `bun:"countID" json:"countID"` + CountID sql.NullInt64 `bun:"countID" json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { bun.BaseModel `bun:"table:user"` - MaxID uint64 `bun:"maxID" json:"maxID"` + MaxID sql.NullInt64 `bun:"maxID" json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { bun.BaseModel `bun:"table:user"` - MinID uint64 `bun:"minID" json:"minID"` + MinID sql.NullInt64 `bun:"minID" json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { bun.BaseModel `bun:"table:user"` - AvgID uint64 `bun:"avgID" json:"avgID"` + AvgID sql.NullInt64 `bun:"avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/example_test/sql/scanner.go b/example/example_test/sql/scanner.go index 32b1c4f..253653c 100644 --- a/example/example_test/sql/scanner.go +++ b/example/example_test/sql/scanner.go @@ -6,29 +6,57 @@ import ( "reflect" model "github.com/anqiansong/sqlgen/example/sql" + "github.com/iancoleman/strcase" ) type customScanner struct { } -func (c customScanner) getRowElem(v interface{}) ([]interface{}, error) { +func (c customScanner) ColumnMapper(colName string) string { + return strcase.ToCamel(colName) +} + +func (c customScanner) TagKey() string { + return `db` +} + +func (c customScanner) getRowElem(rows *sql.Rows, v interface{}) ([]interface{}, error) { var elem reflect.Value value, ok := v.(reflect.Value) if !ok { - elem = value.Elem() - value = reflect.ValueOf(v) + elem = reflect.ValueOf(v) } else { elem = value } switch elem.Kind() { case reflect.Pointer: - return c.getRowElem(elem.Elem()) + return c.getRowElem(rows, elem.Elem()) case reflect.Struct: var list []interface{} + cols, err := rows.Columns() + if err != nil { + return nil, err + } + + targetField := make(map[string]reflect.Value) for i := 0; i < elem.NumField(); i++ { f := elem.Field(i) - list = append(list, f.Addr().Interface()) + t := elem.Type().Field(i) + tag, ok := t.Tag.Lookup(c.TagKey()) + if ok { + targetField[tag] = f + } + } + + for _, name := range cols { + f, ok := targetField[name] + if !ok { + f = elem.FieldByName(c.ColumnMapper(name)) + } + if f.CanAddr() { + list = append(list, f.Addr().Interface()) + } } return list, nil default: @@ -68,7 +96,7 @@ func (c customScanner) getRowsElem(rows *sql.Rows, v interface{}) error { for rows.Next() { value := reflect.New(itemReceiver) - dest, err := c.getRowElem(value) + dest, err := c.getRowElem(rows, value) if err != nil { return err } @@ -88,16 +116,20 @@ func (c customScanner) getRowsElem(rows *sql.Rows, v interface{}) error { return nil } -func (c customScanner) ScanRow(row *sql.Row, v interface{}) error { - dest, err := c.getRowElem(v) +func (c customScanner) ScanRow(rows *sql.Rows, v interface{}) error { + if !rows.Next() { + return sql.ErrNoRows + } + + dest, err := c.getRowElem(rows, v) if err != nil { return err } - return row.Scan(dest...) + + return rows.Scan(dest...) } func (c customScanner) ScanRows(rows *sql.Rows, v interface{}) error { - //return scan.Rows(v, rows) return c.getRowsElem(rows, v) } diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go index e2f0496..ad4dcd7 100644 --- a/example/example_test/sql/user_model.gen_test.go +++ b/example/example_test/sql/user_model.gen_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "log" + "sort" "testing" "time" @@ -35,6 +36,11 @@ func mustInitDB(db *sql.DB) { log.Fatalln(err) } + _, err = tx.ExecContext(ctx, `SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + log.Fatalln(err) + } + _, err = tx.ExecContext(ctx, `truncate table user`) if err != nil { log.Fatalln(err) @@ -263,30 +269,147 @@ func TestFindGroupHavingLimitOffset(t *testing.T) { } err := um.Create(ctx, list...) assert.NoError(t, err) - actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ Count: 2, Offset: 0, }) assert.NoError(t, err) assertUsersEqual(t, list[:2], actual) })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) - t.Run("FindGroupHavingLimitOffset1", initAndRun(func(t *testing.T) { + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { var list []*model.User for i := 0; i < 5; i++ { list = append(list, mustMockUser()) } err := um.Create(ctx, list...) assert.NoError(t, err) - actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ Count: 2, Offset: 1, }) assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) assertUsersEqual(t, list[1:3], actual) })) } +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, uint64(0), countID.CountID) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, uint64(1), actual.CountID) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + func assertUserEqual(t *testing.T, expected, actual *model.User) { now := time.Now() expected.CreateAt = now diff --git a/example/go.mod b/example/go.mod index 46ff15e..715e1e3 100644 --- a/example/go.mod +++ b/example/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/go-sql-driver/mysql v1.6.0 + github.com/iancoleman/strcase v0.2.0 github.com/jmoiron/sqlx v1.3.5 github.com/satori/go.uuid v1.2.0 github.com/stretchr/testify v1.7.0 diff --git a/example/go.sum b/example/go.sum index 4061a5a..9f84aab 100644 --- a/example/go.sum +++ b/example/go.sum @@ -134,6 +134,8 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHLwW0= +github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= diff --git a/example/gorm/user_model.gen.go b/example/gorm/user_model.gen.go index 1063a06..69a88c1 100644 --- a/example/gorm/user_model.gen.go +++ b/example/gorm/user_model.gen.go @@ -4,6 +4,7 @@ package model import ( "context" + "database/sql" "fmt" "time" @@ -135,7 +136,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `gorm:"column:countID" json:"countID"` + CountID sql.NullInt64 `gorm:"column:countID" json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -145,22 +146,22 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `gorm:"column:countID" json:"countID"` + CountID sql.NullInt64 `gorm:"column:countID" json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `gorm:"column:maxID" json:"maxID"` + MaxID sql.NullInt64 `gorm:"column:maxID" json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `gorm:"column:minID" json:"minID"` + MinID sql.NullInt64 `gorm:"column:minID" json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `gorm:"column:avgID" json:"avgID"` + AvgID sql.NullInt64 `gorm:"column:avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/sql/scanner.go b/example/sql/scanner.go index 7306a65..a56452e 100644 --- a/example/sql/scanner.go +++ b/example/sql/scanner.go @@ -3,6 +3,8 @@ package model import "database/sql" type Scanner interface { - ScanRow(row *sql.Row, v interface{}) error + ScanRow(rows *sql.Rows, v interface{}) error ScanRows(rows *sql.Rows, v interface{}) error + ColumnMapper(colName string) string + TagKey() string } diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index a12140e..990ecec 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -137,7 +137,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `json:"countID"` + CountID sql.NullInt64 `json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -147,22 +147,22 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `json:"countID"` + CountID sql.NullInt64 `json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `json:"maxID"` + MaxID sql.NullInt64 `json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `json:"minID"` + MinID sql.NullInt64 `json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `json:"avgID"` + AvgID sql.NullInt64 `json:"avgID"` } // UpdateWhereParameter is a where parameter structure. @@ -260,13 +260,17 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (r return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + return result, nil } // FindOneByName is generated from sql: @@ -283,13 +287,17 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindOneGroupByName is generated from sql: @@ -307,13 +315,17 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindOneGroupByNameHavingName is generated from sql: @@ -332,13 +344,17 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindAll is generated from sql: @@ -352,15 +368,16 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -377,15 +394,16 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -401,15 +419,16 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -427,15 +446,16 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -454,15 +474,16 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -482,15 +503,16 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -510,15 +532,16 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe return nil, err } - var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() + if err = m.scanner.ScanRows(rows, &result); err != nil { return nil, err } + return result, nil } @@ -536,13 +559,17 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + return result, nil } // FindAllCount is generated from sql: @@ -558,13 +585,17 @@ func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResul return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindAllCountWhere is generated from sql: @@ -581,13 +612,17 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // FindMaxID is generated from sql: @@ -603,13 +638,17 @@ func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + return result, nil } // FindMinID is generated from sql: @@ -625,13 +664,17 @@ func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + if err = m.scanner.ScanRow(rows, &result); err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + return result, nil } // FindAvgID is generated from sql: @@ -647,13 +690,17 @@ func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err return nil, err } - row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { + rows, err := m.db.QueryContext(ctx, query, args...) + if err != nil { return nil, err } - err = m.scanner.ScanRow(row, result) - return + defer rows.Close() + if err = m.scanner.ScanRow(rows, &result); err != nil { + return nil, err + } + + return result, nil } // Update is generated from sql: diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index 90f2576..dbf1327 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -137,7 +137,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `db:"countID" json:"countID"` + CountID sql.NullInt64 `db:"countID" json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -147,22 +147,22 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `db:"countID" json:"countID"` + CountID sql.NullInt64 `db:"countID" json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `db:"maxID" json:"maxID"` + MaxID sql.NullInt64 `db:"maxID" json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `db:"minID" json:"minID"` + MinID sql.NullInt64 `db:"minID" json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `db:"avgID" json:"avgID"` + AvgID sql.NullInt64 `db:"avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/xorm/user_model.gen.go b/example/xorm/user_model.gen.go index 7ee7d5e..cf8766d 100644 --- a/example/xorm/user_model.gen.go +++ b/example/xorm/user_model.gen.go @@ -4,6 +4,7 @@ package model import ( "context" + "database/sql" "fmt" "time" @@ -135,7 +136,7 @@ type FindOnePartWhereParameter struct { // FindAllCountResult is a find all count result. type FindAllCountResult struct { - CountID uint64 `xorm:"'countID'" json:"countID"` + CountID sql.NullInt64 `xorm:"'countID'" json:"countID"` } // FindAllCountWhereWhereParameter is a where parameter structure. @@ -145,22 +146,22 @@ type FindAllCountWhereWhereParameter struct { // FindAllCountWhereResult is a find all count where result. type FindAllCountWhereResult struct { - CountID uint64 `xorm:"'countID'" json:"countID"` + CountID sql.NullInt64 `xorm:"'countID'" json:"countID"` } // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { - MaxID uint64 `xorm:"'maxID'" json:"maxID"` + MaxID sql.NullInt64 `xorm:"'maxID'" json:"maxID"` } // FindMinIDResult is a find min id result. type FindMinIDResult struct { - MinID uint64 `xorm:"'minID'" json:"minID"` + MinID sql.NullInt64 `xorm:"'minID'" json:"minID"` } // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID uint64 `xorm:"'avgID'" json:"avgID"` + AvgID sql.NullInt64 `xorm:"'avgID'" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/internal/gen/sql/scanner.tpl b/internal/gen/sql/scanner.tpl index 858a4c3..a1ec44d 100644 --- a/internal/gen/sql/scanner.tpl +++ b/internal/gen/sql/scanner.tpl @@ -3,6 +3,8 @@ package model import "database/sql" type Scanner interface { - ScanRow(row *sql.Row, v interface{}) error + ScanRow(rows *sql.Rows, v interface{}) error ScanRows(rows *sql.Rows, v interface{}) error -} \ No newline at end of file + ColumnMapper(colName string) string + TagKey() string +} diff --git a/internal/gen/sql/sql_gen.tpl b/internal/gen/sql/sql_gen.tpl index f2d8837..72867f9 100644 --- a/internal/gen/sql/sql_gen.tpl +++ b/internal/gen/sql/sql_gen.tpl @@ -92,22 +92,17 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if return nil, err } - {{if $stmt.Limit.One}}row := m.db.QueryRowContext(ctx, query, args...) - if err = row.Err(); err != nil { - return nil, err - } - err = m.scanner.ScanRow(row, result) - return - {{else}}var rows *sql.Rows - rows, err = m.db.QueryContext(ctx, query, args...) + rows, err := m.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - if err = m.scanner.ScanRows(rows, &result); err != nil{ + + if err = m.scanner. {{if $stmt.Limit.One}}ScanRow{{else}}ScanRows{{end}}(rows, &result); err != nil{ return nil, err } - return result, nil{{end}} + + return result, nil } {{end}} diff --git a/internal/parser/funcmap.go b/internal/parser/funcmap.go index cfd5ea9..7ce4c6f 100644 --- a/internal/parser/funcmap.go +++ b/internal/parser/funcmap.go @@ -1,92 +1,95 @@ package parser -import "github.com/pingcap/parser/mysql" +import ( + "github.com/anqiansong/sqlgen/internal/spec" + "github.com/pingcap/parser/mysql" +) var funcMap = map[string]byte{ - // mysql.TypeLonglong - "count": mysql.TypeLonglong, - "length": mysql.TypeLonglong, - "char_length": mysql.TypeLonglong, - "locate": mysql.TypeLonglong, - "position": mysql.TypeLonglong, - "instr": mysql.TypeLonglong, - "field": mysql.TypeLonglong, - "sign": mysql.TypeLonglong, - "mod": mysql.TypeLonglong, - "unix_timestamp": mysql.TypeLonglong, - "sysdate": mysql.TypeLonglong, - "utc_date": mysql.TypeLonglong, - "utc_time": mysql.TypeLonglong, - "month": mysql.TypeLonglong, - "day": mysql.TypeLonglong, - "dayofmonth": mysql.TypeLonglong, - "dayofweek": mysql.TypeLonglong, - "dayofyear": mysql.TypeLonglong, - "week": mysql.TypeLonglong, - "weekday": mysql.TypeLonglong, - "weekofyear": mysql.TypeLonglong, - "quarter": mysql.TypeLonglong, - "hour": mysql.TypeLonglong, - "minute": mysql.TypeLonglong, - "second": mysql.TypeLonglong, - "extract": mysql.TypeLonglong, - "time_to_sec": mysql.TypeLonglong, - "to_days": mysql.TypeLonglong, - "to_seconds": mysql.TypeLonglong, - "datadiff": mysql.TypeLonglong, + // spec.TypeNullLongLong + "count": spec.TypeNullLongLong, + "length": spec.TypeNullLongLong, + "char_length": spec.TypeNullLongLong, + "locate": spec.TypeNullLongLong, + "position": spec.TypeNullLongLong, + "instr": spec.TypeNullLongLong, + "field": spec.TypeNullLongLong, + "sign": spec.TypeNullLongLong, + "mod": spec.TypeNullLongLong, + "unix_timestamp": spec.TypeNullLongLong, + "sysdate": spec.TypeNullLongLong, + "utc_date": spec.TypeNullLongLong, + "utc_time": spec.TypeNullLongLong, + "month": spec.TypeNullLongLong, + "day": spec.TypeNullLongLong, + "dayofmonth": spec.TypeNullLongLong, + "dayofweek": spec.TypeNullLongLong, + "dayofyear": spec.TypeNullLongLong, + "week": spec.TypeNullLongLong, + "weekday": spec.TypeNullLongLong, + "weekofyear": spec.TypeNullLongLong, + "quarter": spec.TypeNullLongLong, + "hour": spec.TypeNullLongLong, + "minute": spec.TypeNullLongLong, + "second": spec.TypeNullLongLong, + "extract": spec.TypeNullLongLong, + "time_to_sec": spec.TypeNullLongLong, + "to_days": spec.TypeNullLongLong, + "to_seconds": spec.TypeNullLongLong, + "datadiff": spec.TypeNullLongLong, - // mysql.TypeNewDecimal - "avg": mysql.TypeNewDecimal, - "abs": mysql.TypeNewDecimal, - "ceil": mysql.TypeNewDecimal, - "floor": mysql.TypeNewDecimal, - "round": mysql.TypeNewDecimal, - "rand": mysql.TypeNewDecimal, - "pi": mysql.TypeNewDecimal, - "truncate": mysql.TypeNewDecimal, - "pow": mysql.TypeNewDecimal, - "sqrt": mysql.TypeNewDecimal, - "exp": mysql.TypeNewDecimal, - "log": mysql.TypeNewDecimal, - "log10": mysql.TypeNewDecimal, - "radians": mysql.TypeNewDecimal, - "degrees": mysql.TypeNewDecimal, - "sin": mysql.TypeNewDecimal, - "cos": mysql.TypeNewDecimal, - "tan": mysql.TypeNewDecimal, - "cot": mysql.TypeNewDecimal, - "asin": mysql.TypeNewDecimal, - "acos": mysql.TypeNewDecimal, - "atan": mysql.TypeNewDecimal, + // spec.TypeNullDecimal + "avg": spec.TypeNullDecimal, + "abs": spec.TypeNullDecimal, + "ceil": spec.TypeNullDecimal, + "floor": spec.TypeNullDecimal, + "round": spec.TypeNullDecimal, + "rand": spec.TypeNullDecimal, + "pi": spec.TypeNullDecimal, + "truncate": spec.TypeNullDecimal, + "pow": spec.TypeNullDecimal, + "sqrt": spec.TypeNullDecimal, + "exp": spec.TypeNullDecimal, + "log": spec.TypeNullDecimal, + "log10": spec.TypeNullDecimal, + "radians": spec.TypeNullDecimal, + "degrees": spec.TypeNullDecimal, + "sin": spec.TypeNullDecimal, + "cos": spec.TypeNullDecimal, + "tan": spec.TypeNullDecimal, + "cot": spec.TypeNullDecimal, + "asin": spec.TypeNullDecimal, + "acos": spec.TypeNullDecimal, + "atan": spec.TypeNullDecimal, - // mysql.TypeString - "concat_ws": mysql.TypeString, - "concat": mysql.TypeString, - "insert": mysql.TypeString, - "upper": mysql.TypeString, - "ucaase": mysql.TypeString, - "lower": mysql.TypeString, - "lcase": mysql.TypeString, - "left": mysql.TypeString, - "right": mysql.TypeString, - "lpad": mysql.TypeString, - "rpad": mysql.TypeString, - "replace": mysql.TypeString, - "substring": mysql.TypeString, - "substr": mysql.TypeString, - "trim": mysql.TypeString, - "ltrim": mysql.TypeString, - "rtrim": mysql.TypeString, - "reverse": mysql.TypeString, - "repeat": mysql.TypeString, - "space": mysql.TypeString, - "strcmp": mysql.TypeString, - "mid": mysql.TypeString, - "from_unixtime": mysql.TypeString, - "month_name": mysql.TypeString, - "day_name": mysql.TypeString, - "date_format": mysql.TypeString, - "time_format": mysql.TypeString, + // spec.TypeNullString + "concat_ws": spec.TypeNullString, + "concat": spec.TypeNullString, + "insert": spec.TypeNullString, + "upper": spec.TypeNullString, + "ucaase": spec.TypeNullString, + "lower": spec.TypeNullString, + "lcase": spec.TypeNullString, + "left": spec.TypeNullString, + "right": spec.TypeNullString, + "lpad": spec.TypeNullString, + "rpad": spec.TypeNullString, + "replace": spec.TypeNullString, + "substring": spec.TypeNullString, + "substr": spec.TypeNullString, + "trim": spec.TypeNullString, + "ltrim": spec.TypeNullString, + "rtrim": spec.TypeNullString, + "reverse": spec.TypeNullString, + "repeat": spec.TypeNullString, + "space": spec.TypeNullString, + "strcmp": spec.TypeNullString, + "mid": spec.TypeNullString, + "from_unixtime": spec.TypeNullString, + "month_name": spec.TypeNullString, + "day_name": spec.TypeNullString, + "date_format": spec.TypeNullString, + "time_format": spec.TypeNullString, // mysql.TypeDate "curdate": mysql.TypeDate, diff --git a/internal/parser/select.go b/internal/parser/select.go index 56a8d92..aae4351 100644 --- a/internal/parser/select.go +++ b/internal/parser/select.go @@ -4,13 +4,12 @@ import ( "fmt" "strings" + "github.com/anqiansong/sqlgen/internal/set" + "github.com/anqiansong/sqlgen/internal/spec" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/parser/test_driver" - - "github.com/anqiansong/sqlgen/internal/set" - "github.com/anqiansong/sqlgen/internal/spec" ) func parseSelect(stmt *ast.SelectStmt) (*spec.SelectStmt, error) { @@ -416,9 +415,10 @@ func parseFieldList(fieldList *ast.FieldList, from string) (spec.Fields, string, } selectField = append(selectField, funcSql) columnSet.Add(spec.Field{ - ASName: f.AsName.String(), - ColumnName: columnName, - TP: tp, + ASName: f.AsName.String(), + ColumnName: columnName, + TP: tp, + AggregateCall: aggregate, }) } diff --git a/internal/spec/converter.go b/internal/spec/converter.go index a6eee60..989407d 100644 --- a/internal/spec/converter.go +++ b/internal/spec/converter.go @@ -127,6 +127,7 @@ func convertField(table *Table, fields []Field) (Columns, error) { column, ok := table.GetColumnByName(f.ColumnName) if ok { column.Name = name + column.AggregateCall = f.AggregateCall list = append(list, column) } else { return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) @@ -136,8 +137,9 @@ func convertField(table *Table, fields []Field) (Columns, error) { return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) } list = append(list, Column{ - Name: name, - TP: f.TP, + Name: name, + TP: f.TP, + AggregateCall: f.AggregateCall, }) } } diff --git a/internal/spec/stmt.go b/internal/spec/stmt.go index 8114aa8..20ae5e4 100644 --- a/internal/spec/stmt.go +++ b/internal/spec/stmt.go @@ -16,9 +16,10 @@ type Fields []Field // Field represents a select filed. type Field struct { - ASName string - ColumnName string - TP byte + ASName string + ColumnName string + TP byte + AggregateCall bool } // Limit represents a limit clause. diff --git a/internal/spec/table.go b/internal/spec/table.go index 33f0cee..32c6fe7 100644 --- a/internal/spec/table.go +++ b/internal/spec/table.go @@ -28,7 +28,8 @@ type Column struct { // Name is the name of the column. Name string // TP is the type of the column. - TP byte + TP byte + AggregateCall bool } // ColumnOption is a column option. diff --git a/internal/spec/type.go b/internal/spec/type.go index 163a079..754d2cc 100644 --- a/internal/spec/type.go +++ b/internal/spec/type.go @@ -1,19 +1,30 @@ package spec import ( + "database/sql" "fmt" + "github.com/anqiansong/sqlgen/internal/parameter" "github.com/pingcap/parser/mysql" +) - "github.com/anqiansong/sqlgen/internal/parameter" +const ( + // TypeNullLongLong is a type extension for mysql.TypeLongLong. + TypeNullLongLong byte = 0xf0 + // TypeNullDecimal is a type extension for mysql.TypeDecimal. + TypeNullDecimal byte = 0xf1 + // TypeNullString is a type extension for mysql.TypeString. + TypeNullString byte = 0xf2 ) const defaultThirdDecimalPkg = "github.com/shopspring/decimal" type typeKey struct { - tp byte - signed bool - thirdPkg string + tp byte + signed bool + thirdPkg string + aggregateCall bool + sql.NullFloat64 } var typeMapper = map[typeKey]string{ @@ -22,6 +33,8 @@ var typeMapper = map[typeKey]string{ typeKey{tp: mysql.TypeShort}: "int16", typeKey{tp: mysql.TypeShort, signed: true}: "uint16", typeKey{tp: mysql.TypeLong}: "int32", + typeKey{tp: TypeNullLongLong}: "sql.NullInt64", + typeKey{tp: TypeNullLongLong, signed: true}: "sql.NullInt64", typeKey{tp: mysql.TypeLong, signed: true}: "uint32", typeKey{tp: mysql.TypeFloat}: "float64", typeKey{tp: mysql.TypeDouble}: "float64", @@ -40,7 +53,11 @@ var typeMapper = map[typeKey]string{ typeKey{ tp: mysql.TypeNewDecimal, thirdPkg: defaultThirdDecimalPkg, - }: "byte", + }: "decimal.Decimal", + typeKey{ + tp: TypeNullDecimal, + thirdPkg: defaultThirdDecimalPkg, + }: "decimal.NullDecimal", typeKey{tp: mysql.TypeEnum}: "string", typeKey{tp: mysql.TypeSet}: "string", typeKey{tp: mysql.TypeTinyBlob}: "string", @@ -49,6 +66,43 @@ var typeMapper = map[typeKey]string{ typeKey{tp: mysql.TypeBlob}: "string", typeKey{tp: mysql.TypeVarString}: "string", typeKey{tp: mysql.TypeString}: "string", + typeKey{tp: TypeNullString}: "sql.NullString", + + // aggregate functions + typeKey{tp: mysql.TypeTiny, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeTiny, signed: true, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeShort, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeShort, signed: true, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeLong, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeLong, signed: true, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeFloat, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeDouble, aggregateCall: true}: "sql.NullFloat64", + typeKey{tp: mysql.TypeLonglong, aggregateCall: true}: "sql.NullInt64", + typeKey{tp: mysql.TypeLonglong, signed: true, aggregateCall: true}: "sql.NullInt64", + typeKey{tp: mysql.TypeInt24, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeInt24, signed: true, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeYear, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeVarchar, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeBit, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeJSON, aggregateCall: true}: "sql.NullString", + typeKey{ + tp: mysql.TypeNewDecimal, + thirdPkg: defaultThirdDecimalPkg, + aggregateCall: true, + }: "decimal.NullDecimal", + typeKey{ + tp: TypeNullDecimal, + thirdPkg: defaultThirdDecimalPkg, + aggregateCall: true, + }: "decimal.NullDecimal", + typeKey{tp: mysql.TypeEnum, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeSet, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeTinyBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeMediumBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeLongBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeBlob, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeVarString, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", } // Type is the type of the column. @@ -56,10 +110,11 @@ type Type byte // DataType returns the Go type, third-package of the column. func (c Column) DataType() (parameter.Parameter, error) { - var key = typeKey{tp: c.TP, signed: c.Unsigned} + var key = typeKey{tp: c.TP, signed: c.Unsigned, aggregateCall: c.AggregateCall} if c.TP == mysql.TypeNewDecimal { key.thirdPkg = defaultThirdDecimalPkg } + goType, ok := typeMapper[key] if !ok { return parameter.Parameter{}, fmt.Errorf("unsupported type: %v", c.TP) @@ -73,3 +128,7 @@ func (c Column) GoType() (string, error) { p, err := c.DataType() return p.Type, err } + +func isNullType(tp byte) bool { + return tp >= TypeNullLongLong && tp <= TypeNullString +} From a6d3af1f9a385505cd6f0f6fa1974376561e0594 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Wed, 17 Aug 2022 00:04:48 +0800 Subject: [PATCH 08/10] Add example test --- example/bun/user_model.gen.go | 3 +- .../example_test/sql/user_model.gen_test.go | 20 ++++++++++ example/go.mod | 1 + example/go.sum | 2 + example/gorm/user_model.gen.go | 4 +- example/sql/user_model.gen.go | 3 +- example/sqlx/user_model.gen.go | 3 +- example/xorm/user_model.gen.go | 4 +- internal/parser/parser_test.go | 6 ++- internal/parser/test.sql | 2 +- internal/spec/converter.go | 3 ++ internal/spec/type.go | 39 ++++++++++--------- 12 files changed, 65 insertions(+), 25 deletions(-) diff --git a/example/bun/user_model.gen.go b/example/bun/user_model.gen.go index 0d167dd..8ee0788 100644 --- a/example/bun/user_model.gen.go +++ b/example/bun/user_model.gen.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/shopspring/decimal" "github.com/uptrace/bun" ) @@ -167,7 +168,7 @@ type FindMinIDResult struct { // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { bun.BaseModel `bun:"table:user"` - AvgID sql.NullInt64 `bun:"avgID" json:"avgID"` + AvgID decimal.NullDecimal `bun:"avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go index ad4dcd7..573be02 100644 --- a/example/example_test/sql/user_model.gen_test.go +++ b/example/example_test/sql/user_model.gen_test.go @@ -410,6 +410,26 @@ func TestFindMinID(t *testing.T) { })) } +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + func assertUserEqual(t *testing.T, expected, actual *model.User) { now := time.Now() expected.CreateAt = now diff --git a/example/go.mod b/example/go.mod index 715e1e3..8e40bee 100644 --- a/example/go.mod +++ b/example/go.mod @@ -7,6 +7,7 @@ require ( github.com/iancoleman/strcase v0.2.0 github.com/jmoiron/sqlx v1.3.5 github.com/satori/go.uuid v1.2.0 + github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.7.0 github.com/uptrace/bun v1.1.7 gorm.io/gorm v1.23.8 diff --git a/example/go.sum b/example/go.sum index 9f84aab..db2c5c9 100644 --- a/example/go.sum +++ b/example/go.sum @@ -324,6 +324,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= diff --git a/example/gorm/user_model.gen.go b/example/gorm/user_model.gen.go index 69a88c1..32c3ea8 100644 --- a/example/gorm/user_model.gen.go +++ b/example/gorm/user_model.gen.go @@ -9,6 +9,8 @@ import ( "time" "gorm.io/gorm" + + "github.com/shopspring/decimal" ) // UserModel represents a user model. @@ -161,7 +163,7 @@ type FindMinIDResult struct { // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID sql.NullInt64 `gorm:"column:avgID" json:"avgID"` + AvgID decimal.NullDecimal `gorm:"column:avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index 990ecec..c454cab 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/shopspring/decimal" "xorm.io/builder" ) @@ -162,7 +163,7 @@ type FindMinIDResult struct { // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID sql.NullInt64 `json:"avgID"` + AvgID decimal.NullDecimal `json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index dbf1327..ae6e2f7 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jmoiron/sqlx" + "github.com/shopspring/decimal" "xorm.io/builder" ) @@ -162,7 +163,7 @@ type FindMinIDResult struct { // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID sql.NullInt64 `db:"avgID" json:"avgID"` + AvgID decimal.NullDecimal `db:"avgID" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/example/xorm/user_model.gen.go b/example/xorm/user_model.gen.go index cf8766d..3210717 100644 --- a/example/xorm/user_model.gen.go +++ b/example/xorm/user_model.gen.go @@ -9,6 +9,8 @@ import ( "time" "xorm.io/xorm" + + "github.com/shopspring/decimal" ) // UserModel represents a user model. @@ -161,7 +163,7 @@ type FindMinIDResult struct { // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { - AvgID sql.NullInt64 `xorm:"'avgID'" json:"avgID"` + AvgID decimal.NullDecimal `xorm:"'avgID'" json:"avgID"` } // UpdateWhereParameter is a where parameter structure. diff --git a/internal/parser/parser_test.go b/internal/parser/parser_test.go index 79224a8..20b25b4 100644 --- a/internal/parser/parser_test.go +++ b/internal/parser/parser_test.go @@ -27,7 +27,11 @@ func TestParse(t *testing.T) { ctxOne := ctx[0] selectOne := ctxOne.SelectStmt[0] - selectOne.Where.ParameterStructure("test") + p, err := selectOne.ColumnInfo[0].DataType() + if err != nil { + log.Fatal(err) + } + fmt.Println(p) } func TestFrom(t *testing.T) { diff --git a/internal/parser/test.sql b/internal/parser/test.sql index d8401f2..1d87114 100644 --- a/internal/parser/test.sql +++ b/internal/parser/test.sql @@ -16,5 +16,5 @@ CREATE TABLE `user` ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT '用户表' COLLATE=utf8mb4_general_ci; -- fn: test -select * from user where id = ? and name in (?) limit 1; +select avg(id) AS acg from user where id > ?; diff --git a/internal/spec/converter.go b/internal/spec/converter.go index 989407d..83d2d1c 100644 --- a/internal/spec/converter.go +++ b/internal/spec/converter.go @@ -128,6 +128,9 @@ func convertField(table *Table, fields []Field) (Columns, error) { if ok { column.Name = name column.AggregateCall = f.AggregateCall + if f.TP != mysql.TypeUnspecified { + column.TP = f.TP + } list = append(list, column) } else { return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) diff --git a/internal/spec/type.go b/internal/spec/type.go index 754d2cc..1a62939 100644 --- a/internal/spec/type.go +++ b/internal/spec/type.go @@ -33,8 +33,6 @@ var typeMapper = map[typeKey]string{ typeKey{tp: mysql.TypeShort}: "int16", typeKey{tp: mysql.TypeShort, signed: true}: "uint16", typeKey{tp: mysql.TypeLong}: "int32", - typeKey{tp: TypeNullLongLong}: "sql.NullInt64", - typeKey{tp: TypeNullLongLong, signed: true}: "sql.NullInt64", typeKey{tp: mysql.TypeLong, signed: true}: "uint32", typeKey{tp: mysql.TypeFloat}: "float64", typeKey{tp: mysql.TypeDouble}: "float64", @@ -69,22 +67,17 @@ var typeMapper = map[typeKey]string{ typeKey{tp: TypeNullString}: "sql.NullString", // aggregate functions - typeKey{tp: mysql.TypeTiny, aggregateCall: true}: "sql.NullInt16", - typeKey{tp: mysql.TypeTiny, signed: true, aggregateCall: true}: "sql.NullInt16", - typeKey{tp: mysql.TypeShort, aggregateCall: true}: "sql.NullInt16", - typeKey{tp: mysql.TypeShort, signed: true, aggregateCall: true}: "sql.NullInt16", - typeKey{tp: mysql.TypeLong, aggregateCall: true}: "sql.NullInt32", - typeKey{tp: mysql.TypeLong, signed: true, aggregateCall: true}: "sql.NullInt32", - typeKey{tp: mysql.TypeFloat, aggregateCall: true}: "sql.NullInt32", - typeKey{tp: mysql.TypeDouble, aggregateCall: true}: "sql.NullFloat64", - typeKey{tp: mysql.TypeLonglong, aggregateCall: true}: "sql.NullInt64", - typeKey{tp: mysql.TypeLonglong, signed: true, aggregateCall: true}: "sql.NullInt64", - typeKey{tp: mysql.TypeInt24, aggregateCall: true}: "sql.NullInt32", - typeKey{tp: mysql.TypeInt24, signed: true, aggregateCall: true}: "sql.NullInt32", - typeKey{tp: mysql.TypeYear, aggregateCall: true}: "sql.NullString", - typeKey{tp: mysql.TypeVarchar, aggregateCall: true}: "sql.NullString", - typeKey{tp: mysql.TypeBit, aggregateCall: true}: "sql.NullInt16", - typeKey{tp: mysql.TypeJSON, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeTiny, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeShort, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeLong, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeFloat, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeDouble, aggregateCall: true}: "sql.NullFloat64", + typeKey{tp: mysql.TypeLonglong, aggregateCall: true}: "sql.NullInt64", + typeKey{tp: mysql.TypeInt24, aggregateCall: true}: "sql.NullInt32", + typeKey{tp: mysql.TypeYear, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeVarchar, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeBit, aggregateCall: true}: "sql.NullInt16", + typeKey{tp: mysql.TypeJSON, aggregateCall: true}: "sql.NullString", typeKey{ tp: mysql.TypeNewDecimal, thirdPkg: defaultThirdDecimalPkg, @@ -103,6 +96,13 @@ var typeMapper = map[typeKey]string{ typeKey{tp: mysql.TypeBlob, aggregateCall: true}: "sql.NullString", typeKey{tp: mysql.TypeVarString, aggregateCall: true}: "sql.NullString", typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", + typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", + typeKey{tp: TypeNullLongLong}: "sql.NullInt64", + typeKey{tp: TypeNullDecimal}: "decimal.NullDecimal", + typeKey{tp: TypeNullString}: "sql.NullString", + typeKey{tp: TypeNullLongLong, aggregateCall: true}: "sql.NullInt64", + typeKey{tp: TypeNullDecimal, aggregateCall: true}: "decimal.NullDecimal", + typeKey{tp: TypeNullString, aggregateCall: true}: "sql.NullString", } // Type is the type of the column. @@ -111,6 +111,9 @@ type Type byte // DataType returns the Go type, third-package of the column. func (c Column) DataType() (parameter.Parameter, error) { var key = typeKey{tp: c.TP, signed: c.Unsigned, aggregateCall: c.AggregateCall} + if c.AggregateCall { + key = typeKey{tp: c.TP, aggregateCall: c.AggregateCall} + } if c.TP == mysql.TypeNewDecimal { key.thirdPkg = defaultThirdDecimalPkg } From 36c19bf61f38cefb7d833586bd526abe8e2ce586 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Wed, 17 Aug 2022 14:40:54 +0800 Subject: [PATCH 09/10] Add example_test --- example/bun/user_model.gen.go | 17 +- example/example.sql | 2 +- example/example_test/bun/mock.go | 23 + .../example_test/bun/user_model.gen_test.go | 561 +++++++++++++++++- example/example_test/gorm/mock.go | 23 + .../example_test/gorm/user_model.gen_test.go | 546 ++++++++++++++++- .../example_test/sql/user_model.gen_test.go | 112 +++- example/go.mod | 13 +- example/go.sum | 29 +- example/gorm/user_model.gen.go | 265 +++++---- example/sql/user_model.gen.go | 10 +- example/sqlx/user_model.gen.go | 10 +- example/xorm/user_model.gen.go | 16 +- internal/gen/bun/bun_gen.tpl | 9 +- internal/gen/gorm/gorm.go | 3 + internal/gen/gorm/gorm_gen.tpl | 46 +- internal/gen/xorm/xorm_gen.tpl | 6 +- 17 files changed, 1515 insertions(+), 176 deletions(-) create mode 100644 example/example_test/bun/mock.go create mode 100644 example/example_test/gorm/mock.go diff --git a/example/bun/user_model.gen.go b/example/bun/user_model.gen.go index 8ee0788..9b7b212 100644 --- a/example/bun/user_model.gen.go +++ b/example/bun/user_model.gen.go @@ -224,11 +224,7 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) error { return fmt.Errorf("data is empty") } - var list []User - for _, v := range data { - list = append(list, *v) - } - + list := data[:] _, err := m.db.NewInsert().Model(&list).Exec(ctx) return err } @@ -463,7 +459,8 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("user") + db.Model(&map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -482,7 +479,8 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("user") + db.Model(&map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -498,10 +496,11 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("user") + db.Model(&map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, diff --git a/example/example.sql b/example/example.sql index 71eb623..79c0d7a 100644 --- a/example/example.sql +++ b/example/example.sql @@ -42,7 +42,7 @@ update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickn -- test case: update one with order by desc, limit count clause. -- fn: UpdateOrderByIdDescLimitCount -update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; +update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; -- operation: read diff --git a/example/example_test/bun/mock.go b/example/example_test/bun/mock.go new file mode 100644 index 0000000..3687a30 --- /dev/null +++ b/example/example_test/bun/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/bun" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/bun/user_model.gen_test.go b/example/example_test/bun/user_model.gen_test.go index c0ac838..a390b56 100644 --- a/example/example_test/bun/user_model.gen_test.go +++ b/example/example_test/bun/user_model.gen_test.go @@ -1 +1,560 @@ -package bun +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/bun" + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/mysqldialect" + "github.com/uptrace/bun/extra/bundebug" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *bun.DB +) + +func TestMain(m *testing.M) { + conn, err := sql.Open("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + if err != nil { + log.Fatalln(err) + } + + db = bun.NewDB(conn, mysqldialect.New()) + err = db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + db.AddQueryHook(bundebug.NewQueryHook( + bundebug.WithVerbose(true), + )) + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(db *bun.DB) { + err := db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.ExecContext(ctx, `SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + return err + } + _, err = tx.ExecContext(ctx, `truncate table user`) + if err != nil { + log.Fatalln(err) + } + _, err = tx.ExecContext(ctx, `alter table user auto_increment=1`) + return err + }) + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/gorm/mock.go b/example/example_test/gorm/mock.go new file mode 100644 index 0000000..8bd62b8 --- /dev/null +++ b/example/example_test/gorm/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/gorm" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/gorm/user_model.gen_test.go b/example/example_test/gorm/user_model.gen_test.go index 4506bcf..96aa3d1 100644 --- a/example/example_test/gorm/user_model.gen_test.go +++ b/example/example_test/gorm/user_model.gen_test.go @@ -1 +1,545 @@ -package gorm +package sql + +import ( + "context" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/gorm" + "github.com/stretchr/testify/assert" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *gorm.DB +) + +func TestMain(m *testing.M) { + var err error + conn := mysql.Open("root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + db, err = gorm.Open(conn, &gorm.Config{}) + if err != nil { + fmt.Println("gorm open error:", err) + return + } + + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(db *gorm.DB) { + err := db.Transaction(func(tx *gorm.DB) error { + tx.Exec(`SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + tx.Exec(`truncate table user`) + tx.Exec(`alter table user auto_increment=1`) + return nil + }) + if err != nil { + log.Fatal(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/sql/user_model.gen_test.go b/example/example_test/sql/user_model.gen_test.go index 573be02..d793a0c 100644 --- a/example/example_test/sql/user_model.gen_test.go +++ b/example/example_test/sql/user_model.gen_test.go @@ -3,6 +3,7 @@ package sql import ( "context" "database/sql" + "fmt" "log" "sort" "testing" @@ -26,6 +27,12 @@ func TestMain(m *testing.M) { log.Fatalln(err) } + err = db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + um = model.NewUserModel(db, getScanner()) m.Run() } @@ -340,7 +347,7 @@ func TestFindAllCount(t *testing.T) { t.Run("noRows", initAndRun(func(t *testing.T) { countID, err := um.FindAllCount(ctx) assert.NoError(t, err) - assert.Equal(t, uint64(0), countID.CountID) + assert.Equal(t, int64(0), countID.CountID.Int64) })) t.Run("FindAllCount", initAndRun(func(t *testing.T) { @@ -349,7 +356,7 @@ func TestFindAllCount(t *testing.T) { assert.NoError(t, err) actual, err := um.FindAllCount(ctx) assert.NoError(t, err) - assert.Equal(t, uint64(1), actual.CountID) + assert.Equal(t, int64(1), actual.CountID.Int64) })) } @@ -430,6 +437,107 @@ func TestFindAvgID(t *testing.T) { })) } +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + func assertUserEqual(t *testing.T, expected, actual *model.User) { now := time.Now() expected.CreateAt = now diff --git a/example/go.mod b/example/go.mod index 8e40bee..937ae7c 100644 --- a/example/go.mod +++ b/example/go.mod @@ -7,9 +7,12 @@ require ( github.com/iancoleman/strcase v0.2.0 github.com/jmoiron/sqlx v1.3.5 github.com/satori/go.uuid v1.2.0 - github.com/shopspring/decimal v1.3.1 + github.com/shopspring/decimal v1.2.0 github.com/stretchr/testify v1.7.0 github.com/uptrace/bun v1.1.7 + github.com/uptrace/bun/dialect/mysqldialect v1.1.7 + github.com/uptrace/bun/extra/bundebug v1.1.7 + gorm.io/driver/mysql v1.3.6 gorm.io/gorm v1.23.8 xorm.io/builder v0.3.12 xorm.io/xorm v1.3.1 @@ -17,11 +20,14 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.13.0 // indirect github.com/goccy/go-json v0.8.1 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.4 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -29,6 +35,7 @@ require ( github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect + golang.org/x/mod v0.5.1 // indirect + golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/example/go.sum b/example/go.sum index db2c5c9..79fe62d 100644 --- a/example/go.sum +++ b/example/go.sum @@ -58,6 +58,8 @@ github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4s github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -187,8 +189,9 @@ github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= @@ -226,14 +229,18 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= @@ -323,9 +330,8 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= -github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= @@ -356,6 +362,10 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/uptrace/bun v1.1.7 h1:biOoh5dov69hQPBlaRsXSHoEOIEnCxFzQvUmbscSNJI= github.com/uptrace/bun v1.1.7/go.mod h1:Z2Pd3cRvNKbrYuL6Gp1XGjA9QEYz+rDz5KkEi9MZLnQ= +github.com/uptrace/bun/dialect/mysqldialect v1.1.7 h1:eMDtsuu5BRuh0P2l0/j0Qv5UBmcqJE0u3F8Zy//klNM= +github.com/uptrace/bun/dialect/mysqldialect v1.1.7/go.mod h1:cCSZH3IULSGaG76Z96mAC7O74MeIYGfDX7CWGanGc0s= +github.com/uptrace/bun/extra/bundebug v1.1.7 h1:YbW7i9pUfPJMzclSzdHslIvAAR0WO9dW34ctL1Xh+UM= +github.com/uptrace/bun/extra/bundebug v1.1.7/go.mod h1:WoBnTrBG9CXITZUw+UfF+DYjWi71boo8FKZGuS5qDzA= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= @@ -405,8 +415,9 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= +golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -460,10 +471,12 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 h1:nonptSpoQ4vQjyraW20DXPAglgQfVnM9ZC6MmNLMR60= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -537,6 +550,8 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.3.6 h1:BhX1Y/RyALb+T9bZ3t07wLnPZBukt+IRkMn8UZSNbGM= +gorm.io/driver/mysql v1.3.6/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/gorm v1.23.8 h1:h8sGJ+biDgBA1AD1Ha9gFCx7h8npU7AsLdlkX0n2TpE= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/example/gorm/user_model.gen.go b/example/gorm/user_model.gen.go index 32c3ea8..fe88929 100644 --- a/example/gorm/user_model.gen.go +++ b/example/gorm/user_model.gen.go @@ -15,7 +15,7 @@ import ( // UserModel represents a user model. type UserModel struct { - db gorm.DB + db *gorm.DB } // User represents a user struct data. @@ -36,16 +36,22 @@ type FindOneWhereParameter struct { IdEqual uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneByNameWhereParameter is a where parameter structure. type FindOneByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameWhereParameter is a where parameter structure. type FindOneGroupByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameHavingNameWhereParameter is a where parameter structure. type FindOneGroupByNameHavingNameWhereParameter struct { NameEqual string @@ -56,6 +62,10 @@ type FindOneGroupByNameHavingNameHavingParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitWhereParameter is a where parameter structure. type FindLimitWhereParameter struct { IdGT uint64 @@ -66,12 +76,16 @@ type FindLimitLimitParameter struct { Count int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitOffsetLimitParameter is a limit parameter structure. type FindLimitOffsetLimitParameter struct { Count int Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupLimitOffsetWhereParameter is a where parameter structure. type FindGroupLimitOffsetWhereParameter struct { IdGT uint64 @@ -83,6 +97,8 @@ type FindGroupLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingLimitOffsetWhereParameter struct { IdGT uint64 @@ -99,6 +115,8 @@ type FindGroupHavingLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderAscLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderAscLimitOffsetWhereParameter struct { IdGT uint64 @@ -115,6 +133,8 @@ type FindGroupHavingOrderAscLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderDescLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderDescLimitOffsetWhereParameter struct { IdGT uint64 @@ -131,16 +151,25 @@ type FindGroupHavingOrderDescLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOnePartWhereParameter is a where parameter structure. type FindOnePartWhereParameter struct { IdGT uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindAllCountResult is a find all count result. type FindAllCountResult struct { CountID sql.NullInt64 `gorm:"column:countID" json:"countID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountResult) TableName() string { + return "user" +} + // FindAllCountWhereWhereParameter is a where parameter structure. type FindAllCountWhereWhereParameter struct { IdGT uint64 @@ -151,21 +180,41 @@ type FindAllCountWhereResult struct { CountID sql.NullInt64 `gorm:"column:countID" json:"countID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountWhereResult) TableName() string { + return "user" +} + // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { MaxID sql.NullInt64 `gorm:"column:maxID" json:"maxID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMaxIDResult) TableName() string { + return "user" +} + // FindMinIDResult is a find min id result. type FindMinIDResult struct { MinID sql.NullInt64 `gorm:"column:minID" json:"minID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMinIDResult) TableName() string { + return "user" +} + // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { AvgID decimal.NullDecimal `gorm:"column:avgID" json:"avgID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAvgIDResult) TableName() string { + return "user" +} + // UpdateWhereParameter is a where parameter structure. type UpdateWhereParameter struct { IdEqual uint64 @@ -181,6 +230,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -217,7 +271,7 @@ func (User) TableName() string { } // NewUserModel returns a new user model. -func NewUserModel(db gorm.DB) *UserModel { +func NewUserModel(db *gorm.DB) *UserModel { return &UserModel{db: db} } @@ -228,11 +282,7 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) error { } db := m.db.WithContext(ctx) - var list []User - for _, v := range data { - list = append(list, *v) - } - + list := data[:] return db.Create(&list).Error } @@ -241,10 +291,10 @@ func (m *UserModel) Create(ctx context.Context, data ...*User) error { func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id = ?`, where.IdEqual) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -253,10 +303,10 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (* func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`name = ?`, where.NameEqual) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -265,11 +315,11 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupByNameWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`name = ?`, where.NameEqual) - db.Group(`name`) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Group(`name`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -278,12 +328,12 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where FindOneGroupByNameHavingNameWhereParameter, having FindOneGroupByNameHavingNameHavingParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`name = ?`, where.NameEqual) - db.Group(`name`) - db.Having(`name = ?`, having.NameEqual) - db.Limit(1) - db.Find(result) + db = db.Select(`*`) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Group(`name`) + db = db.Having(`name = ?`, having.NameEqual) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -292,8 +342,8 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find func (m *UserModel) FindAll(ctx context.Context) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Find(&result) + db = db.Select(`*`) + db = db.Find(&result) return result, db.Error } @@ -302,10 +352,10 @@ func (m *UserModel) FindAll(ctx context.Context) ([]*User, error) { func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter, limit FindLimitLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -314,9 +364,9 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -325,11 +375,11 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLimitOffsetWhereParameter, limit FindGroupLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -338,12 +388,12 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGroupHavingLimitOffsetWhereParameter, having FindGroupHavingLimitOffsetHavingParameter, limit FindGroupHavingLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Having(`id > ?`, having.IdGT) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Having(`id > ?`, having.IdGT) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -352,13 +402,13 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, where FindGroupHavingOrderAscLimitOffsetWhereParameter, having FindGroupHavingOrderAscLimitOffsetHavingParameter, limit FindGroupHavingOrderAscLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Having(`id > ?`, having.IdGT) - db.Order(`id`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Having(`id > ?`, having.IdGT) + db = db.Order(`id`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -367,13 +417,13 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, where FindGroupHavingOrderDescLimitOffsetWhereParameter, having FindGroupHavingOrderDescLimitOffsetHavingParameter, limit FindGroupHavingOrderDescLimitOffsetLimitParameter) ([]*User, error) { var result []*User var db = m.db.WithContext(ctx) - db.Select(`*`) - db.Where(`id > ?`, where.IdGT) - db.Group(`name`) - db.Having(`id > ?`, having.IdGT) - db.Order(`id desc`) - db.Offset(limit.Offset).Limit(limit.Count) - db.Find(&result) + db = db.Select(`*`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Group(`name`) + db = db.Having(`id > ?`, having.IdGT) + db = db.Order(`id desc`) + db = db.Offset(limit.Offset).Limit(limit.Count) + db = db.Find(&result) return result, db.Error } @@ -382,10 +432,10 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParameter) (*User, error) { var result = new(User) var db = m.db.WithContext(ctx) - db.Select(`name, password, mobile`) - db.Where(`id > ?`, where.IdGT) - db.Limit(1) - db.Find(result) + db = db.Select(`name, password, mobile`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -394,9 +444,9 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam func (m *UserModel) FindAllCount(ctx context.Context) (*FindAllCountResult, error) { var result = new(FindAllCountResult) var db = m.db.WithContext(ctx) - db.Select(`count(id) AS countID`) - db.Limit(1) - db.Find(result) + db = db.Select(`count(id) AS countID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -405,10 +455,10 @@ func (m *UserModel) FindAllCount(ctx context.Context) (*FindAllCountResult, erro func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhereWhereParameter) (*FindAllCountWhereResult, error) { var result = new(FindAllCountWhereResult) var db = m.db.WithContext(ctx) - db.Select(`count(id) AS countID`) - db.Where(`id > ?`, where.IdGT) - db.Limit(1) - db.Find(result) + db = db.Select(`count(id) AS countID`) + db = db.Where(`id > ?`, where.IdGT) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -417,9 +467,9 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe func (m *UserModel) FindMaxID(ctx context.Context) (*FindMaxIDResult, error) { var result = new(FindMaxIDResult) var db = m.db.WithContext(ctx) - db.Select(`max(id) AS maxID`) - db.Limit(1) - db.Find(result) + db = db.Select(`max(id) AS maxID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -428,9 +478,9 @@ func (m *UserModel) FindMaxID(ctx context.Context) (*FindMaxIDResult, error) { func (m *UserModel) FindMinID(ctx context.Context) (*FindMinIDResult, error) { var result = new(FindMinIDResult) var db = m.db.WithContext(ctx) - db.Select(`min(id) AS minID`) - db.Limit(1) - db.Find(result) + db = db.Select(`min(id) AS minID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -439,9 +489,9 @@ func (m *UserModel) FindMinID(ctx context.Context) (*FindMinIDResult, error) { func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { var result = new(FindAvgIDResult) var db = m.db.WithContext(ctx) - db.Select(`avg(id) AS avgID`) - db.Limit(1) - db.Find(result) + db = db.Select(`avg(id) AS avgID`) + db = db.Limit(1) + db = db.Take(result) return result, db.Error } @@ -449,9 +499,9 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { var db = m.db.WithContext(ctx) - db.Model(&User{}) - db.Where(`id = ?`, where.IdEqual) - db.Updates(map[string]interface{}{ + db = db.Model(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Updates(map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -468,10 +518,10 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { var db = m.db.WithContext(ctx) - db.Model(&User{}) - db.Where(`id = ?`, where.IdEqual) - db.Order(`id desc`) - db.Updates(map[string]interface{}{ + db = db.Model(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Order(`id desc`) + db = db.Updates(map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -485,13 +535,14 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { var db = m.db.WithContext(ctx) - db.Model(&User{}) - db.Where(`id = ?`, where.IdEqual) - db.Order(`id desc`) - db.Updates(map[string]interface{}{ + db = db.Model(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Order(`id desc`) + db = db.Limit(limit.Count) + db = db.Updates(map[string]interface{}{ "name": data.Name, "password": data.Password, "mobile": data.Mobile, @@ -508,8 +559,8 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use // delete from `user` where `id` = ?; func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`id = ?`, where.IdEqual) - db.Delete(&User{}) + db = db.Where(`id = ?`, where.IdEqual) + db = db.Delete(&User{}) return db.Error } @@ -517,8 +568,8 @@ func (m *UserModel) DeleteOne(ctx context.Context, where DeleteOneWhereParameter // delete from `user` where `name` = ?; func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Delete(&User{}) return db.Error } @@ -526,9 +577,9 @@ func (m *UserModel) DeleteOneByName(ctx context.Context, where DeleteOneByNameWh // delete from `user` where `name` = ? order by id; func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOrderByIDAscWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Order(`id`) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Order(`id`) + db = db.Delete(&User{}) return db.Error } @@ -536,9 +587,9 @@ func (m *UserModel) DeleteOneOrderByIDAsc(ctx context.Context, where DeleteOneOr // delete from `user` where `name` = ? order by id desc; func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneOrderByIDDescWhereParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Order(`id desc`) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Order(`id desc`) + db = db.Delete(&User{}) return db.Error } @@ -546,9 +597,9 @@ func (m *UserModel) DeleteOneOrderByIDDesc(ctx context.Context, where DeleteOneO // delete from `user` where `name` = ? order by id desc limit ?; func (m *UserModel) DeleteOneOrderByIDDescLimitCount(ctx context.Context, where DeleteOneOrderByIDDescLimitCountWhereParameter, limit DeleteOneOrderByIDDescLimitCountLimitParameter) error { var db = m.db.WithContext(ctx) - db.Where(`name = ?`, where.NameEqual) - db.Order(`id desc`) - db.Limit(limit.Count) - db.Delete(&User{}) + db = db.Where(`name = ?`, where.NameEqual) + db = db.Order(`id desc`) + db = db.Limit(limit.Count) + db = db.Delete(&User{}) return db.Error } diff --git a/example/sql/user_model.gen.go b/example/sql/user_model.gen.go index c454cab..81cbcc8 100644 --- a/example/sql/user_model.gen.go +++ b/example/sql/user_model.gen.go @@ -181,6 +181,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -754,8 +759,8 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { b := builder.MySQL() b.Update(builder.Eq{ "name": data.Name, @@ -770,6 +775,7 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.OrderBy(`id desc`) + b.Limit(limit.Count) query, args, err := b.ToSQL() if err != nil { return err diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index ae6e2f7..dc8d0a5 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -181,6 +181,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -865,8 +870,8 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { b := builder.MySQL() b.Update(builder.Eq{ "name": data.Name, @@ -881,6 +886,7 @@ func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *Use b.From("`user`") b.Where(builder.Expr(`id = ?`, where.IdEqual)) b.OrderBy(`id desc`) + b.Limit(limit.Count) query, args, err := b.ToSQL() if err != nil { return err diff --git a/example/xorm/user_model.gen.go b/example/xorm/user_model.gen.go index 3210717..5b35dc5 100644 --- a/example/xorm/user_model.gen.go +++ b/example/xorm/user_model.gen.go @@ -181,6 +181,11 @@ type UpdateOrderByIdDescLimitCountWhereParameter struct { IdEqual uint64 } +// UpdateOrderByIdDescLimitCountLimitParameter is a limit parameter structure. +type UpdateOrderByIdDescLimitCountLimitParameter struct { + Count int +} + // DeleteOneWhereParameter is a where parameter structure. type DeleteOneWhereParameter struct { IdEqual uint64 @@ -227,11 +232,7 @@ func (m *UserModel) Insert(ctx context.Context, data ...*User) error { } var session = m.engine.Context(ctx) - var list []User - for _, v := range data { - list = append(list, *v) - } - + list := data[:] _, err := session.Insert(&list) return err } @@ -483,11 +484,12 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U } // UpdateOrderByIdDescLimitCount is generated from sql: -// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; -func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter) error { +// update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; +func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { var session = m.engine.Context(ctx) session.Where(`id = ?`, where.IdEqual) session.OrderBy(`id desc`) + session.Limit(limit.Count) _, err := session.Update(map[string]interface{}{ "name": data.Name, "password": data.Password, diff --git a/internal/gen/bun/bun_gen.tpl b/internal/gen/bun/bun_gen.tpl index 222f7eb..b8aa810 100644 --- a/internal/gen/bun/bun_gen.tpl +++ b/internal/gen/bun/bun_gen.tpl @@ -50,11 +50,7 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* return fmt.Errorf("data is empty") } - var list []{{UpperCamel $.Table.Name}} - for _,v:=range data{ - list = append(list,*v) - } - + list := data[:] _,err := m.db.NewInsert().Model(&list).Exec(ctx) return err } @@ -81,7 +77,8 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}) error { var db = m.db.NewUpdate() - db.Model(map[string]interface{}{ + db.Table("{{$.Table.Name}}") + db.Model(&map[string]interface{}{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) diff --git a/internal/gen/gorm/gorm.go b/internal/gen/gorm/gorm.go index e618ec4..568325e 100644 --- a/internal/gen/gorm/gorm.go +++ b/internal/gen/gorm/gorm.go @@ -26,6 +26,9 @@ func Run(list []spec.Context, output string) error { "IsPrimary": func(name string) bool { return ctx.Table.IsPrimary(name) }, + "IsExtraResult": func(name string) bool { + return name != templatex.UpperCamel(ctx.Table.Name) + }, }) gen.MustParse(gormGenTpl) gen.MustExecute(ctx) diff --git a/internal/gen/gorm/gorm_gen.tpl b/internal/gen/gorm/gorm_gen.tpl index 8e84c97..a49bb4e 100644 --- a/internal/gen/gorm/gorm_gen.tpl +++ b/internal/gen/gorm/gorm_gen.tpl @@ -14,7 +14,7 @@ import ( // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. type {{UpperCamel $.Table.Name}}Model struct { - db gorm.DB + db *gorm.DB } // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. @@ -25,6 +25,10 @@ type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} {{end}}{{$stmt.ReceiverStructure "gorm"}} +// TableName returns the table name. it implemented by gorm.Tabler. +{{if IsExtraResult $stmt.ReceiverName}}func ({{$stmt.ReceiverName}}) TableName() string { + return "{{$.Table.Name}}" +}{{end}} {{end}} {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} @@ -43,7 +47,7 @@ func ({{UpperCamel $.Table.Name}}) TableName() string { } // New{{UpperCamel $.Table.Name}}Model returns a new {{$.Table.Name}} model. -func New{{UpperCamel $.Table.Name}}Model (db gorm.DB) *{{UpperCamel $.Table.Name}}Model { +func New{{UpperCamel $.Table.Name}}Model (db *gorm.DB) *{{UpperCamel $.Table.Name}}Model { return &{{UpperCamel $.Table.Name}}Model{db: db} } @@ -54,11 +58,7 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* } db:=m.db.WithContext(ctx) - var list []{{UpperCamel $.Table.Name}} - for _,v:=range data{ - list = append(list,*v) - } - + list := data[:] return db.Create(&list).Error } {{range $stmt := .SelectStmt}} @@ -67,13 +67,13 @@ func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...* func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})({{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} error){ var result {{if $stmt.Limit.One}} = new({{$stmt.ReceiverName}}){{else}}[]*{{$stmt.ReceiverName}}{{end}} var db = m.db.WithContext(ctx) - db.Select(`{{$stmt.SelectSQL}}`) - {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) - {{end }}{{if $stmt.GroupBy.IsValid}}db.Group({{$stmt.GroupBy.SQL}}) - {{end}}{{if $stmt.Having.IsValid}}db.Having({{$stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}}) - {{end}}{{if $stmt.OrderBy.IsValid}}db.Order({{$stmt.OrderBy.SQL}}) - {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) - {{end}}db.Find({{if $stmt.Limit.One}}result{{else}}&result{{end}}) + db=db.Select(`{{$stmt.SelectSQL}}`) + {{if $stmt.Where.IsValid}}db=db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) + {{end }}{{if $stmt.GroupBy.IsValid}}db=db.Group({{$stmt.GroupBy.SQL}}) + {{end}}{{if $stmt.Having.IsValid}}db=db.Having({{$stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}}) + {{end}}{{if $stmt.OrderBy.IsValid}}db=db.Order({{$stmt.OrderBy.SQL}}) + {{end}}{{if $stmt.Limit.IsValid}}db=db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) + {{end}}db={{if $stmt.Limit.One}}db.Take(result){{else}}db.Find(&result){{end}} return result, db.Error } {{end}} @@ -83,11 +83,11 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { var db = m.db.WithContext(ctx) - db.Model(&{{UpperCamel $.Table.Name}}{}) - {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) - {{end}}{{if $stmt.OrderBy.IsValid}}db.Order({{$stmt.OrderBy.SQL}}) - {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) - {{end}}db.Updates(map[string]interface{}{ + db=db.Model(&{{UpperCamel $.Table.Name}}{}) + {{if $stmt.Where.IsValid}}db=db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) + {{end}}{{if $stmt.OrderBy.IsValid}}db=db.Order({{$stmt.OrderBy.SQL}}) + {{end}}{{if $stmt.Limit.IsValid}}db=db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) + {{end}}db=db.Updates(map[string]interface{}{ {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, {{end}} }) @@ -100,10 +100,10 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, dat // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { var db = m.db.WithContext(ctx) - {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) - {{end}}{{if $stmt.OrderBy.IsValid}}db.Order({{$stmt.OrderBy.SQL}}) - {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) - {{end}}db.Delete(&{{UpperCamel $.Table.Name}}{}) + {{if $stmt.Where.IsValid}}db=db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) + {{end}}{{if $stmt.OrderBy.IsValid}}db=db.Order({{$stmt.OrderBy.SQL}}) + {{end}}{{if $stmt.Limit.IsValid}}db=db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) + {{end}}db=db.Delete(&{{UpperCamel $.Table.Name}}{}) return db.Error } {{end}} \ No newline at end of file diff --git a/internal/gen/xorm/xorm_gen.tpl b/internal/gen/xorm/xorm_gen.tpl index 7f4b903..247dfda 100644 --- a/internal/gen/xorm/xorm_gen.tpl +++ b/internal/gen/xorm/xorm_gen.tpl @@ -54,11 +54,7 @@ func (m *{{UpperCamel $.Table.Name}}Model) Insert(ctx context.Context, data ...* } var session = m.engine.Context(ctx) - var list []{{UpperCamel $.Table.Name}} - for _,v := range data{ - list = append(list,*v) - } - + list := data[:] _,err := session.Insert(&list) return err } From 142c91880c993f51cd54672cc7db3ee00ac0bac0 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Wed, 17 Aug 2022 15:44:35 +0800 Subject: [PATCH 10/10] Add example_test --- example/example_test/sqlx/mock.go | 23 + .../example_test/sqlx/user_model.gen_test.go | 560 ++++++++++++++++- example/example_test/xorm/mock.go | 23 + .../example_test/xorm/user_model.gen_test.go | 570 +++++++++++++++++- example/sqlx/user_model.gen.go | 167 ++--- example/xorm/user_model.gen.go | 124 +++- internal/gen/sqlx/sqlx_gen.tpl | 16 +- internal/gen/xorm/xorm.go | 3 + internal/gen/xorm/xorm_gen.tpl | 23 +- 9 files changed, 1406 insertions(+), 103 deletions(-) create mode 100644 example/example_test/sqlx/mock.go create mode 100644 example/example_test/xorm/mock.go diff --git a/example/example_test/sqlx/mock.go b/example/example_test/sqlx/mock.go new file mode 100644 index 0000000..a438320 --- /dev/null +++ b/example/example_test/sqlx/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/sqlx" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/sqlx/user_model.gen_test.go b/example/example_test/sqlx/user_model.gen_test.go index ed8ee00..d1a8f11 100644 --- a/example/example_test/sqlx/user_model.gen_test.go +++ b/example/example_test/sqlx/user_model.gen_test.go @@ -1 +1,559 @@ -package sqlx +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/sqlx" + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db *sqlx.DB +) + +func TestMain(m *testing.M) { + db = sqlx.MustOpen("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + err := db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(db *sqlx.DB) { + tx, err := db.Begin() + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `truncate table user`) + if err != nil { + log.Fatalln(err) + } + + _, err = tx.ExecContext(ctx, `alter table user auto_increment=1`) + if err != nil { + log.Fatalln(err) + } + + err = tx.Commit() + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + mustInitDB(db) + return func(t *testing.T) { + f(t) + } +} diff --git a/example/example_test/xorm/mock.go b/example/example_test/xorm/mock.go new file mode 100644 index 0000000..8dee030 --- /dev/null +++ b/example/example_test/xorm/mock.go @@ -0,0 +1,23 @@ +package sql + +import ( + "time" + + model "github.com/anqiansong/sqlgen/example/xorm" + uuid "github.com/satori/go.uuid" +) + +func mustMockUser() *model.User { + uid := uuid.NewV4().String() + now := time.Now() + return &model.User{ + Name: uid, + Password: "bar", + Mobile: uid, + Gender: "male", + Nickname: "test", + Type: 1, + CreateAt: now, + UpdateAt: now, + } +} diff --git a/example/example_test/xorm/user_model.gen_test.go b/example/example_test/xorm/user_model.gen_test.go index 0baf644..e0264e7 100644 --- a/example/example_test/xorm/user_model.gen_test.go +++ b/example/example_test/xorm/user_model.gen_test.go @@ -1 +1,569 @@ -package xorm +package sql + +import ( + "context" + "database/sql" + "fmt" + "log" + "sort" + "testing" + "time" + + model "github.com/anqiansong/sqlgen/example/xorm" + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "xorm.io/xorm" +) + +var ( + um *model.UserModel + ctx = context.TODO() + db xorm.EngineInterface +) + +func TestMain(m *testing.M) { + var err error + db, err = xorm.NewEngine("mysql", "root:mysqlpw@tcp(127.0.0.1:55000)/test?charset=utf8mb4&parseTime=true&loc=Local") + if err != nil { + log.Fatalln(err) + } + + err = db.Ping() + if err != nil { + fmt.Println("ping error") + return + } + + um = model.NewUserModel(db) + m.Run() +} + +func mustInitDB(t *testing.T, db xorm.EngineInterface) { + session := db.NewSession() + t.Cleanup(func() { + session.Close() + }) + + err := session.Begin() + if err != nil { + log.Fatalln(err) + } + + _, err = session.Exec(`SET SESSION sql_mode=(SELECT REPLACE(@@sql_mode,'ONLY_FULL_GROUP_BY,',''))`) + if err != nil { + log.Fatalln(err) + } + + _, err = session.Exec(`truncate table user`) + if err != nil { + log.Fatalln(err) + } + + _, err = session.Exec(`alter table user auto_increment=1`) + if err != nil { + log.Fatalln(err) + } + + err = session.Commit() + if err != nil { + log.Fatalln(err) + } +} + +func TestCreate(t *testing.T) { + t.Run("emptyData", initAndRun(func(t *testing.T) { + err := um.Create(ctx) + assert.Contains(t, err.Error(), "empty") + })) + + t.Run("createOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + assert.Equal(t, uint64(1), mockUser.Id) + })) + t.Run("createMultiple", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 1; i <= 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + for idx, item := range list { + assert.Equal(t, uint64(idx+1), item.Id) + } + })) +} + +func TestFindOne(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("findOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneByName(ctx, model.FindOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: "foo"}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOneGroupByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByName(ctx, model.FindOneGroupByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindOneGroupByNameHavingName(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: "foo"}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + mockUser := mustMockUser() + err = um.Create(ctx, mockUser) + assert.NoError(t, err) + _, err = um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: "foo", + }) + assert.ErrorIs(t, err, sql.ErrNoRows) + + })) + + t.Run("FindOneGroupByNameHavingName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOneGroupByNameHavingName(ctx, model.FindOneGroupByNameHavingNameWhereParameter{NameEqual: mockUser.Name}, model.FindOneGroupByNameHavingNameHavingParameter{ + NameEqual: mockUser.Name, + }) + assert.NoError(t, err) + assertUserEqual(t, mockUser, actual) + })) +} + +func TestFindAll(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindAll", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list, actual) + })) +} + +func TestFindLimit(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 1, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimit", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimit(ctx, model.FindLimitWhereParameter{ + IdGT: 0, + }, model.FindLimitLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) + + t.Run("FindLimitOffset1", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindLimitOffset(ctx, model.FindLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindGroupHavingLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 1, + Offset: 0, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingLimitOffset(ctx, model.FindGroupHavingLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingLimitOffsetLimitParameter{ + Count: 2, + Offset: 0, + }) + assert.NoError(t, err) + assertUsersEqual(t, list[:2], actual) + })) +} + +func TestFindGroupHavingOrderDescLimitOffset(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + }) + assert.NoError(t, err) + assert.Equal(t, 0, len(actual)) + })) + + t.Run("FindGroupHavingOrderDescLimitOffset", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindGroupHavingOrderDescLimitOffset(ctx, model.FindGroupHavingOrderDescLimitOffsetWhereParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetHavingParameter{ + IdGT: 0, + }, model.FindGroupHavingOrderDescLimitOffsetLimitParameter{ + Count: 2, + Offset: 1, + }) + assert.NoError(t, err) + sort.Slice(list, func(i, j int) bool { + return list[i].Id > list[j].Id + }) + assertUsersEqual(t, list[1:3], actual) + })) +} + +func TestFindOnePart(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + _, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 1}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("FindOnePart", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindOnePart(ctx, model.FindOnePartWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, mockUser.Name, actual.Name) + assert.Equal(t, mockUser.Password, actual.Password) + assert.Equal(t, mockUser.Mobile, actual.Mobile) + })) +} + +func TestFindAllCount(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindAllCountWhere(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + countID, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(0), countID.CountID.Int64) + })) + + t.Run("FindAllCountWhere", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + actual, err := um.FindAllCountWhere(ctx, model.FindAllCountWhereWhereParameter{IdGT: 0}) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.CountID.Int64) + })) +} + +func TestFindMaxID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + maxID, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), maxID.MaxID.Int64) + })) + + t.Run("FindMaxID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMaxID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(5), actual.MaxID.Int64) + })) +} + +func TestFindMinID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(0), minID.MinID.Int64) + })) + + t.Run("FindMinID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindMinID(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), actual.MinID.Int64) + })) +} + +func TestFindAvgID(t *testing.T) { + t.Run("noRows", initAndRun(func(t *testing.T) { + minID, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "0", minID.AvgID.Decimal.String()) + })) + + t.Run("FindAvgID", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + actual, err := um.FindAvgID(ctx) + assert.NoError(t, err) + assert.Equal(t, "3", actual.AvgID.Decimal.String()) + })) +} + +func TestUpdate(t *testing.T) { + t.Run("Update", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.Update(ctx, newUser, model.UpdateWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDesc(ctx, newUser, model.UpdateOrderByIdDescWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) + + t.Run("UpdateOrderByIdDescLimitCount", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + mockUser.Name = "new name" + newUser := mustMockUser() + newUser.Id = mockUser.Id + err = um.UpdateOrderByIdDescLimitCount(ctx, newUser, model.UpdateOrderByIdDescLimitCountWhereParameter{IdEqual: mockUser.Id}, model.UpdateOrderByIdDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + assertUserEqual(t, newUser, actual) + })) +} + +func TestDelete(t *testing.T) { + t.Run("DeleteOne", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOne(ctx, model.DeleteOneWhereParameter{IdEqual: mockUser.Id}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneByName", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneByName(ctx, model.DeleteOneByNameWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDAsc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDAsc(ctx, model.DeleteOneOrderByIDAscWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDesc", initAndRun(func(t *testing.T) { + mockUser := mustMockUser() + err := um.Create(ctx, mockUser) + assert.NoError(t, err) + err = um.DeleteOneOrderByIDDesc(ctx, model.DeleteOneOrderByIDDescWhereParameter{NameEqual: mockUser.Name}) + assert.NoError(t, err) + _, err = um.FindOne(ctx, model.FindOneWhereParameter{IdEqual: mockUser.Id}) + assert.ErrorIs(t, err, sql.ErrNoRows) + })) + + t.Run("DeleteOneOrderByIDDescLimitCount", initAndRun(func(t *testing.T) { + var list []*model.User + for i := 0; i < 5; i++ { + list = append(list, mustMockUser()) + } + err := um.Create(ctx, list...) + assert.NoError(t, err) + + err = um.DeleteOneOrderByIDDescLimitCount(ctx, model.DeleteOneOrderByIDDescLimitCountWhereParameter{NameEqual: list[0].Name}, model.DeleteOneOrderByIDDescLimitCountLimitParameter{Count: 1}) + assert.NoError(t, err) + actual, err := um.FindAll(ctx) + assert.NoError(t, err) + assertUsersEqual(t, list[1:], actual) + })) +} + +func assertUserEqual(t *testing.T, expected, actual *model.User) { + now := time.Now() + expected.CreateAt = now + expected.UpdateAt = now + actual.CreateAt = now + actual.UpdateAt = now + assert.Equal(t, *expected, *actual) +} + +func assertUsersEqual(t *testing.T, expected, actual []*model.User) { + assert.Equal(t, len(expected), len(actual)) + for idx, expectedItem := range expected { + actual := actual[idx] + assertUserEqual(t, expectedItem, actual) + } +} + +func initAndRun(f func(t *testing.T)) func(t *testing.T) { + return func(t *testing.T) { + mustInitDB(t, db) + f(t) + } +} diff --git a/example/sqlx/user_model.gen.go b/example/sqlx/user_model.gen.go index dc8d0a5..b3a5331 100644 --- a/example/sqlx/user_model.gen.go +++ b/example/sqlx/user_model.gen.go @@ -272,17 +272,17 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (r } defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindOneByName is generated from sql: @@ -306,17 +306,17 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP } defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindOneGroupByName is generated from sql: @@ -341,17 +341,17 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy } defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindOneGroupByNameHavingName is generated from sql: @@ -377,17 +377,17 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find } defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAll is generated from sql: @@ -418,6 +418,7 @@ func (m *UserModel) FindAll(ctx context.Context) (result []*User, err error) { } return result, nil + } // FindLimit is generated from sql: @@ -450,6 +451,7 @@ func (m *UserModel) FindLimit(ctx context.Context, where FindLimitWhereParameter } return result, nil + } // FindLimitOffset is generated from sql: @@ -481,6 +483,7 @@ func (m *UserModel) FindLimitOffset(ctx context.Context, limit FindLimitOffsetLi } return result, nil + } // FindGroupLimitOffset is generated from sql: @@ -514,6 +517,7 @@ func (m *UserModel) FindGroupLimitOffset(ctx context.Context, where FindGroupLim } return result, nil + } // FindGroupHavingLimitOffset is generated from sql: @@ -548,6 +552,7 @@ func (m *UserModel) FindGroupHavingLimitOffset(ctx context.Context, where FindGr } return result, nil + } // FindGroupHavingOrderAscLimitOffset is generated from sql: @@ -583,6 +588,7 @@ func (m *UserModel) FindGroupHavingOrderAscLimitOffset(ctx context.Context, wher } return result, nil + } // FindGroupHavingOrderDescLimitOffset is generated from sql: @@ -618,6 +624,7 @@ func (m *UserModel) FindGroupHavingOrderDescLimitOffset(ctx context.Context, whe } return result, nil + } // FindOnePart is generated from sql: @@ -641,17 +648,17 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam } defer rows.Close() - for rows.Next() { - var v User - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAllCount is generated from sql: @@ -674,17 +681,17 @@ func (m *UserModel) FindAllCount(ctx context.Context) (result *FindAllCountResul } defer rows.Close() - for rows.Next() { - var v FindAllCountResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAllCountWhere is generated from sql: @@ -708,17 +715,17 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe } defer rows.Close() - for rows.Next() { - var v FindAllCountWhereResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindMaxID is generated from sql: @@ -741,17 +748,17 @@ func (m *UserModel) FindMaxID(ctx context.Context) (result *FindMaxIDResult, err } defer rows.Close() - for rows.Next() { - var v FindMaxIDResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindMinID is generated from sql: @@ -774,17 +781,17 @@ func (m *UserModel) FindMinID(ctx context.Context) (result *FindMinIDResult, err } defer rows.Close() - for rows.Next() { - var v FindMinIDResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // FindAvgID is generated from sql: @@ -807,17 +814,17 @@ func (m *UserModel) FindAvgID(ctx context.Context) (result *FindAvgIDResult, err } defer rows.Close() - for rows.Next() { - var v FindAvgIDResult - err = rows.StructScan(&v) - if err != nil { - return nil, err - } - result = &v - break + if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err } return result, nil + } // Update is generated from sql: diff --git a/example/xorm/user_model.gen.go b/example/xorm/user_model.gen.go index 5b35dc5..63d1692 100644 --- a/example/xorm/user_model.gen.go +++ b/example/xorm/user_model.gen.go @@ -36,16 +36,22 @@ type FindOneWhereParameter struct { IdEqual uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneByNameWhereParameter is a where parameter structure. type FindOneByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameWhereParameter is a where parameter structure. type FindOneGroupByNameWhereParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOneGroupByNameHavingNameWhereParameter is a where parameter structure. type FindOneGroupByNameHavingNameWhereParameter struct { NameEqual string @@ -56,6 +62,10 @@ type FindOneGroupByNameHavingNameHavingParameter struct { NameEqual string } +// TableName returns the table name. it implemented by gorm.Tabler. + +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitWhereParameter is a where parameter structure. type FindLimitWhereParameter struct { IdGT uint64 @@ -66,12 +76,16 @@ type FindLimitLimitParameter struct { Count int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindLimitOffsetLimitParameter is a limit parameter structure. type FindLimitOffsetLimitParameter struct { Count int Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupLimitOffsetWhereParameter is a where parameter structure. type FindGroupLimitOffsetWhereParameter struct { IdGT uint64 @@ -83,6 +97,8 @@ type FindGroupLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingLimitOffsetWhereParameter struct { IdGT uint64 @@ -99,6 +115,8 @@ type FindGroupHavingLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderAscLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderAscLimitOffsetWhereParameter struct { IdGT uint64 @@ -115,6 +133,8 @@ type FindGroupHavingOrderAscLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindGroupHavingOrderDescLimitOffsetWhereParameter is a where parameter structure. type FindGroupHavingOrderDescLimitOffsetWhereParameter struct { IdGT uint64 @@ -131,16 +151,25 @@ type FindGroupHavingOrderDescLimitOffsetLimitParameter struct { Offset int } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindOnePartWhereParameter is a where parameter structure. type FindOnePartWhereParameter struct { IdGT uint64 } +// TableName returns the table name. it implemented by gorm.Tabler. + // FindAllCountResult is a find all count result. type FindAllCountResult struct { CountID sql.NullInt64 `xorm:"'countID'" json:"countID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountResult) TableName() string { + return "user" +} + // FindAllCountWhereWhereParameter is a where parameter structure. type FindAllCountWhereWhereParameter struct { IdGT uint64 @@ -151,21 +180,41 @@ type FindAllCountWhereResult struct { CountID sql.NullInt64 `xorm:"'countID'" json:"countID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAllCountWhereResult) TableName() string { + return "user" +} + // FindMaxIDResult is a find max id result. type FindMaxIDResult struct { MaxID sql.NullInt64 `xorm:"'maxID'" json:"maxID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMaxIDResult) TableName() string { + return "user" +} + // FindMinIDResult is a find min id result. type FindMinIDResult struct { MinID sql.NullInt64 `xorm:"'minID'" json:"minID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindMinIDResult) TableName() string { + return "user" +} + // FindAvgIDResult is a find avg id result. type FindAvgIDResult struct { AvgID decimal.NullDecimal `xorm:"'avgID'" json:"avgID"` } +// TableName returns the table name. it implemented by gorm.Tabler. +func (FindAvgIDResult) TableName() string { + return "user" +} + // UpdateWhereParameter is a where parameter structure. type UpdateWhereParameter struct { IdEqual uint64 @@ -225,15 +274,19 @@ func NewUserModel(engine xorm.EngineInterface) *UserModel { return &UserModel{engine: engine} } -// Insert creates user data. -func (m *UserModel) Insert(ctx context.Context, data ...*User) error { +// Create creates user data. +func (m *UserModel) Create(ctx context.Context, data ...*User) error { if len(data) == 0 { return fmt.Errorf("data is empty") } var session = m.engine.Context(ctx) - list := data[:] - _, err := session.Insert(&list) + var list []interface{} + for _, v := range data { + list = append(list, v) + } + + _, err := session.Insert(list...) return err } @@ -245,7 +298,11 @@ func (m *UserModel) FindOne(ctx context.Context, where FindOneWhereParameter) (* session.Select(`*`) session.Where(`id = ?`, where.IdEqual) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -257,7 +314,11 @@ func (m *UserModel) FindOneByName(ctx context.Context, where FindOneByNameWhereP session.Select(`*`) session.Where(`name = ?`, where.NameEqual) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -270,7 +331,11 @@ func (m *UserModel) FindOneGroupByName(ctx context.Context, where FindOneGroupBy session.Where(`name = ?`, where.NameEqual) session.GroupBy(`name`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -284,7 +349,11 @@ func (m *UserModel) FindOneGroupByNameHavingName(ctx context.Context, where Find session.GroupBy(`name`) session.Having(fmt.Sprintf(`name = '%v'`, having.NameEqual)) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -386,7 +455,11 @@ func (m *UserModel) FindOnePart(ctx context.Context, where FindOnePartWhereParam session.Select(`name, password, mobile`) session.Where(`id > ?`, where.IdGT) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -397,7 +470,11 @@ func (m *UserModel) FindAllCount(ctx context.Context) (*FindAllCountResult, erro var session = m.engine.Context(ctx) session.Select(`count(id) AS countID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -409,7 +486,11 @@ func (m *UserModel) FindAllCountWhere(ctx context.Context, where FindAllCountWhe session.Select(`count(id) AS countID`) session.Where(`id > ?`, where.IdGT) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -420,7 +501,11 @@ func (m *UserModel) FindMaxID(ctx context.Context) (*FindMaxIDResult, error) { var session = m.engine.Context(ctx) session.Select(`max(id) AS maxID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -431,7 +516,11 @@ func (m *UserModel) FindMinID(ctx context.Context) (*FindMinIDResult, error) { var session = m.engine.Context(ctx) session.Select(`min(id) AS minID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -442,7 +531,11 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { var session = m.engine.Context(ctx) session.Select(`avg(id) AS avgID`) session.Limit(1) - _, err := session.Get(result) + has, err := session.Get(result) + if !has { + return nil, sql.ErrNoRows + } + return result, err } @@ -450,6 +543,7 @@ func (m *UserModel) FindAvgID(ctx context.Context) (*FindAvgIDResult, error) { // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ?; func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWhereParameter) error { var session = m.engine.Context(ctx) + session.Table(&User{}) session.Where(`id = ?`, where.IdEqual) _, err := session.Update(map[string]interface{}{ "name": data.Name, @@ -468,6 +562,7 @@ func (m *UserModel) Update(ctx context.Context, data *User, where UpdateWherePar // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc; func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where UpdateOrderByIdDescWhereParameter) error { var session = m.engine.Context(ctx) + session.Table(&User{}) session.Where(`id = ?`, where.IdEqual) session.OrderBy(`id desc`) _, err := session.Update(map[string]interface{}{ @@ -487,6 +582,7 @@ func (m *UserModel) UpdateOrderByIdDesc(ctx context.Context, data *User, where U // update `user` set `name` = ?, `password` = ?, `mobile` = ?, `gender` = ?, `nickname` = ?, `type` = ?, `create_at` = ?, `update_at` = ? where `id` = ? order by id desc limit ?; func (m *UserModel) UpdateOrderByIdDescLimitCount(ctx context.Context, data *User, where UpdateOrderByIdDescLimitCountWhereParameter, limit UpdateOrderByIdDescLimitCountLimitParameter) error { var session = m.engine.Context(ctx) + session.Table(&User{}) session.Where(`id = ?`, where.IdEqual) session.OrderBy(`id desc`) session.Limit(limit.Count) diff --git a/internal/gen/sqlx/sqlx_gen.tpl b/internal/gen/sqlx/sqlx_gen.tpl index 82c95c6..84d75bf 100644 --- a/internal/gen/sqlx/sqlx_gen.tpl +++ b/internal/gen/sqlx/sqlx_gen.tpl @@ -98,17 +98,29 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if } defer rows.Close() + {{if $stmt.Limit.One}}if !rows.Next() { + return nil, sql.ErrNoRows + } + + err = rows.StructScan(result) + if err != nil { + return nil, err + } + + return result, nil + + {{else}} for rows.Next() { var v {{$stmt.ReceiverName}} err = rows.StructScan(&v) if err != nil { return nil, err } - {{if $stmt.Limit.One}}result=&v - break{{else}} result = append(result, &v){{end}} + result = append(result, &v) } return result, nil + {{end}} } {{end}} diff --git a/internal/gen/xorm/xorm.go b/internal/gen/xorm/xorm.go index bc950c3..946588f 100644 --- a/internal/gen/xorm/xorm.go +++ b/internal/gen/xorm/xorm.go @@ -30,6 +30,9 @@ func Run(list []spec.Context, output string) error { format = strings.ReplaceAll(format, "?", "'%v'") return format }, + "IsExtraResult": func(name string) bool { + return name != templatex.UpperCamel(ctx.Table.Name) + }, }) gen.MustParse(xormGenTpl) gen.MustExecute(ctx) diff --git a/internal/gen/xorm/xorm_gen.tpl b/internal/gen/xorm/xorm_gen.tpl index 247dfda..a2bea0b 100644 --- a/internal/gen/xorm/xorm_gen.tpl +++ b/internal/gen/xorm/xorm_gen.tpl @@ -26,6 +26,10 @@ type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} {{end}}{{$stmt.ReceiverStructure "xorm"}} +// TableName returns the table name. it implemented by gorm.Tabler. +{{if IsExtraResult $stmt.ReceiverName}}func ({{$stmt.ReceiverName}}) TableName() string { + return "{{$.Table.Name}}" +}{{end}} {{end}} {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} @@ -47,15 +51,19 @@ func New{{UpperCamel $.Table.Name}}Model (engine xorm.EngineInterface) *{{UpperC return &{{UpperCamel $.Table.Name}}Model{engine: engine} } -// Insert creates {{$.Table.Name}} data. -func (m *{{UpperCamel $.Table.Name}}Model) Insert(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { +// Create creates {{$.Table.Name}} data. +func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { if len(data)==0{ return fmt.Errorf("data is empty") } var session = m.engine.Context(ctx) - list := data[:] - _,err := session.Insert(&list) + var list []interface{} + for _, v := range data { + list = append(list, v) + } + + _,err := session.Insert(list...) return err } {{range $stmt := .SelectStmt}} @@ -70,7 +78,11 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if {{end}}{{if $stmt.Having.IsValid}}session.Having(fmt.Sprintf({{HavingSprintf $stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}})) {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) - {{end}}{{if $stmt.Limit.One}}_,{{end}} err := session{{if $stmt.Limit.One}}.Get({{if $stmt.Limit.One}}result{{end}}){{else}}.Find(&result){{end}} + {{end}}{{if $stmt.Limit.One}}has, err := session.Get(result) + if !has{ + return nil, sql.ErrNoRows + } + {{else}}err :=session.Find(&result){{end}} return result, err } {{end}} @@ -80,6 +92,7 @@ func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if // {{$stmt.SQL}} func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { var session = m.engine.Context(ctx) + session.Table(&{{UpperCamel $.Table.Name}}{}) {{if $stmt.Where.IsValid}}session.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}})