diff --git a/br/pkg/lightning/backend/kv/sql2kv.go b/br/pkg/lightning/backend/kv/sql2kv.go index 6cebb1e29e329..9ad552ef5f340 100644 --- a/br/pkg/lightning/backend/kv/sql2kv.go +++ b/br/pkg/lightning/backend/kv/sql2kv.go @@ -169,7 +169,7 @@ func collectGeneratedColumns(se *session, meta *model.TableInfo, cols []*table.C var genCols []genCol for i, col := range cols { if col.GeneratedExpr != nil { - expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names) + expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names, false) if err != nil { return nil, err } diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 2850a3aa968a5..429d3f13425b2 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -6164,7 +6164,7 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. // For same reason, decide whether index is global here. - indexColumns, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) + indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) if err != nil { return errors.Trace(err) } @@ -6274,7 +6274,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as if err != nil { return nil, errors.Trace(err) } - expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr) + expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr, true) if err != nil { // TODO: refine the error message. return nil, err @@ -6389,7 +6389,7 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. // For same reason, decide whether index is global here. - indexColumns, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) + indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) if err != nil { return errors.Trace(err) } diff --git a/ddl/generated_column.go b/ddl/generated_column.go index 2f4ceee8b60a9..678d803edf521 100644 --- a/ddl/generated_column.go +++ b/ddl/generated_column.go @@ -268,12 +268,14 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol } type illegalFunctionChecker struct { - hasIllegalFunc bool - hasAggFunc bool - hasRowVal bool // hasRowVal checks whether the functional index refers to a row value - hasWindowFunc bool - hasNotGAFunc4ExprIdx bool - otherErr error + hasIllegalFunc bool + hasAggFunc bool + hasRowVal bool // hasRowVal checks whether the functional index refers to a row value + hasWindowFunc bool + hasNotGAFunc4ExprIdx bool + hasCastArrayFunc bool + disallowCastArrayFunc bool + otherErr error } func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { @@ -308,7 +310,14 @@ func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipC case *ast.WindowFuncExpr: c.hasWindowFunc = true return inNode, true + case *ast.FuncCastExpr: + c.hasCastArrayFunc = c.hasCastArrayFunc || node.Tp.IsArray() + if c.disallowCastArrayFunc && node.Tp.IsArray() { + c.otherErr = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions") + return inNode, true + } } + c.disallowCastArrayFunc = true return inNode, false } @@ -355,6 +364,9 @@ func checkIllegalFn4Generated(name string, genType int, expr ast.ExprNode) error if genType == typeIndex && c.hasNotGAFunc4ExprIdx && !config.GetGlobalConfig().Experimental.AllowsExpressionIndex { return dbterror.ErrUnsupportedExpressionIndex } + if genType == typeColumn && c.hasCastArrayFunc { + return expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions") + } return nil } diff --git a/ddl/index.go b/ddl/index.go index 0f70b73b61046..273b89e041233 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -64,26 +64,28 @@ var ( telemetryAddIndexIngestUsage = metrics.TelemetryAddIndexIngestCnt ) -func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, error) { +func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { // Build offsets. idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications)) var col *model.ColumnInfo + var mvIndex bool maxIndexLength := config.GetGlobalConfig().MaxIndexLength // The sum of length of all index columns. sumLength := 0 for _, ip := range indexPartSpecifications { col = model.FindColumnInfo(columns, ip.Column.Name.L) if col == nil { - return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) + return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) } if err := checkIndexColumn(ctx, col, ip.Length); err != nil { - return nil, err + return nil, false, err } + mvIndex = mvIndex || col.FieldType.IsArray() indexColLen := ip.Length indexColumnLength, err := getIndexColumnLength(col, ip.Length) if err != nil { - return nil, err + return nil, false, err } sumLength += indexColumnLength @@ -92,12 +94,12 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde // The multiple column index and the unique index in which the length sum exceeds the maximum size // will return an error instead produce a warning. if ctx == nil || ctx.GetSessionVars().StrictSQLMode || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { - return nil, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength) + return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength) } // truncate index length and produce warning message in non-restrict sql mode. colLenPerUint, err := getIndexColumnLength(col, 1) if err != nil { - return nil, err + return nil, false, err } indexColLen = maxIndexLength / colLenPerUint // produce warning message @@ -111,7 +113,7 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde }) } - return idxParts, nil + return idxParts, mvIndex, nil } // CheckPKOnGeneratedColumn checks the specification of PK is valid. @@ -154,7 +156,7 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn } // JSON column cannot index. - if col.FieldType.GetType() == mysql.TypeJSON { + if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() { if col.Hidden { return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction } @@ -263,7 +265,7 @@ func BuildIndexInfo( return nil, errors.Trace(err) } - idxColumns, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications) + idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications) if err != nil { return nil, errors.Trace(err) } @@ -276,6 +278,7 @@ func BuildIndexInfo( Primary: isPrimary, Unique: isUnique, Global: isGlobal, + MVIndex: mvIndex, } if indexOption != nil { diff --git a/ddl/partition.go b/ddl/partition.go index 0a1ea4e6fbe66..2c95f389707f9 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1375,7 +1375,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, tblInfo * return nil } - e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr) + e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr, false) if err != nil { return errors.Trace(err) } diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index c7304642c544a..5a201d906b5a3 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -177,6 +177,7 @@ go_test( "integration_serial_test.go", "integration_test.go", "main_test.go", + "multi_valued_index_test.go", "scalar_function_test.go", "schema_test.go", "typeinfer_test.go", diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index ee66669e638d6..e6257c4dd058c 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -23,6 +23,7 @@ package expression import ( + "fmt" "math" "strconv" "strings" @@ -407,6 +408,70 @@ func (c *castAsDurationFunctionClass) getFunction(ctx sessionctx.Context, args [ return sig, nil } +type castAsArrayFunctionClass struct { + baseFunctionClass + + tp *types.FieldType +} + +func (c *castAsArrayFunctionClass) verifyArgs(args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + + if args[0].GetType().EvalType() != types.ETJson { + return types.ErrInvalidJSONData.GenWithStackByArgs("1", "cast_as_array") + } + + return nil +} + +func (c *castAsArrayFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + arrayType := c.tp.ArrayType() + switch arrayType.GetType() { + case mysql.TypeYear, mysql.TypeJSON: + return nil, ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("CAST-ing data to array of %s", arrayType.String())) + } + if arrayType.EvalType() == types.ETString && arrayType.GetCharset() != charset.CharsetUTF8MB4 && arrayType.GetCharset() != charset.CharsetBin { + return nil, ErrNotSupportedYet.GenWithStackByArgs("specifying charset for multi-valued index", arrayType.String()) + } + + bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp) + if err != nil { + return nil, err + } + sig = &castJSONAsArrayFunctionSig{bf} + return sig, nil +} + +type castJSONAsArrayFunctionSig struct { + baseBuiltinFunc +} + +func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc { + newSig := &castJSONAsArrayFunctionSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + val, isNull, err := b.args[0].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + if val.TypeCode != types.JSONTypeCodeArray { + return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAST-ing Non-JSON Array type to array") + } + + // TODO: impl the cast(... as ... array) function + + return types.BinaryJSON{}, false, nil +} + type castAsJSONFunctionClass struct { baseFunctionClass @@ -1914,6 +1979,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + res, err := BuildCastFunctionWithCheck(ctx, expr, tp) + terror.Log(err) + return +} + +// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any. +func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) { argType := expr.GetType() // If source argument's nullable, then target type should be nullable if !mysql.HasNotNullFlag(argType.GetFlag()) { @@ -1933,7 +2005,11 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT case types.ETDuration: fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETJson: - fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + if tp.IsArray() { + fc = &castAsArrayFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + } else { + fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + } case types.ETString: fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} if expr.GetType().GetType() == mysql.TypeBit { @@ -1941,7 +2017,6 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT } } f, err := fc.getFunction(ctx, []Expression{expr}) - terror.Log(err) res = &ScalarFunction{ FuncName: model.NewCIStr(ast.Cast), RetType: tp, @@ -1950,10 +2025,10 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT // We do not fold CAST if the eval type of this scalar function is ETJson // since we may reset the flag of the field type of CastAsJson later which // would affect the evaluation of it. - if tp.EvalType() != types.ETJson { + if tp.EvalType() != types.ETJson && err == nil { res = FoldConstant(res) } - return res + return res, err } // WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not diff --git a/expression/errors.go b/expression/errors.go index 0db38645f78d4..c56737ec2fae3 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -37,6 +37,7 @@ var ( ErrInvalidTableSample = dbterror.ClassExpression.NewStd(mysql.ErrInvalidTableSample) ErrInternal = dbterror.ClassOptimizer.NewStd(mysql.ErrInternal) ErrNoDB = dbterror.ClassOptimizer.NewStd(mysql.ErrNoDB) + ErrNotSupportedYet = dbterror.ClassExpression.NewStd(mysql.ErrNotSupportedYet) // All the un-exported errors are defined here: errFunctionNotExists = dbterror.ClassExpression.NewStd(mysql.ErrSpDoesNotExist) diff --git a/expression/expression.go b/expression/expression.go index 024bac00ef960..352f105c52d65 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -59,7 +59,7 @@ var EvalAstExpr func(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, e // RewriteAstExpr rewrites ast expression directly. // Note: initialized in planner/core // import expression and planner/core together to use EvalAstExpr -var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice) (Expression, error) +var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice, allowCastArray bool) (Expression, error) // VecExpr contains all vectorized evaluation methods. type VecExpr interface { @@ -998,7 +998,7 @@ func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.C if err != nil { return nil, nil, errors.Trace(err) } - e, err := RewriteAstExpr(ctx, expr, mockSchema, names) + e, err := RewriteAstExpr(ctx, expr, mockSchema, names, false) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/expression/multi_valued_index_test.go b/expression/multi_valued_index_test.go new file mode 100644 index 0000000000000..058d955faa4fb --- /dev/null +++ b/expression/multi_valued_index_test.go @@ -0,0 +1,47 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression_test + +import ( + "testing" + + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/testkit" +) + +func TestMultiValuedIndexDDL(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustExec("create table t(a json);") + tk.MustGetErrCode("select cast(a as signed array) from t", errno.ErrNotSupportedYet) + tk.MustGetErrCode("select json_extract(cast(a as signed array), '$[0]') from t", errno.ErrNotSupportedYet) + tk.MustGetErrCode("select * from t where cast(a as signed array)", errno.ErrNotSupportedYet) + tk.MustGetErrCode("select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet) + + tk.MustExec("drop table t") + tk.MustGetErrCode("CREATE TABLE t(x INT, KEY k ((1 AND CAST(JSON_ARRAY(x) AS UNSIGNED ARRAY))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(cast(f1 as unsigned array) as unsigned array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->>'$[*]' as unsigned array))));", errno.ErrInvalidJSONData) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as year array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as json array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as char(10) charset gbk array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("create table t(j json, gc json as ((concat(cast(j->'$[*]' as unsigned array),\"x\"))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("create table t(j json, gc json as (cast(j->'$[*]' as unsigned array)));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("create view v as select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet) + tk.MustExec("create table t(a json, index idx((cast(a as signed array))));") +} diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 808db9f69b4cf..3343a0cbaa169 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -48,7 +48,7 @@ func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableI return nil, errors.Trace(err) } expr := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr - return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr) + return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr, false) } // ParseSimpleExprCastWithTableInfo parses simple expression string to Expression. @@ -63,13 +63,13 @@ func ParseSimpleExprCastWithTableInfo(ctx sessionctx.Context, exprStr string, ta } // RewriteSimpleExprWithTableInfo rewrites simple ast.ExprNode to expression.Expression. -func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode) (Expression, error) { +func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode, allowCastArray bool) (Expression, error) { dbName := model.NewCIStr(ctx.GetSessionVars().CurrentDB) columns, names, err := ColumnInfos2ColumnsAndNames(ctx, dbName, tbl.Name, tbl.Cols(), tbl) if err != nil { return nil, err } - e, err := RewriteAstExpr(ctx, expr, NewSchema(columns...), names) + e, err := RewriteAstExpr(ctx, expr, NewSchema(columns...), names, allowCastArray) if err != nil { return nil, err } @@ -111,7 +111,7 @@ func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *S // RewriteSimpleExprWithNames rewrites simple ast.ExprNode to expression.Expression. func RewriteSimpleExprWithNames(ctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names []*types.FieldName) (Expression, error) { - e, err := RewriteAstExpr(ctx, expr, schema, names) + e, err := RewriteAstExpr(ctx, expr, schema, names, false) if err != nil { return nil, err } diff --git a/parser/model/model.go b/parser/model/model.go index ba7c46bcd6333..19aabf4a06572 100644 --- a/parser/model/model.go +++ b/parser/model/model.go @@ -1419,6 +1419,7 @@ type IndexInfo struct { Primary bool `json:"is_primary"` // Whether the index is primary key. Invisible bool `json:"is_invisible"` // Whether the index is invisible. Global bool `json:"is_global"` // Whether the index is global. + MVIndex bool `json:"mv_index"` // Whether the index is multivalued index. } // Clone clones IndexInfo. diff --git a/parser/parser_test.go b/parser/parser_test.go index c06d2076f085a..7b72117f69d16 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1603,6 +1603,7 @@ func TestBuiltin(t *testing.T) { {"select cast(time '2000' as year);", true, "SELECT CAST(TIME '2000' AS YEAR)"}, {"select cast(b as signed array);", true, "SELECT CAST(`b` AS SIGNED ARRAY)"}, + {"select cast(b as char(10) array);", true, "SELECT CAST(`b` AS CHAR(10) ARRAY)"}, // for last_insert_id {"SELECT last_insert_id();", true, "SELECT LAST_INSERT_ID()"}, diff --git a/parser/types/field_type.go b/parser/types/field_type.go index 369ed59fa7a59..ff0ac9793cf17 100644 --- a/parser/types/field_type.go +++ b/parser/types/field_type.go @@ -72,7 +72,7 @@ func NewFieldType(tp byte) *FieldType { // IsDecimalValid checks whether the decimal is valid. func (ft *FieldType) IsDecimalValid() bool { - if ft.tp == mysql.TypeNewDecimal && (ft.decimal < 0 || ft.decimal > mysql.MaxDecimalScale || ft.flen <= 0 || ft.flen > mysql.MaxDecimalWidth || ft.flen < ft.decimal) { + if ft.GetType() == mysql.TypeNewDecimal && (ft.decimal < 0 || ft.decimal > mysql.MaxDecimalScale || ft.flen <= 0 || ft.flen > mysql.MaxDecimalWidth || ft.flen < ft.decimal) { return false } return true @@ -80,7 +80,7 @@ func (ft *FieldType) IsDecimalValid() bool { // IsVarLengthType Determine whether the column type is a variable-length type func (ft *FieldType) IsVarLengthType() bool { - switch ft.tp { + switch ft.GetType() { case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeJSON, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: return true default: @@ -90,6 +90,9 @@ func (ft *FieldType) IsVarLengthType() bool { // GetType returns the type of the FieldType. func (ft *FieldType) GetType() byte { + if ft.array { + return mysql.TypeJSON + } return ft.tp } @@ -126,6 +129,7 @@ func (ft *FieldType) GetElems() []string { // SetType sets the type of the FieldType. func (ft *FieldType) SetType(tp byte) { ft.tp = tp + ft.array = false } // SetFlag sets the flag of the FieldType. @@ -160,7 +164,7 @@ func (ft *FieldType) SetFlen(flen int) { // SetFlenUnderLimit sets the length of the field to the value of the argument func (ft *FieldType) SetFlenUnderLimit(flen int) { - if ft.tp == mysql.TypeNewDecimal { + if ft.GetType() == mysql.TypeNewDecimal { ft.flen = mathutil.Min(flen, mysql.MaxDecimalWidth) } else { ft.flen = flen @@ -174,7 +178,7 @@ func (ft *FieldType) SetDecimal(decimal int) { // SetDecimalUnderLimit sets the decimal of the field to the value of the argument func (ft *FieldType) SetDecimalUnderLimit(decimal int) { - if ft.tp == mysql.TypeNewDecimal { + if ft.GetType() == mysql.TypeNewDecimal { ft.decimal = mathutil.Min(decimal, mysql.MaxDecimalScale) } else { ft.decimal = decimal @@ -183,7 +187,7 @@ func (ft *FieldType) SetDecimalUnderLimit(decimal int) { // UpdateFlenAndDecimalUnderLimit updates the length and decimal to the value of the argument func (ft *FieldType) UpdateFlenAndDecimalUnderLimit(old *FieldType, deltaDecimal int, deltaFlen int) { - if ft.tp != mysql.TypeNewDecimal { + if ft.GetType() != mysql.TypeNewDecimal { return } if old.decimal < 0 { @@ -229,6 +233,13 @@ func (ft *FieldType) IsArray() bool { return ft.array } +// ArrayType return the type of the array. +func (ft *FieldType) ArrayType() *FieldType { + clone := ft.Clone() + clone.SetArray(false) + return clone +} + // SetElemWithIsBinaryLit sets the element of the FieldType. func (ft *FieldType) SetElemWithIsBinaryLit(idx int, element string, isBinaryLit bool) { ft.elems[idx] = element @@ -274,7 +285,7 @@ func (ft *FieldType) Equal(other *FieldType) bool { // When tp is float or double with decimal unspecified, do not check whether flen is equal, // because flen for them is useless. // The decimal field can be ignored if the type is int or string. - tpEqual := (ft.tp == other.tp) || (ft.tp == mysql.TypeVarchar && other.tp == mysql.TypeVarString) || (ft.tp == mysql.TypeVarString && other.tp == mysql.TypeVarchar) + tpEqual := (ft.GetType() == other.GetType()) || (ft.GetType() == mysql.TypeVarchar && other.GetType() == mysql.TypeVarString) || (ft.GetType() == mysql.TypeVarString && other.GetType() == mysql.TypeVarchar) flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) ignoreDecimal := ft.EvalType() == ETInt || ft.EvalType() == ETString partialEqual := tpEqual && @@ -316,7 +327,7 @@ func (ft *FieldType) PartialEqual(other *FieldType, unsafe bool) bool { // EvalType gets the type in evaluation. func (ft *FieldType) EvalType() EvalType { - switch ft.tp { + switch ft.GetType() { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeBit, mysql.TypeYear: return ETInt @@ -342,7 +353,7 @@ func (ft *FieldType) EvalType() EvalType { // Hybrid checks whether a type is a hybrid type, which can represent different types of value in specific context. func (ft *FieldType) Hybrid() bool { - return ft.tp == mysql.TypeEnum || ft.tp == mysql.TypeBit || ft.tp == mysql.TypeSet + return ft.GetType() == mysql.TypeEnum || ft.GetType() == mysql.TypeBit || ft.GetType() == mysql.TypeSet } // Init initializes the FieldType data. @@ -355,10 +366,10 @@ func (ft *FieldType) Init(tp byte) { // CompactStr only considers tp/CharsetBin/flen/Deimal. // This is used for showing column type in infoschema. func (ft *FieldType) CompactStr() string { - ts := TypeToStr(ft.tp, ft.charset) + ts := TypeToStr(ft.GetType(), ft.charset) suffix := "" - defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.tp) + defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) isDecimalNotDefault := ft.decimal != defaultDecimal && ft.decimal != 0 && ft.decimal != UnspecifiedLength // displayFlen and displayDecimal are flen and decimal values with `-1` substituted with default value. @@ -370,7 +381,7 @@ func (ft *FieldType) CompactStr() string { displayDecimal = defaultDecimal } - switch ft.tp { + switch ft.GetType() { case mysql.TypeEnum, mysql.TypeSet: // Format is ENUM ('e1', 'e2') or SET ('e1', 'e2') es := make([]string, 0, len(ft.elems)) @@ -414,8 +425,8 @@ func (ft *FieldType) CompactStr() string { func (ft *FieldType) InfoSchemaStr() string { suffix := "" if mysql.HasUnsignedFlag(ft.flag) && - ft.tp != mysql.TypeBit && - ft.tp != mysql.TypeYear { + ft.GetType() != mysql.TypeBit && + ft.GetType() != mysql.TypeYear { suffix = " unsigned" } return ft.CompactStr() + suffix @@ -431,11 +442,11 @@ func (ft *FieldType) String() string { if mysql.HasZerofillFlag(ft.flag) { strs = append(strs, "ZEROFILL") } - if mysql.HasBinaryFlag(ft.flag) && ft.tp != mysql.TypeString { + if mysql.HasBinaryFlag(ft.flag) && ft.GetType() != mysql.TypeString { strs = append(strs, "BINARY") } - if IsTypeChar(ft.tp) || IsTypeBlob(ft.tp) { + if IsTypeChar(ft.GetType()) || IsTypeBlob(ft.GetType()) { if ft.charset != "" && ft.charset != charset.CharsetBin { strs = append(strs, fmt.Sprintf("CHARACTER SET %s", ft.charset)) } @@ -449,12 +460,12 @@ func (ft *FieldType) String() string { // Restore implements Node interface. func (ft *FieldType) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord(TypeToStr(ft.tp, ft.charset)) + ctx.WriteKeyWord(TypeToStr(ft.GetType(), ft.charset)) precision := UnspecifiedLength scale := UnspecifiedLength - switch ft.tp { + switch ft.GetType() { case mysql.TypeEnum, mysql.TypeSet: ctx.WritePlain("(") for i, e := range ft.elems { @@ -491,7 +502,7 @@ func (ft *FieldType) Restore(ctx *format.RestoreCtx) error { ctx.WriteKeyWord(" BINARY") } - if IsTypeChar(ft.tp) || IsTypeBlob(ft.tp) { + if IsTypeChar(ft.GetType()) || IsTypeBlob(ft.GetType()) { if ft.charset != "" && ft.charset != charset.CharsetBin { ctx.WriteKeyWord(" CHARACTER SET " + ft.charset) } @@ -519,7 +530,7 @@ func (ft *FieldType) RestoreAsCastType(ctx *format.RestoreCtx, explicitCharset b ctx.WritePlainf("(%d)", ft.flen) } if !explicitCharset { - return + break } if !skipWriteBinary && ft.flag&mysql.BinaryFlag != 0 { ctx.WriteKeyWord(" BINARY") @@ -581,7 +592,7 @@ const VarStorageLen = -1 // StorageLength is the length of stored value for the type. func (ft *FieldType) StorageLength() int { - switch ft.tp { + switch ft.GetType() { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeYear, mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp, mysql.TypeEnum, mysql.TypeSet, @@ -599,7 +610,7 @@ func (ft *FieldType) StorageLength() int { // HasCharset indicates if a COLUMN has an associated charset. Returning false here prevents some information // statements(like `SHOW CREATE TABLE`) from attaching a CHARACTER SET clause to the column. func HasCharset(ft *FieldType) bool { - switch ft.tp { + switch ft.GetType() { case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: return !mysql.HasBinaryFlag(ft.flag) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index d0ca6e6f8e4cf..ddb905dc5c06b 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -55,7 +55,7 @@ func evalAstExpr(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error if val, ok := expr.(*driver.ValueExpr); ok { return val.Datum, nil } - newExpr, err := rewriteAstExpr(sctx, expr, nil, nil) + newExpr, err := rewriteAstExpr(sctx, expr, nil, nil, false) if err != nil { return types.Datum{}, err } @@ -63,13 +63,14 @@ func evalAstExpr(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error } // rewriteAstExpr rewrites ast expression directly. -func rewriteAstExpr(sctx sessionctx.Context, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice) (expression.Expression, error) { +func rewriteAstExpr(sctx sessionctx.Context, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice, allowCastArray bool) (expression.Expression, error) { var is infoschema.InfoSchema // in tests, it may be null if s, ok := sctx.GetInfoSchema().(infoschema.InfoSchema); ok { is = s } b, savedBlockNames := NewPlanBuilder().Init(sctx, is, &hint.BlockHintProcessor{}) + b.allowBuildCastArray = allowCastArray fakePlan := LogicalTableDual{}.Init(sctx, 0) if schema != nil { fakePlan.schema = schema @@ -1183,6 +1184,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.disableFoldCounter-- } case *ast.FuncCastExpr: + if v.Tp.IsArray() && !er.b.allowBuildCastArray { + er.err = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions") + return retNode, false + } arg := er.ctxStack[len(er.ctxStack)-1] er.err = expression.CheckArgsNotMultiColumnRow(arg) if er.err != nil { @@ -1195,7 +1200,11 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return retNode, false } - castFunction := expression.BuildCastFunction(er.sctx, arg, v.Tp) + castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp) + if err != nil { + er.err = err + return retNode, false + } if v.Tp.EvalType() == types.ETString { castFunction.SetCoercibility(expression.CoercibilityImplicit) if v.Tp.GetCharset() == charset.CharsetASCII { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 9201f953bdcdc..df7a0d893e4ed 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -577,6 +577,9 @@ type PlanBuilder struct { // disableSubQueryPreprocessing indicates whether to pre-process uncorrelated sub-queries in rewriting stage. disableSubQueryPreprocessing bool + + // allowBuildCastArray indicates whether allow cast(... as ... array). + allowBuildCastArray bool } type handleColHelper struct { @@ -697,6 +700,14 @@ func (p PlanBuilderOptNoExecution) Apply(builder *PlanBuilder) { builder.disableSubQueryPreprocessing = true } +// PlanBuilderOptAllowCastArray means the plan builder should allow build cast(... as ... array). +type PlanBuilderOptAllowCastArray struct{} + +// Apply implements the interface PlanBuilderOpt. +func (p PlanBuilderOptAllowCastArray) Apply(builder *PlanBuilder) { + builder.allowBuildCastArray = true +} + // NewPlanBuilder creates a new PlanBuilder. func NewPlanBuilder(opts ...PlanBuilderOpt) *PlanBuilder { builder := &PlanBuilder{ diff --git a/types/field_type_builder.go b/types/field_type_builder.go index 7c9f3bdc3177d..81554c4585442 100644 --- a/types/field_type_builder.go +++ b/types/field_type_builder.go @@ -114,6 +114,12 @@ func (b *FieldTypeBuilder) SetElems(elems []string) *FieldTypeBuilder { return b } +// SetArray sets array of the ft +func (b *FieldTypeBuilder) SetArray(x bool) *FieldTypeBuilder { + b.ft.SetArray(x) + return b +} + // Build returns the ft func (b *FieldTypeBuilder) Build() FieldType { return b.ft