Skip to content

Commit

Permalink
planner, expression: support builtin function NAME_CONST (#9261)
Browse files Browse the repository at this point in the history
  • Loading branch information
spongedu authored and zz-jason committed Feb 19, 2019
1 parent 7c0a9a7 commit ed7bb00
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 14 deletions.
138 changes: 136 additions & 2 deletions expression/builtin_miscellaneous.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ var (
_ builtinFunc = &builtinIsIPv4MappedSig{}
_ builtinFunc = &builtinIsIPv6Sig{}
_ builtinFunc = &builtinUUIDSig{}

_ builtinFunc = &builtinNameConstIntSig{}
_ builtinFunc = &builtinNameConstRealSig{}
_ builtinFunc = &builtinNameConstDecimalSig{}
_ builtinFunc = &builtinNameConstTimeSig{}
_ builtinFunc = &builtinNameConstDurationSig{}
_ builtinFunc = &builtinNameConstStringSig{}
_ builtinFunc = &builtinNameConstJSONSig{}
)

type sleepFunctionClass struct {
Expand Down Expand Up @@ -228,7 +236,7 @@ func (c *anyValueFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Charset, bf.tp.Collate, bf.tp.Flag = mysql.DefaultCharset, mysql.DefaultCollationName, 0
sig = &builtinTimeAnyValueSig{bf}
default:
panic("unexpected types.EvalType of builtin function ANY_VALUE")
return nil, errIncorrectArgs.GenWithStackByArgs("ANY_VALUE")
}
return sig, nil
}
Expand Down Expand Up @@ -808,7 +816,133 @@ type nameConstFunctionClass struct {
}

func (c *nameConstFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "NAME_CONST")
if err := c.verifyArgs(args); err != nil {
return nil, err
}
argTp := args[1].GetType().EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, types.ETString, argTp)
*bf.tp = *args[1].GetType()
var sig builtinFunc
switch argTp {
case types.ETDecimal:
sig = &builtinNameConstDecimalSig{bf}
case types.ETDuration:
sig = &builtinNameConstDurationSig{bf}
case types.ETInt:
bf.tp.Decimal = 0
sig = &builtinNameConstIntSig{bf}
case types.ETJson:
sig = &builtinNameConstJSONSig{bf}
case types.ETReal:
sig = &builtinNameConstRealSig{bf}
case types.ETString:
bf.tp.Decimal = types.UnspecifiedLength
sig = &builtinNameConstStringSig{bf}
case types.ETDatetime, types.ETTimestamp:
bf.tp.Charset, bf.tp.Collate, bf.tp.Flag = mysql.DefaultCharset, mysql.DefaultCollationName, 0
sig = &builtinNameConstTimeSig{bf}
default:
return nil, errIncorrectArgs.GenWithStackByArgs("NAME_CONST")
}
return sig, nil
}

type builtinNameConstDecimalSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstDecimalSig) Clone() builtinFunc {
newSig := &builtinNameConstDecimalSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
return b.args[1].EvalDecimal(b.ctx, row)
}

type builtinNameConstIntSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstIntSig) Clone() builtinFunc {
newSig := &builtinNameConstIntSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstIntSig) evalInt(row chunk.Row) (int64, bool, error) {
return b.args[1].EvalInt(b.ctx, row)
}

type builtinNameConstRealSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstRealSig) Clone() builtinFunc {
newSig := &builtinNameConstRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstRealSig) evalReal(row chunk.Row) (float64, bool, error) {
return b.args[1].EvalReal(b.ctx, row)
}

type builtinNameConstStringSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstStringSig) Clone() builtinFunc {
newSig := &builtinNameConstStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstStringSig) evalString(row chunk.Row) (string, bool, error) {
return b.args[1].EvalString(b.ctx, row)
}

type builtinNameConstJSONSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstJSONSig) Clone() builtinFunc {
newSig := &builtinNameConstJSONSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstJSONSig) evalJSON(row chunk.Row) (json.BinaryJSON, bool, error) {
return b.args[1].EvalJSON(b.ctx, row)
}

type builtinNameConstDurationSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstDurationSig) Clone() builtinFunc {
newSig := &builtinNameConstDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstDurationSig) evalDuration(row chunk.Row) (types.Duration, bool, error) {
return b.args[1].EvalDuration(b.ctx, row)
}

type builtinNameConstTimeSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstTimeSig) Clone() builtinFunc {
newSig := &builtinNameConstTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstTimeSig) evalTime(row chunk.Row) (types.Time, bool, error) {
return b.args[1].EvalTime(b.ctx, row)
}

type releaseAllLocksFunctionClass struct {
Expand Down
45 changes: 45 additions & 0 deletions expression/builtin_miscellaneous_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ package expression
import (
"math"
"strings"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/testleak"
Expand Down Expand Up @@ -320,3 +322,46 @@ func (s *testEvaluatorSuite) TestIsIPv4Compat(c *C) {
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(0))
}

func (s *testEvaluatorSuite) TestNameConst(c *C) {
defer testleak.AfterTest(c)()
dec := types.NewDecFromFloatForTest(123.123)
tm := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
du := types.Duration{Duration: time.Duration(12*time.Hour + 1*time.Minute + 1*time.Second), Fsp: types.DefaultFsp}
cases := []struct {
colName string
arg interface{}
isNil bool
asserts func(d types.Datum)
}{
{"test_int", 3, false, func(d types.Datum) {
c.Assert(d.GetInt64(), Equals, int64(3))
}},
{"test_float", 3.14159, false, func(d types.Datum) {
c.Assert(d.GetFloat64(), Equals, 3.14159)
}},
{"test_string", "TiDB", false, func(d types.Datum) {
c.Assert(d.GetString(), Equals, "TiDB")
}},
{"test_null", nil, true, func(d types.Datum) {
c.Assert(d.Kind(), Equals, types.KindNull)
}},
{"test_decimal", dec, false, func(d types.Datum) {
c.Assert(d.GetMysqlDecimal().String(), Equals, dec.String())
}},
{"test_time", tm, false, func(d types.Datum) {
c.Assert(d.GetMysqlTime().String(), Equals, tm.String())
}},
{"test_duration", du, false, func(d types.Datum) {
c.Assert(d.GetMysqlDuration().String(), Equals, du.String())
}},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.NameConst, s.primitiveValsToConstants([]interface{}{t.colName, t.arg})...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
c.Assert(err, IsNil)
t.asserts(d)
}
}
37 changes: 37 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
Expand Down Expand Up @@ -3935,6 +3936,42 @@ func (s *testIntegrationSuite) TestValuesFloat32(c *C) {
tk.MustQuery(`select * from t;`).Check(testkit.Rows(`1 0.02`))
}

func (s *testIntegrationSuite) TestFuncNameConst(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)
tk.MustExec("USE test;")
tk.MustExec("DROP TABLE IF EXISTS t;")
tk.MustExec("CREATE TABLE t(a CHAR(20), b VARCHAR(20), c BIGINT);")
tk.MustExec("INSERT INTO t (b, c) values('hello', 1);")

r := tk.MustQuery("SELECT name_const('test_int', 1), name_const('test_float', 3.1415);")
r.Check(testkit.Rows("1 3.1415"))
r = tk.MustQuery("SELECT name_const('test_string', 'hello'), name_const('test_nil', null);")
r.Check(testkit.Rows("hello <nil>"))
r = tk.MustQuery("SELECT name_const('test_string', 1) + c FROM t;")
r.Check(testkit.Rows("2"))
r = tk.MustQuery("SELECT concat('hello', name_const('test_string', 'world')) FROM t;")
r.Check(testkit.Rows("helloworld"))
err := tk.ExecToErr(`select name_const(a,b) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(a,"hello") from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const("hello", b) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const("hello", 1+1) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(concat('a', 'b'), 555) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(555) from t;`)
c.Assert(err.Error(), Equals, "[expression:1582]Incorrect parameter count in the call to native function 'name_const'")

var rs sqlexec.RecordSet
rs, err = tk.Exec(`select name_const("hello", 1);`)
c.Assert(err, IsNil)
c.Assert(len(rs.Fields()), Equals, 1)
c.Assert(rs.Fields()[0].Column.Name.L, Equals, "hello")
}

func (s *testIntegrationSuite) TestValuesEnum(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
44 changes: 32 additions & 12 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,18 +558,29 @@ func (b *PlanBuilder) buildProjectionFieldNameFromColumns(field *ast.SelectField
}

// buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression.
func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) model.CIStr {
func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) (model.CIStr, error) {
if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow {
// When the query is select t.a from t group by a; The Column Name should be a but not t.a;
return agg.Args[0].(*ast.ColumnNameExpr).Name.Name
return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil
}

innerExpr := getInnerFromParenthesesAndUnaryPlus(field.Expr)
funcCall, isFuncCall := innerExpr.(*ast.FuncCallExpr)
// When used to produce a result set column, NAME_CONST() causes the column to have the given name.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details
if isFuncCall && funcCall.FnName.L == ast.NameConst {
if v, err := evalAstExpr(b.ctx, funcCall.Args[0]); err == nil {
if s, err := v.ToString(); err == nil {
return model.NewCIStr(s), nil
}
}
return model.NewCIStr(""), ErrWrongArguments.GenWithStackByArgs("NAME_CONST")
}
valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr)

// Non-literal: Output as inputed, except that comments need to be removed.
if !isValueExpr {
return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment))
return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)), nil
}

// Literal: Need special processing
Expand All @@ -585,21 +596,21 @@ func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectF
fieldName := strings.TrimLeftFunc(projName, func(r rune) bool {
return !unicode.IsOneOf(mysql.RangeGraph, r)
})
return model.NewCIStr(fieldName)
return model.NewCIStr(fieldName), nil
case types.KindNull:
// See #4053, #3685
return model.NewCIStr("NULL")
return model.NewCIStr("NULL"), nil
default:
// Keep as it is.
if innerExpr.Text() != "" {
return model.NewCIStr(innerExpr.Text())
return model.NewCIStr(innerExpr.Text()), nil
}
return model.NewCIStr(field.Text())
return model.NewCIStr(field.Text()), nil
}
}

// buildProjectionField builds the field object according to SelectField in projection.
func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) *expression.Column {
func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) (*expression.Column, error) {
var origTblName, tblName, origColName, colName, dbName model.CIStr
if c, ok := expr.(*expression.Column); ok && !c.IsReferenced {
// Field is a column reference.
Expand All @@ -609,7 +620,10 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
colName = field.AsName
} else {
// Other: field is an expression.
colName = b.buildProjectionFieldNameFromExpressions(field)
var err error
if colName, err = b.buildProjectionFieldNameFromExpressions(field); err != nil {
return nil, errors.Trace(err)
}
}
return &expression.Column{
UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(),
Expand All @@ -619,7 +633,7 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
OrigColName: origColName,
DBName: dbName,
RetType: expr.GetType(),
}
}, nil
}

// buildProjection returns a Projection plan and non-aux columns length.
Expand Down Expand Up @@ -648,7 +662,10 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
expr = p.Schema().Columns[i]
}
proj.Exprs = append(proj.Exprs, expr)
col := b.buildProjectionField(proj.id, schema.Len()+1, field, expr)
col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, expr)
if err != nil {
return nil, 0, errors.Trace(err)
}
schema.Append(col)
continue
}
Expand All @@ -660,7 +677,10 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
p = np
proj.Exprs = append(proj.Exprs, newExpr)

col := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr)
col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr)
if err != nil {
return nil, 0, errors.Trace(err)
}
schema.Append(col)
}
proj.SetSchema(schema)
Expand Down
Loading

0 comments on commit ed7bb00

Please sign in to comment.