From c46406e8d8fedebb9f212fd1cd8b8eb7c858c33a Mon Sep 17 00:00:00 2001 From: Nace Sc Date: Mon, 3 Jul 2023 15:09:25 +0800 Subject: [PATCH] Fix/mysqlr raw query (#276) * internal/parser,e2e: fix mysqlr raw query parser * internal/parser: deprecate io/ioutil package * e2e/mysqlr: remove tests * internal/parser/x: AggregateFuncExpr also should support alias * internal/parser/x: scan order should keep in the fixed order * e2e/mysqlr: remove tests * internal/parser: fix pk query --- e2e/mysqlr/gen_methods.go | 8 ++-- e2e/mysqlr/sqls/blog.sql | 4 +- internal/parser/mysqlr/field.go | 35 ++++++++-------- internal/parser/shared/field.go | 33 ++++++++------- internal/parser/x/query/sql.go | 8 +++- internal/parser/x/query/sql_method.go | 27 +++++++++---- internal/parser/x/query/tidb_parser.go | 17 ++++++-- internal/parser/x/query/tidb_parser_test.go | 45 +++++++++++++++------ 8 files changed, 113 insertions(+), 64 deletions(-) diff --git a/e2e/mysqlr/gen_methods.go b/e2e/mysqlr/gen_methods.go index ff12469a..5bb3dcf5 100644 --- a/e2e/mysqlr/gen_methods.go +++ b/e2e/mysqlr/gen_methods.go @@ -40,7 +40,9 @@ func WithDB(db *sql_driver.DB) RawQueryOptionHandler { } type BlogResp struct { - SUMTitle any + Id int64 `sql:"id"` + TitleCount any `sql:"title_count"` + Status int32 `sql:"status"` } type BlogReq struct { @@ -55,7 +57,7 @@ func (req *BlogReq) Params() []any { return params } -const _BlogSQL = "SELECT SUM(`title`) AS `title_count` FROM `blogs` WHERE `id`=?" +const _BlogSQL = "SELECT `Id`,SUM(`title`) AS `title_count`,`status` FROM `blogs` WHERE `id`=?" // Blog is a raw query handler generated function for `e2e/mysqlr/sqls/blog.sql`. func (m *sqlMethods) Blog(ctx context.Context, req *BlogReq, opts ...RawQueryOptionHandler) ([]*BlogResp, error) { @@ -77,7 +79,7 @@ func (m *sqlMethods) Blog(ctx context.Context, req *BlogReq, opts ...RawQueryOpt var results []*BlogResp for rows.Next() { var o BlogResp - err = rows.Scan(&o.SUMTitle) + err = rows.Scan(&o.Id, &o.TitleCount, &o.Status) if err != nil { return nil, err } diff --git a/e2e/mysqlr/sqls/blog.sql b/e2e/mysqlr/sqls/blog.sql index 4ec485ea..84995352 100644 --- a/e2e/mysqlr/sqls/blog.sql +++ b/e2e/mysqlr/sqls/blog.sql @@ -1,5 +1,7 @@ SELECT - SUM(`title`) AS title_count + Id , + SUM(`title`) AS title_count , + status FROM blogs WHERE diff --git a/internal/parser/mysqlr/field.go b/internal/parser/mysqlr/field.go index 1298841c..03a7af34 100644 --- a/internal/parser/mysqlr/field.go +++ b/internal/parser/mysqlr/field.go @@ -11,24 +11,23 @@ import ( "github.com/ezbuy/ezorm/v2/internal/generator" "github.com/ezbuy/utils/container/set" + "github.com/iancoleman/strcase" ) -var ( - nullablePrimitiveSet = map[string]bool{ - "uint8": true, - "uint16": true, - "uint32": true, - "uint64": true, - "int8": true, - "int16": true, - "int32": true, - "int64": true, - "float32": true, - "float64": true, - "bool": true, - "string": true, - } -) +var nullablePrimitiveSet = map[string]bool{ + "uint8": true, + "uint16": true, + "uint32": true, + "uint64": true, + "int8": true, + "int16": true, + "int32": true, + "int64": true, + "float32": true, + "float64": true, + "bool": true, + "string": true, +} type Field struct { Name string @@ -71,7 +70,7 @@ var SupportedFieldTypes = map[string]string{ } func (f *Field) GetName() string { - return CamelName(f.Name) + return strcase.ToLowerCamel(f.Name) } func (f *Field) GetUnderlineName() string { @@ -408,7 +407,7 @@ func (f *Field) Read(data generator.Schema) error { return nil } -//! field SQL script functions +// ! field SQL script functions func (f *Field) SQLColumn() string { columns := make([]string, 0, 6) columns = append(columns, f.SQLName()) diff --git a/internal/parser/shared/field.go b/internal/parser/shared/field.go index 4404d808..90a53e31 100644 --- a/internal/parser/shared/field.go +++ b/internal/parser/shared/field.go @@ -10,28 +10,27 @@ import ( "github.com/ezbuy/ezorm/v2/internal/generator" "github.com/ezbuy/utils/container/set" + "github.com/iancoleman/strcase" ) const ( flagNullable = "nullable" ) -var ( - nullablePrimitiveSet = map[string]bool{ - "uint8": true, - "uint16": true, - "uint32": true, - "uint64": true, - "int8": true, - "int16": true, - "int32": true, - "int64": true, - "float32": true, - "float64": true, - "bool": true, - "string": true, - } -) +var nullablePrimitiveSet = map[string]bool{ + "uint8": true, + "uint16": true, + "uint32": true, + "uint64": true, + "int8": true, + "int16": true, + "int32": true, + "int64": true, + "float32": true, + "float64": true, + "bool": true, + "string": true, +} var _ generator.IField = (*Field)(nil) @@ -162,7 +161,7 @@ func (f *Field) DbName() string { } func (f *Field) GetName() string { - return camel2name(f.Name) + return strcase.ToLowerCamel(f.Name) } func (f *Field) GetTag() string { diff --git a/internal/parser/x/query/sql.go b/internal/parser/x/query/sql.go index 901ba3e8..3e352ffe 100644 --- a/internal/parser/x/query/sql.go +++ b/internal/parser/x/query/sql.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/ezbuy/ezorm/v2/internal/generator" + "github.com/iancoleman/strcase" ) type T uint8 @@ -158,6 +159,7 @@ func (tm TableMetadata) Validate(tableRef map[string]map[string]generator.IField } for _, p := range f.params { pName := uglify(p.Name) + pName = strcase.ToLowerCamel(pName) col, ok := ff[pName] if !ok && p.Type != T_ANY { return fmt.Errorf("metadata: param %s not found in table %s", pName, name) @@ -168,8 +170,9 @@ func (tm TableMetadata) Validate(tableRef map[string]map[string]generator.IField } for _, r := range f.result { rName := uglify(r.Name) + rName = strcase.ToLowerCamel(rName) if _, ok := ff[rName]; !ok && r.Type != T_ANY { - return fmt.Errorf("metadata: result %s not found in table %s", r.Name, name) + return fmt.Errorf("metadata: result %s not found in table %s", rName, name) } } } @@ -189,7 +192,8 @@ func (qm *QueryMetadata) String() string { type QueryBuilder struct { *bytes.Buffer - raw *Raw + raw *Raw + resultFields []*QueryField } func (qb *QueryBuilder) IsQueryIn() bool { diff --git a/internal/parser/x/query/sql_method.go b/internal/parser/x/query/sql_method.go index 7f5bf6a7..691e50d8 100644 --- a/internal/parser/x/query/sql_method.go +++ b/internal/parser/x/query/sql_method.go @@ -5,7 +5,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "os" "path/filepath" "strings" "unicode" @@ -62,6 +62,7 @@ func (p *SQL) retypeResult(table string, col string) (string, error) { if !ok { return "", fmt.Errorf("res: retype: table: %s not found", table) } + col = strcase.ToLowerCamel(col) f, ok := t[col] if !ok { return "", fmt.Errorf("res: retype: field: %s not found", col) @@ -70,7 +71,7 @@ func (p *SQL) retypeResult(table string, col string) (string, error) { } func (p *SQL) Read(path string) (*SQLMethod, error) { - data, err := ioutil.ReadFile(path) + data, err := os.ReadFile(path) if err != nil { return nil, err } @@ -101,7 +102,10 @@ func (p *SQL) Read(path string) (*SQLMethod, error) { } for t, f := range meta { for _, c := range f.params { - name := uglify(c.Name) + name := c.Alias + if name == "" { + name = uglify(c.Name) + } result.Fields = append(result.Fields, &SQLMethodField{ Name: strcase.ToCamel(name), Raw: name, @@ -109,15 +113,19 @@ func (p *SQL) Read(path string) (*SQLMethod, error) { }) } for _, c := range f.result { - name := uglify(c.Name) + name := c.Alias + if name == "" { + name = uglify(c.Name) + } if c.Type == T_ANY { result.Result = append(result.Result, &SQLMethodField{ Name: strcase.ToCamel(name), Type: c.Type.String(), + Raw: name, }) continue } - tp, err := p.retypeResult(t.Name, name) + tp, err := p.retypeResult(t.Name, uglify(c.Name)) if err != nil { return nil, err } @@ -131,8 +139,13 @@ func (p *SQL) Read(path string) (*SQLMethod, error) { } var scan bytes.Buffer - for _, r := range result.Result { - scan.WriteString(fmt.Sprintf("&o.%s, ", r.Name)) + for _, r := range builder.resultFields { + name := r.Alias + if name == "" { + name = uglify(r.Name) + } + name = strcase.ToCamel(name) + scan.WriteString(fmt.Sprintf("&o.%s, ", name)) } result.Assign = scan.String() diff --git a/internal/parser/x/query/tidb_parser.go b/internal/parser/x/query/tidb_parser.go index 89d0a1c9..801ce76b 100644 --- a/internal/parser/x/query/tidb_parser.go +++ b/internal/parser/x/query/tidb_parser.go @@ -92,7 +92,9 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error { if x.Fields != nil { for _, f := range x.Fields.Fields { if expr, ok := f.Expr.(*ast.AggregateFuncExpr); ok { - field := &QueryField{} + field := &QueryField{ + Alias: f.AsName.String(), + } var txt bytes.Buffer txt.WriteString(expr.F) for _, args := range expr.Args { @@ -106,16 +108,20 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error { if len(expr.Args) > 0 { if col, ok := expr.Args[0].(*ast.ColumnNameExpr); ok { tp.meta.AppendResult(col.Name.Table.String(), field) + tp.b.resultFields = append(tp.b.resultFields, field) } } } if expr, ok := f.Expr.(*ast.ColumnNameExpr); ok { - field := &QueryField{} + field := &QueryField{ + Alias: f.AsName.String(), + } ff := &strings.Builder{} expr.Format(ff) field.Name = ff.String() field.Type = T_PLACEHOLDER tp.meta.AppendResult(expr.Name.Table.String(), field) + tp.b.resultFields = append(tp.b.resultFields, field) } } } @@ -241,7 +247,8 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error { } func (tp *TiDBParser) Parse(ctx context.Context, - query string) (TableMetadata, *QueryBuilder, error) { + query string, +) (TableMetadata, *QueryBuilder, error) { queries := strings.Split(query, ";") for _, q := range queries { if len(strings.TrimSpace(q)) == 0 { @@ -268,12 +275,14 @@ func (tp *TiDBParser) Flush() { ins: map[string]struct{}{}, limit: &LimitOption{}, } + tp.b.resultFields = []*QueryField{} tp.b.Reset() tp.meta = make(map[Table]*QueryMetadata) } func (tp *TiDBParser) parseOne(ctx context.Context, - query string) error { + query string, +) error { node, err := parser.New().ParseOneStmt(query, "", "") if err != nil { return fmt.Errorf("raw query parser: %w(query: %s)", err, query) diff --git a/internal/parser/x/query/tidb_parser_test.go b/internal/parser/x/query/tidb_parser_test.go index 1375d324..d43ee3bd 100644 --- a/internal/parser/x/query/tidb_parser_test.go +++ b/internal/parser/x/query/tidb_parser_test.go @@ -55,6 +55,15 @@ WHERE u.name = 'me' ` +const queryWithColAs = ` +SELECT + u.id as uid +FROM + user u +WHERE + u.name = 'me' +` + const queryWithSubquery = ` SELECT id @@ -98,19 +107,21 @@ func TestTiDBParserParseMetadata(t *testing.T) { query string metadata TableMetadata }{ - {"query", query, map[Table]*QueryMetadata{ - {Name: "user"}: { - params: []*QueryField{ - {Name: "col:`name`", Type: T_ARRAY_STRING}, - {Name: "col:`id`", Type: T_INT}, - {Name: "col:`phone`", Type: T_STRING}, - {Name: "limit:count", Type: T_INT}, - {Name: "limit:offset", Type: T_INT}, + { + "query", query, map[Table]*QueryMetadata{ + {Name: "user"}: { + params: []*QueryField{ + {Name: "col:`name`", Type: T_ARRAY_STRING}, + {Name: "col:`id`", Type: T_INT}, + {Name: "col:`phone`", Type: T_STRING}, + {Name: "limit:count", Type: T_INT}, + {Name: "limit:offset", Type: T_INT}, + }, + result: []*QueryField{ + {Name: "`id`", Type: T_PLACEHOLDER}, + }, }, - result: []*QueryField{ - {Name: "`id`", Type: T_PLACEHOLDER}, - }, - }}, + }, }, {"queryIn", queryIn, map[Table]*QueryMetadata{ {Name: "user"}: { @@ -173,6 +184,16 @@ func TestTiDBParserParseMetadata(t *testing.T) { }, }, }}, + {"queryWithColAs", queryWithColAs, map[Table]*QueryMetadata{ + {Name: "user", Alias: "u"}: { + params: []*QueryField{ + {Name: "col:`u`.`name`", Type: T_STRING}, + }, + result: []*QueryField{ + {Name: "`u`.`id`", Type: T_PLACEHOLDER, Alias: "uid"}, + }, + }, + }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {