diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index e1b92dc4023e4..7146bdff628ab 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -14,6 +14,8 @@ package expression import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" @@ -41,10 +43,20 @@ type simpleRewriter struct { // The expression string must only reference the column in table Info. func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo) (Expression, error) { exprStr = "select " + exprStr - stmts, warns, err := parser.New().Parse(exprStr, "", "") + var stmts []ast.StmtNode + var err error + var warns []error + if p, ok := ctx.(interface { + ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) + }); ok { + stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "") + } else { + stmts, warns, err = parser.New().Parse(exprStr, "", "") + } for _, warn := range warns { ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) } + if err != nil { return nil, util.SyntaxError(err) } @@ -80,12 +92,13 @@ func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) { exprStr = "select " + exprStr stmts, warns, err := parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) - } if err != nil { return nil, util.SyntaxWarn(err) } + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + fields := stmts[0].(*ast.SelectStmt).Fields.Fields exprs := make([]Expression, 0, len(fields)) for _, field := range fields { @@ -102,13 +115,23 @@ func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema * // The expression string must only reference the column in the given NameSlice. func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *Schema, names types.NameSlice) ([]Expression, error) { exprStr = "select " + exprStr - stmts, warns, err := parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + var stmts []ast.StmtNode + var err error + var warns []error + if p, ok := ctx.(interface { + ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) + }); ok { + stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "") + } else { + stmts, warns, err = parser.New().Parse(exprStr, "", "") } if err != nil { return nil, util.SyntaxWarn(err) } + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + fields := stmts[0].(*ast.SelectStmt).Fields.Fields exprs := make([]Expression, 0, len(fields)) for _, field := range fields { diff --git a/planner/core/partition_pruning_test.go b/planner/core/partition_pruning_test.go index 0e4568ffea43c..70ef8a5413f17 100644 --- a/planner/core/partition_pruning_test.go +++ b/planner/core/partition_pruning_test.go @@ -118,7 +118,7 @@ func (s *testPartitionPruningSuite) TestPruneUseBinarySearch(c *C) { } for i, ca := range cases { - start, end := pruneUseBinarySearch(lessThan, ca.input) + start, end := pruneUseBinarySearch(lessThan, ca.input, false) c.Assert(ca.result.start, Equals, start, Commentf("fail = %d", i)) c.Assert(ca.result.end, Equals, end, Commentf("fail = %d", i)) } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 3c01372b193ec..4703a17d2ab07 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -1193,7 +1193,7 @@ func getHashPartitionColumnName(ctx sessionctx.Context, tbl *model.TableInfo) *a return nil } // PartitionExpr don't need columns and names for hash partition. - partitionExpr, err := table.(partitionTable).PartitionExpr(ctx, nil, nil) + partitionExpr, err := table.(partitionTable).PartitionExpr() if err != nil { return nil } diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 5d42630c6f326..b0933425af6e0 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -15,10 +15,8 @@ package core import ( "context" "sort" - "strconv" "strings" - "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" @@ -98,7 +96,7 @@ func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, err // partitionTable is for those tables which implement partition. type partitionTable interface { - PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*tables.PartitionExpr, error) + PartitionExpr() (*tables.PartitionExpr, error) } func generateHashPartitionExpr(t table.Table, ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) { @@ -185,12 +183,25 @@ func (lt *lessThanDataInt) length() int { return len(lt.data) } -func (lt *lessThanDataInt) compare(ith int, v int64) int { +func compareUnsigned(v1, v2 int64) int { + switch { + case uint64(v1) > uint64(v2): + return 1 + case uint64(v1) == uint64(v2): + return 0 + } + return -1 +} + +func (lt *lessThanDataInt) compare(ith int, v int64, unsigned bool) int { if ith == len(lt.data)-1 { if lt.maxvalue { return 1 } } + if unsigned { + return compareUnsigned(lt.data[ith], v) + } switch { case lt.data[ith] > v: return 1 @@ -328,37 +339,24 @@ func (s *partitionProcessor) pruneRangePartition(ds *DataSource, pi *model.Parti result := fullRange(len(pi.Definitions)) // Extract the partition column, if the column is not null, it's possible to prune. if col != nil { - // TODO: Store LessThanData in the partitionExpr, avoid allocating here. - lessThan, err := makeLessThanData(pi) + partExpr, err := ds.table.(partitionTable).PartitionExpr() if err != nil { return nil, err } - pruner := rangePruner{lessThan, col, fn} + pruner := rangePruner{ + lessThan: lessThanDataInt{ + data: partExpr.ForRangePruning.LessThan, + maxvalue: partExpr.ForRangePruning.MaxValue, + }, + col: col, + partFn: fn, + } result = partitionRangeForCNFExpr(ds.ctx, ds.allConds, &pruner, result) } return s.makeUnionAllChildren(ds, pi, result) } -// makeLessThanData extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...' -func makeLessThanData(pi *model.PartitionInfo) (lessThanDataInt, error) { - var maxValue bool - lessThan := make([]int64, len(pi.Definitions)) - for i := 0; i < len(pi.Definitions); i++ { - if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { - // Use a bool flag instead of math.MaxInt64 to avoid the corner cases. - maxValue = true - } else { - var err error - lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64) - if err != nil { - return lessThanDataInt{}, errors.WithStack(err) - } - } - } - return lessThanDataInt{lessThan, maxValue}, nil -} - // makePartitionByFnCol extracts the column and function information in 'partition by ... fn(col)'. func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, error) { schema := expression.NewSchema(columns...) @@ -431,7 +429,9 @@ func (p *rangePruner) partitionRangeForExpr(sctx sessionctx.Context, expr expres if !ok { return 0, 0, false } - start, end := pruneUseBinarySearch(p.lessThan, dataForPrune) + + unsigned := mysql.HasUnsignedFlag(p.col.RetType.Flag) + start, end := pruneUseBinarySearch(p.lessThan, dataForPrune, unsigned) return start, end, true } @@ -556,7 +556,7 @@ func relaxOP(op string) string { return op } -func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune) (start int, end int) { +func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune, unsigned bool) (start int, end int) { length := lessThan.length() switch data.op { case ast.EQ: @@ -564,21 +564,21 @@ func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune) (start in // col = 14, lessThan = [4 7 11 14 17] => [4, 5) // col = 10, lessThan = [4 7 11 14 17] => [2, 3) // col = 3, lessThan = [4 7 11 14 17] => [0, 1) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 }) start, end = pos, pos+1 case ast.LT: // col < 66, lessThan = [4 7 11 14 17] => [0, 5) // col < 14, lessThan = [4 7 11 14 17] => [0, 4) // col < 10, lessThan = [4 7 11 14 17] => [0, 3) // col < 3, lessThan = [4 7 11 14 17] => [0, 1) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) >= 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) >= 0 }) start, end = 0, pos+1 case ast.GE: // col >= 66, lessThan = [4 7 11 14 17] => [5, 5) // col >= 14, lessThan = [4 7 11 14 17] => [4, 5) // col >= 10, lessThan = [4 7 11 14 17] => [2, 5) // col >= 3, lessThan = [4 7 11 14 17] => [0, 5) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 }) start, end = pos, length case ast.GT: // col > 66, lessThan = [4 7 11 14 17] => [5, 5) @@ -586,14 +586,14 @@ func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune) (start in // col > 10, lessThan = [4 7 11 14 17] => [3, 5) // col > 3, lessThan = [4 7 11 14 17] => [1, 5) // col > 2, lessThan = [4 7 11 14 17] => [0, 5) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1, unsigned) > 0 }) start, end = pos, length case ast.LE: // col <= 66, lessThan = [4 7 11 14 17] => [0, 6) // col <= 14, lessThan = [4 7 11 14 17] => [0, 5) // col <= 10, lessThan = [4 7 11 14 17] => [0, 3) // col <= 3, lessThan = [4 7 11 14 17] => [0, 1) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 }) start, end = 0, pos+1 case ast.IsNull: start, end = 0, 1 diff --git a/table/tables/partition.go b/table/tables/partition.go index 6be609365b19a..66e2cc688a2a1 100644 --- a/table/tables/partition.go +++ b/table/tables/partition.go @@ -15,8 +15,10 @@ package tables import ( "bytes" + stderr "errors" "fmt" "sort" + "strconv" "strings" "github.com/pingcap/errors" @@ -113,6 +115,46 @@ type PartitionExpr struct { OrigExpr ast.ExprNode // Expr is the hash partition expression. Expr expression.Expression + // Used in the range pruning process. + *ForRangePruning +} + +// ForRangePruning is used for range partition pruning. +type ForRangePruning struct { + LessThan []int64 + MaxValue bool + Unsigned bool +} + +// dataForRangePruning extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...' +func dataForRangePruning(pi *model.PartitionInfo) (*ForRangePruning, error) { + var maxValue bool + var unsigned bool + lessThan := make([]int64, len(pi.Definitions)) + for i := 0; i < len(pi.Definitions); i++ { + if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { + // Use a bool flag instead of math.MaxInt64 to avoid the corner cases. + maxValue = true + } else { + var err error + lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64) + var numErr *strconv.NumError + if stderr.As(err, &numErr) && numErr.Err == strconv.ErrRange { + var tmp uint64 + tmp, err = strconv.ParseUint(pi.Definitions[i].LessThan[0], 10, 64) + lessThan[i] = int64(tmp) + unsigned = true + } + if err != nil { + return nil, errors.WithStack(err) + } + } + } + return &ForRangePruning{ + LessThan: lessThan, + MaxValue: maxValue, + Unsigned: unsigned, + }, nil } // rangePartitionString returns the partition string for a range typed partition. @@ -134,13 +176,11 @@ func rangePartitionString(pi *model.PartitionInfo) string { func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { // The caller should assure partition info is not nil. - partitionPruneExprs := make([]expression.Expression, 0, len(pi.Definitions)) locateExprs := make([]expression.Expression, 0, len(pi.Definitions)) var buf bytes.Buffer schema := expression.NewSchema(columns...) partStr := rangePartitionString(pi) for i := 0; i < len(pi.Definitions); i++ { - if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { // Expr less than maxvalue is always true. fmt.Fprintf(&buf, "true") @@ -155,28 +195,19 @@ func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, return nil, errors.Trace(err) } locateExprs = append(locateExprs, exprs[0]) - - if i > 0 { - fmt.Fprintf(&buf, " and ((%s) >= (%s))", partStr, pi.Definitions[i-1].LessThan[0]) - } else { - // NULL will locate in the first partition, so its expression is (expr < value or expr is null). - fmt.Fprintf(&buf, " or ((%s) is null)", partStr) - } - - exprs, err = expression.ParseSimpleExprsWithNames(ctx, buf.String(), schema, names) + buf.Reset() + } + ret := &PartitionExpr{ + UpperBounds: locateExprs, + } + if len(pi.Columns) == 0 { + tmp, err := dataForRangePruning(pi) if err != nil { - // If it got an error here, ddl may hang forever, so this error log is important. - logutil.BgLogger().Error("wrong table partition expression", zap.String("expression", buf.String()), zap.Error(err)) return nil, errors.Trace(err) } - // Get a hash code in advance to prevent data race afterwards. - exprs[0].HashCode(ctx.GetSessionVars().StmtCtx) - partitionPruneExprs = append(partitionPruneExprs, exprs[0]) - buf.Reset() + ret.ForRangePruning = tmp } - return &PartitionExpr{ - UpperBounds: locateExprs, - }, nil + return ret, nil } func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, @@ -201,13 +232,13 @@ func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, } // PartitionExpr returns the partition expression. -func (t *partitionedTable) PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { +func (t *partitionedTable) PartitionExpr() (*PartitionExpr, error) { pi := t.meta.GetPartitionInfo() switch pi.Type { case model.PartitionTypeHash: return t.partitionExpr, nil case model.PartitionTypeRange: - return generateRangePartitionExpr(ctx, pi, columns, names) + return t.partitionExpr, nil } panic("cannot reach here") } diff --git a/table/tables/partition_test.go b/table/tables/partition_test.go index 8e29548dddf0a..13e1e3ae83b87 100644 --- a/table/tables/partition_test.go +++ b/table/tables/partition_test.go @@ -19,14 +19,11 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/ddl" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/binloginfo" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testkit" ) @@ -266,16 +263,14 @@ func (ts *testSuite) TestGeneratePartitionExpr(c *C) { tbl, err := ts.dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) c.Assert(err, IsNil) type partitionExpr interface { - PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*tables.PartitionExpr, error) + PartitionExpr() (*tables.PartitionExpr, error) } - ctx := mock.NewContext() - columns, names := expression.ColumnInfos2ColumnsAndNames(ctx, model.NewCIStr("test"), tbl.Meta().Name, tbl.Meta().Columns) - pe, err := tbl.(partitionExpr).PartitionExpr(ctx, columns, names) + pe, err := tbl.(partitionExpr).PartitionExpr() c.Assert(err, IsNil) upperBounds := []string{ - "lt(test.t1.id, 4)", - "lt(test.t1.id, 7)", + "lt(t1.id, 4)", + "lt(t1.id, 7)", "1", } for i, expr := range pe.UpperBounds { @@ -368,3 +363,37 @@ func (ts *testSuite) TestCreatePartitionTableNotSupport(c *C) { _, err = tk.Exec(`create table t7 (a int) partition by range (-(select * from t)) (partition p1 values less than (1));`) c.Assert(ddl.ErrPartitionFunctionIsNotAllowed.Equal(err), IsTrue) } + +func (ts *testSuite) TestIntUint(c *C) { + tk := testkit.NewTestKitWithInit(c, ts.store) + tk.MustExec("use test") + tk.MustExec(`create table t_uint (id bigint unsigned) partition by range (id) ( +partition p0 values less than (4294967293), +partition p1 values less than (4294967296), +partition p2 values less than (484467440737095), +partition p3 values less than (18446744073709551614))`) + tk.MustExec("insert into t_uint values (1)") + tk.MustExec("insert into t_uint values (4294967294)") + tk.MustExec("insert into t_uint values (4294967295)") + tk.MustExec("insert into t_uint values (18446744073709551613)") + tk.MustQuery("select * from t_uint where id > 484467440737095").Check(testkit.Rows("18446744073709551613")) + tk.MustQuery("select * from t_uint where id = 4294967295").Check(testkit.Rows("4294967295")) + tk.MustQuery("select * from t_uint where id < 4294967294").Check(testkit.Rows("1")) + tk.MustQuery("select * from t_uint where id >= 4294967293 order by id").Check(testkit.Rows("4294967294", "4294967295", "18446744073709551613")) + + tk.MustExec(`create table t_int (id bigint signed) partition by range (id) ( +partition p0 values less than (-4294967293), +partition p1 values less than (-12345), +partition p2 values less than (0), +partition p3 values less than (484467440737095), +partition p4 values less than (9223372036854775806))`) + tk.MustExec("insert into t_int values (-9223372036854775803)") + tk.MustExec("insert into t_int values (-429496729312)") + tk.MustExec("insert into t_int values (-1)") + tk.MustExec("insert into t_int values (4294967295)") + tk.MustExec("insert into t_int values (9223372036854775805)") + tk.MustQuery("select * from t_int where id > 484467440737095").Check(testkit.Rows("9223372036854775805")) + tk.MustQuery("select * from t_int where id = 4294967295").Check(testkit.Rows("4294967295")) + tk.MustQuery("select * from t_int where id = -4294967294").Check(testkit.Rows()) + tk.MustQuery("select * from t_int where id < -12345 order by id desc").Check(testkit.Rows("-429496729312", "-9223372036854775803")) +}