Skip to content

Commit

Permalink
Fix/mysqlr raw query (#276)
Browse files Browse the repository at this point in the history
* internal/parser,e2e: fix mysqlr raw query parser

* internal/parser: deprecate io/ioutil package

* e2e/mysqlr: remove tests

* internal/parser/x: AggregateFuncExpr also should support alias

* internal/parser/x: scan order should keep in the fixed order

* e2e/mysqlr: remove tests

* internal/parser: fix pk query
  • Loading branch information
scbizu committed Jul 3, 2023
1 parent 7ec857a commit c46406e
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 64 deletions.
8 changes: 5 additions & 3 deletions e2e/mysqlr/gen_methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ func WithDB(db *sql_driver.DB) RawQueryOptionHandler {
}

type BlogResp struct {
SUMTitle any
Id int64 `sql:"id"`
TitleCount any `sql:"title_count"`
Status int32 `sql:"status"`
}

type BlogReq struct {
Expand All @@ -55,7 +57,7 @@ func (req *BlogReq) Params() []any {
return params
}

const _BlogSQL = "SELECT SUM(`title`) AS `title_count` FROM `blogs` WHERE `id`=?"
const _BlogSQL = "SELECT `Id`,SUM(`title`) AS `title_count`,`status` FROM `blogs` WHERE `id`=?"

// Blog is a raw query handler generated function for `e2e/mysqlr/sqls/blog.sql`.
func (m *sqlMethods) Blog(ctx context.Context, req *BlogReq, opts ...RawQueryOptionHandler) ([]*BlogResp, error) {
Expand All @@ -77,7 +79,7 @@ func (m *sqlMethods) Blog(ctx context.Context, req *BlogReq, opts ...RawQueryOpt
var results []*BlogResp
for rows.Next() {
var o BlogResp
err = rows.Scan(&o.SUMTitle)
err = rows.Scan(&o.Id, &o.TitleCount, &o.Status)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion e2e/mysqlr/sqls/blog.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
SELECT
SUM(`title`) AS title_count
Id ,
SUM(`title`) AS title_count ,
status
FROM
blogs
WHERE
Expand Down
35 changes: 17 additions & 18 deletions internal/parser/mysqlr/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,23 @@ import (

"github.com/ezbuy/ezorm/v2/internal/generator"
"github.com/ezbuy/utils/container/set"
"github.com/iancoleman/strcase"
)

var (
nullablePrimitiveSet = map[string]bool{
"uint8": true,
"uint16": true,
"uint32": true,
"uint64": true,
"int8": true,
"int16": true,
"int32": true,
"int64": true,
"float32": true,
"float64": true,
"bool": true,
"string": true,
}
)
var nullablePrimitiveSet = map[string]bool{
"uint8": true,
"uint16": true,
"uint32": true,
"uint64": true,
"int8": true,
"int16": true,
"int32": true,
"int64": true,
"float32": true,
"float64": true,
"bool": true,
"string": true,
}

type Field struct {
Name string
Expand Down Expand Up @@ -71,7 +70,7 @@ var SupportedFieldTypes = map[string]string{
}

func (f *Field) GetName() string {
return CamelName(f.Name)
return strcase.ToLowerCamel(f.Name)
}

func (f *Field) GetUnderlineName() string {
Expand Down Expand Up @@ -408,7 +407,7 @@ func (f *Field) Read(data generator.Schema) error {
return nil
}

//! field SQL script functions
// ! field SQL script functions
func (f *Field) SQLColumn() string {
columns := make([]string, 0, 6)
columns = append(columns, f.SQLName())
Expand Down
33 changes: 16 additions & 17 deletions internal/parser/shared/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,27 @@ import (

"github.com/ezbuy/ezorm/v2/internal/generator"
"github.com/ezbuy/utils/container/set"
"github.com/iancoleman/strcase"
)

const (
flagNullable = "nullable"
)

var (
nullablePrimitiveSet = map[string]bool{
"uint8": true,
"uint16": true,
"uint32": true,
"uint64": true,
"int8": true,
"int16": true,
"int32": true,
"int64": true,
"float32": true,
"float64": true,
"bool": true,
"string": true,
}
)
var nullablePrimitiveSet = map[string]bool{
"uint8": true,
"uint16": true,
"uint32": true,
"uint64": true,
"int8": true,
"int16": true,
"int32": true,
"int64": true,
"float32": true,
"float64": true,
"bool": true,
"string": true,
}

var _ generator.IField = (*Field)(nil)

Expand Down Expand Up @@ -162,7 +161,7 @@ func (f *Field) DbName() string {
}

func (f *Field) GetName() string {
return camel2name(f.Name)
return strcase.ToLowerCamel(f.Name)
}

func (f *Field) GetTag() string {
Expand Down
8 changes: 6 additions & 2 deletions internal/parser/x/query/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/ezbuy/ezorm/v2/internal/generator"
"github.com/iancoleman/strcase"
)

type T uint8
Expand Down Expand Up @@ -158,6 +159,7 @@ func (tm TableMetadata) Validate(tableRef map[string]map[string]generator.IField
}
for _, p := range f.params {
pName := uglify(p.Name)
pName = strcase.ToLowerCamel(pName)
col, ok := ff[pName]
if !ok && p.Type != T_ANY {
return fmt.Errorf("metadata: param %s not found in table %s", pName, name)
Expand All @@ -168,8 +170,9 @@ func (tm TableMetadata) Validate(tableRef map[string]map[string]generator.IField
}
for _, r := range f.result {
rName := uglify(r.Name)
rName = strcase.ToLowerCamel(rName)
if _, ok := ff[rName]; !ok && r.Type != T_ANY {
return fmt.Errorf("metadata: result %s not found in table %s", r.Name, name)
return fmt.Errorf("metadata: result %s not found in table %s", rName, name)
}
}
}
Expand All @@ -189,7 +192,8 @@ func (qm *QueryMetadata) String() string {

type QueryBuilder struct {
*bytes.Buffer
raw *Raw
raw *Raw
resultFields []*QueryField
}

func (qb *QueryBuilder) IsQueryIn() bool {
Expand Down
27 changes: 20 additions & 7 deletions internal/parser/x/query/sql_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"unicode"
Expand Down Expand Up @@ -62,6 +62,7 @@ func (p *SQL) retypeResult(table string, col string) (string, error) {
if !ok {
return "", fmt.Errorf("res: retype: table: %s not found", table)
}
col = strcase.ToLowerCamel(col)
f, ok := t[col]
if !ok {
return "", fmt.Errorf("res: retype: field: %s not found", col)
Expand All @@ -70,7 +71,7 @@ func (p *SQL) retypeResult(table string, col string) (string, error) {
}

func (p *SQL) Read(path string) (*SQLMethod, error) {
data, err := ioutil.ReadFile(path)
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -101,23 +102,30 @@ func (p *SQL) Read(path string) (*SQLMethod, error) {
}
for t, f := range meta {
for _, c := range f.params {
name := uglify(c.Name)
name := c.Alias
if name == "" {
name = uglify(c.Name)
}
result.Fields = append(result.Fields, &SQLMethodField{
Name: strcase.ToCamel(name),
Raw: name,
Type: c.Type.String(),
})
}
for _, c := range f.result {
name := uglify(c.Name)
name := c.Alias
if name == "" {
name = uglify(c.Name)
}
if c.Type == T_ANY {
result.Result = append(result.Result, &SQLMethodField{
Name: strcase.ToCamel(name),
Type: c.Type.String(),
Raw: name,
})
continue
}
tp, err := p.retypeResult(t.Name, name)
tp, err := p.retypeResult(t.Name, uglify(c.Name))
if err != nil {
return nil, err
}
Expand All @@ -131,8 +139,13 @@ func (p *SQL) Read(path string) (*SQLMethod, error) {
}

var scan bytes.Buffer
for _, r := range result.Result {
scan.WriteString(fmt.Sprintf("&o.%s, ", r.Name))
for _, r := range builder.resultFields {
name := r.Alias
if name == "" {
name = uglify(r.Name)
}
name = strcase.ToCamel(name)
scan.WriteString(fmt.Sprintf("&o.%s, ", name))
}

result.Assign = scan.String()
Expand Down
17 changes: 13 additions & 4 deletions internal/parser/x/query/tidb_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error {
if x.Fields != nil {
for _, f := range x.Fields.Fields {
if expr, ok := f.Expr.(*ast.AggregateFuncExpr); ok {
field := &QueryField{}
field := &QueryField{
Alias: f.AsName.String(),
}
var txt bytes.Buffer
txt.WriteString(expr.F)
for _, args := range expr.Args {
Expand All @@ -106,16 +108,20 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error {
if len(expr.Args) > 0 {
if col, ok := expr.Args[0].(*ast.ColumnNameExpr); ok {
tp.meta.AppendResult(col.Name.Table.String(), field)
tp.b.resultFields = append(tp.b.resultFields, field)
}
}
}
if expr, ok := f.Expr.(*ast.ColumnNameExpr); ok {
field := &QueryField{}
field := &QueryField{
Alias: f.AsName.String(),
}
ff := &strings.Builder{}
expr.Format(ff)
field.Name = ff.String()
field.Type = T_PLACEHOLDER
tp.meta.AppendResult(expr.Name.Table.String(), field)
tp.b.resultFields = append(tp.b.resultFields, field)
}
}
}
Expand Down Expand Up @@ -241,7 +247,8 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error {
}

func (tp *TiDBParser) Parse(ctx context.Context,
query string) (TableMetadata, *QueryBuilder, error) {
query string,
) (TableMetadata, *QueryBuilder, error) {
queries := strings.Split(query, ";")
for _, q := range queries {
if len(strings.TrimSpace(q)) == 0 {
Expand All @@ -268,12 +275,14 @@ func (tp *TiDBParser) Flush() {
ins: map[string]struct{}{},
limit: &LimitOption{},
}
tp.b.resultFields = []*QueryField{}
tp.b.Reset()
tp.meta = make(map[Table]*QueryMetadata)
}

func (tp *TiDBParser) parseOne(ctx context.Context,
query string) error {
query string,
) error {
node, err := parser.New().ParseOneStmt(query, "", "")
if err != nil {
return fmt.Errorf("raw query parser: %w(query: %s)", err, query)
Expand Down
45 changes: 33 additions & 12 deletions internal/parser/x/query/tidb_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ WHERE
u.name = 'me'
`

const queryWithColAs = `
SELECT
u.id as uid
FROM
user u
WHERE
u.name = 'me'
`

const queryWithSubquery = `
SELECT
id
Expand Down Expand Up @@ -98,19 +107,21 @@ func TestTiDBParserParseMetadata(t *testing.T) {
query string
metadata TableMetadata
}{
{"query", query, map[Table]*QueryMetadata{
{Name: "user"}: {
params: []*QueryField{
{Name: "col:`name`", Type: T_ARRAY_STRING},
{Name: "col:`id`", Type: T_INT},
{Name: "col:`phone`", Type: T_STRING},
{Name: "limit:count", Type: T_INT},
{Name: "limit:offset", Type: T_INT},
{
"query", query, map[Table]*QueryMetadata{
{Name: "user"}: {
params: []*QueryField{
{Name: "col:`name`", Type: T_ARRAY_STRING},
{Name: "col:`id`", Type: T_INT},
{Name: "col:`phone`", Type: T_STRING},
{Name: "limit:count", Type: T_INT},
{Name: "limit:offset", Type: T_INT},
},
result: []*QueryField{
{Name: "`id`", Type: T_PLACEHOLDER},
},
},
result: []*QueryField{
{Name: "`id`", Type: T_PLACEHOLDER},
},
}},
},
},
{"queryIn", queryIn, map[Table]*QueryMetadata{
{Name: "user"}: {
Expand Down Expand Up @@ -173,6 +184,16 @@ func TestTiDBParserParseMetadata(t *testing.T) {
},
},
}},
{"queryWithColAs", queryWithColAs, map[Table]*QueryMetadata{
{Name: "user", Alias: "u"}: {
params: []*QueryField{
{Name: "col:`u`.`name`", Type: T_STRING},
},
result: []*QueryField{
{Name: "`u`.`id`", Type: T_PLACEHOLDER, Alias: "uid"},
},
},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit c46406e

Please sign in to comment.