Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ToSQL/CatchSQL funcions for package gdb #2137

Merged
merged 3 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions contrib/drivers/mysql/mysql_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package mysql_test

import (
"context"
"testing"

"github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -70,3 +71,29 @@ func Test_Func_FormatSqlWithArgs(t *testing.T) {
t.Assert(s, "select * from table where id>=100 and sex=1")
})
}

func Test_Func_ToSQL(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
sql, err := gdb.ToSQL(ctx, func(ctx context.Context) error {
value, err := db.Ctx(ctx).Model(TableName).Fields("nickname").Where("id", 1).Value()
t.Assert(value, nil)
return err
})
t.AssertNil(err)
t.Assert(sql, "SELECT `nickname` FROM `user` WHERE `id`=1 LIMIT 1")
})
}

func Test_Func_CatchSQL(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
array, err := gdb.CatchSQL(ctx, func(ctx context.Context) error {
value, err := db.Ctx(ctx).Model(table).Fields("nickname").Where("id", 1).Value()
t.Assert(value, "name_1")
return err
})
t.AssertNil(err)
t.AssertGE(len(array), 1)
})
}
24 changes: 16 additions & 8 deletions database/gdb/gdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"database/sql"
"time"

"github.com/gogf/gf/v2/container/garray"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/container/gvar"
Expand Down Expand Up @@ -171,8 +172,8 @@ type DB interface {
GetCtx() context.Context // See Core.GetCtx.
GetCore() *Core // See Core.GetCore
GetChars() (charLeft string, charRight string) // See Core.GetChars.
Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables.
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields.
Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables. The driver must implement this function.
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields. The driver must implement this function.
ConvertDataForRecord(ctx context.Context, data interface{}) (map[string]interface{}, error) // See Core.ConvertDataForRecord
ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForLocal
CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (string, error) // See Core.CheckLocalTypeForField
Expand Down Expand Up @@ -278,6 +279,11 @@ type (
List = []Map // List is type of map array.
)

type CatchSQLManager struct {
SQLArray *garray.StrArray
DoCommit bool
}

const (
defaultModelSafe = false
defaultCharset = `utf8`
Expand All @@ -292,12 +298,14 @@ const (
ctxTimeoutTypeExec = iota
ctxTimeoutTypeQuery
ctxTimeoutTypePrepare
cachePrefixTableFields = `TableFields:`
cachePrefixSelectCache = `SelectCache:`
commandEnvKeyForDryRun = "gf.gdb.dryrun"
modelForDaoSuffix = `ForDao`
dbRoleSlave = `slave`
contextKeyForDB gctx.StrKey = `DBInContext`
cachePrefixTableFields = `TableFields:`
cachePrefixSelectCache = `SelectCache:`
commandEnvKeyForDryRun = "gf.gdb.dryrun"
modelForDaoSuffix = `ForDao`
dbRoleSlave = `slave`
ctxKeyForDB gctx.StrKey = `CtxKeyForDB`
ctxKeyCatchSQL gctx.StrKey = `CtxKeyCatchSQL`
ctxKeyInternalProducedSQL gctx.StrKey = `CtxKeyInternalProducedSQL`
)

const (
Expand Down
2 changes: 1 addition & 1 deletion database/gdb/gdb_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ func (c *Core) HasTable(name string) (bool, error) {
cacheKey = fmt.Sprintf(`HasTable: %s`, name)
)
result, err := c.GetCache().GetOrSetFuncLock(ctx, cacheKey, func(ctx context.Context) (interface{}, error) {
tableList, err := c.db.Tables(ctx)
tableList, err := c.Tables(ctx)
if err != nil {
return false, err
}
Expand Down
29 changes: 25 additions & 4 deletions database/gdb/gdb_core_underlying.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ import (
"context"
"database/sql"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"

"github.com/gogf/gf/v2"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/internal/intlog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/guid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
gqcn marked this conversation as resolved.
Show resolved Hide resolved
)

// Query commits one query SQL to underlying driver and returns the execution result.
Expand Down Expand Up @@ -58,6 +57,17 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
if err != nil {
return nil, err
}
// SQL format and retrieve.
if v := ctx.Value(ctxKeyCatchSQL); v != nil {
var (
manager = v.(*CatchSQLManager)
formattedSql = FormatSqlWithArgs(sql, args)
)
manager.SQLArray.Append(formattedSql)
if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
return nil, nil
}
}
// Link execution.
var out DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Expand Down Expand Up @@ -102,12 +112,23 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
defer cancelFunc()
}

// Sql filtering.
// SQL filtering.
sql, args = formatSql(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
}
// SQL format and retrieve.
if v := ctx.Value(ctxKeyCatchSQL); v != nil {
var (
manager = v.(*CatchSQLManager)
formattedSql = FormatSqlWithArgs(sql, args)
)
manager.SQLArray.Append(formattedSql)
if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
return new(SqlResult), nil
}
}
// Link execution.
var out DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Expand Down
29 changes: 3 additions & 26 deletions database/gdb/gdb_core_utility.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,6 @@ import (
"github.com/gogf/gf/v2/util/gutil"
)

// WithDB injects given db object into context and returns a new context.
func WithDB(ctx context.Context, db DB) context.Context {
if db == nil {
return ctx
}
dbCtx := db.GetCtx()
if ctxDb := DBFromCtx(dbCtx); ctxDb != nil {
return dbCtx
}
ctx = context.WithValue(ctx, contextKeyForDB, db)
return ctx
}

// DBFromCtx retrieves and returns DB object from context.
func DBFromCtx(ctx context.Context) DB {
if ctx == nil {
return nil
}
v := ctx.Value(contextKeyForDB)
if v != nil {
return v.(DB)
}
return nil
}

// GetLink creates and returns the underlying database link object with transaction checks.
// The parameter `master` specifies whether using the master node if master-slave configured.
func (c *Core) GetLink(ctx context.Context, master bool, schema string) (Link, error) {
Expand Down Expand Up @@ -140,7 +115,8 @@ func (c *Core) GetChars() (charLeft string, charRight string) {
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (c *Core) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
return
ctx = context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{})
return c.db.Tables(ctx, schema...)
}

// TableFields retrieves and returns the fields' information of specified table of current
Expand All @@ -165,6 +141,7 @@ func (c *Core) TableFields(ctx context.Context, table string, schema ...string)
table,
)
value = tableFieldsMap.GetOrSetFuncLock(cacheKey, func() interface{} {
ctx = context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{})
fields, err = c.db.TableFields(ctx, table, schema...)
if err != nil {
return nil
Expand Down
49 changes: 49 additions & 0 deletions database/gdb/gdb_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"strings"
"time"

"github.com/gogf/gf/v2/container/garray"
"github.com/gogf/gf/v2/internal/empty"
"github.com/gogf/gf/v2/internal/reflection"
"github.com/gogf/gf/v2/internal/utils"
Expand Down Expand Up @@ -64,6 +65,54 @@ var (
structTagPriority = append([]string{OrmTagForStruct}, gconv.StructTagPriority...)
)

// WithDB injects given db object into context and returns a new context.
func WithDB(ctx context.Context, db DB) context.Context {
if db == nil {
return ctx
}
dbCtx := db.GetCtx()
if ctxDb := DBFromCtx(dbCtx); ctxDb != nil {
return dbCtx
}
ctx = context.WithValue(ctx, ctxKeyForDB, db)
return ctx
}

// DBFromCtx retrieves and returns DB object from context.
func DBFromCtx(ctx context.Context) DB {
if ctx == nil {
return nil
}
v := ctx.Value(ctxKeyForDB)
if v != nil {
return v.(DB)
}
return nil
}

// ToSQL formats and returns the last one of sql statements in given closure function.
func ToSQL(ctx context.Context, f func(ctx context.Context) error) (sql string, err error) {
var manager = &CatchSQLManager{
SQLArray: garray.NewStrArray(),
DoCommit: false,
}
ctx = context.WithValue(ctx, ctxKeyCatchSQL, manager)
err = f(ctx)
sql, _ = manager.SQLArray.PopRight()
return
}

// CatchSQL catches and returns all sql statements that are executed in given closure function.
func CatchSQL(ctx context.Context, f func(ctx context.Context) error) (sqlArray []string, err error) {
var manager = &CatchSQLManager{
SQLArray: garray.NewStrArray(),
DoCommit: true,
}
ctx = context.WithValue(ctx, ctxKeyCatchSQL, manager)
err = f(ctx)
return manager.SQLArray.Slice(), err
}

// isDoStruct checks and returns whether given type is a DO struct.
func isDoStruct(object interface{}) bool {
// It checks by struct name like "XxxForDao", to be compatible with old version.
Expand Down