From 1d8a4c2fa73fadf9cedddff1d52590678d5f099b Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Tue, 17 Mar 2020 22:05:43 +0800 Subject: [PATCH 1/2] *: handle signed/unsigned in the partition pruning --- expression/simple_rewriter.go | 33 +++++++++-- planner/core/point_get_plan.go | 2 +- planner/core/rule_partition_processor.go | 61 +++++++++---------- table/tables/partition.go | 75 +++++++++++++++++------- table/tables/partition_test.go | 34 +++++++++++ 5 files changed, 144 insertions(+), 61 deletions(-) diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 3ac17173ee30a..a8304e844fdcf 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -14,6 +14,8 @@ package expression import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" @@ -41,10 +43,20 @@ type simpleRewriter struct { // The expression string must only reference the column in table Info. func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo) (Expression, error) { exprStr = "select " + exprStr - stmts, warns, err := parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + var stmts []ast.StmtNode + var err error + if p, ok := ctx.(interface { + ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) + }); ok { + stmts, _, err = p.ParseSQL(context.Background(), exprStr, "", "") + } else { + var warns []error + stmts, warns, err = parser.New().Parse(exprStr, "", "") + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } } + if err != nil { return nil, util.SyntaxError(err) } @@ -102,9 +114,18 @@ func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema * // The expression string must only reference the column in the given NameSlice. func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *Schema, names types.NameSlice) ([]Expression, error) { exprStr = "select " + exprStr - stmts, warns, err := parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + var stmts []ast.StmtNode + var err error + if p, ok := ctx.(interface { + ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) + }); ok { + stmts, _, err = p.ParseSQL(context.Background(), exprStr, "", "") + } else { + var warns []error + stmts, warns, err = parser.New().Parse(exprStr, "", "") + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } } if err != nil { return nil, util.SyntaxWarn(err) diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index f0fd2a802318e..f384a9f8d9ce7 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -1159,7 +1159,7 @@ func getHashPartitionColumnName(ctx sessionctx.Context, tbl *model.TableInfo) *a return nil } // PartitionExpr don't need columns and names for hash partition. - partitionExpr, err := table.(partitionTable).PartitionExpr(ctx, nil, nil) + partitionExpr, err := table.(partitionTable).PartitionExpr() if err != nil { return nil } diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index d14cd0c7f70e3..521f4a242eed0 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -15,12 +15,10 @@ package core import ( "context" "sort" - "strconv" - "strings" - "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" @@ -97,7 +95,7 @@ func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, err // partitionTable is for those tables which implement partition. type partitionTable interface { - PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*tables.PartitionExpr, error) + PartitionExpr() (*tables.PartitionExpr, error) } func generateHashPartitionExpr(t table.Table, ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) { @@ -184,12 +182,25 @@ func (lt *lessThanData) length() int { return len(lt.data) } -func (lt *lessThanData) compare(ith int, v int64) int { +func compareUnsigned(v1, v2 int64) int { + switch { + case uint64(v1) > uint64(v2): + return 1 + case uint64(v1) == uint64(v2): + return 0 + } + return -1 +} + +func (lt *lessThanData) compare(ith int, v int64, unsigned bool) int { if ith == len(lt.data)-1 { if lt.maxvalue { return 1 } } + if unsigned { + return compareUnsigned(lt.data[ith], v) + } switch { case lt.data[ith] > v: return 1 @@ -329,35 +340,20 @@ func (s *partitionProcessor) pruneRangePartition(ds *DataSource, pi *model.Parti result := fullRange(len(pi.Definitions)) if col != nil { - lessThan, err := makeLessThanData(pi) + partExpr, err := ds.table.(partitionTable).PartitionExpr() if err != nil { return nil, err } + lessThan := lessThanData{ + data: partExpr.ForRangePruning.LessThan, + maxvalue: partExpr.ForRangePruning.MaxValue, + } result = partitionRangeForCNFExpr(ds.ctx, ds.allConds, lessThan, col, fn, result) } return s.makeUnionAllChildren(ds, pi, result) } -// makeLessThanData extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...' -func makeLessThanData(pi *model.PartitionInfo) (lessThanData, error) { - var maxValue bool - lessThan := make([]int64, len(pi.Definitions)) - for i := 0; i < len(pi.Definitions); i++ { - if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { - // Use a bool flag instead of math.MaxInt64 to avoid the corner cases. - maxValue = true - } else { - var err error - lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64) - if err != nil { - return lessThanData{}, errors.WithStack(err) - } - } - } - return lessThanData{lessThan, maxValue}, nil -} - // makePartitionByFnCol extracts the column and function information in 'partition by ... fn(col)'. func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, error) { schema := expression.NewSchema(columns...) @@ -408,7 +404,8 @@ func partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression, // Can't prune, return the whole range. return result } - start, end := pruneUseBinarySearch(lessThan, dataForPrune) + unsigned := mysql.HasUnsignedFlag(col.RetType.Flag) + start, end := pruneUseBinarySearch(lessThan, dataForPrune, unsigned) return result.intersectionRange(start, end) } @@ -529,7 +526,7 @@ func relaxOP(op string) string { return op } -func pruneUseBinarySearch(lessThan lessThanData, data dataForPrune) (start int, end int) { +func pruneUseBinarySearch(lessThan lessThanData, data dataForPrune, unsigned bool) (start int, end int) { length := lessThan.length() switch data.op { case ast.EQ: @@ -537,21 +534,21 @@ func pruneUseBinarySearch(lessThan lessThanData, data dataForPrune) (start int, // col = 14, lessThan = [4 7 11 14 17] => [4, 5) // col = 10, lessThan = [4 7 11 14 17] => [2, 3) // col = 3, lessThan = [4 7 11 14 17] => [0, 1) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 }) start, end = pos, pos+1 case ast.LT: // col < 66, lessThan = [4 7 11 14 17] => [0, 5) // col < 14, lessThan = [4 7 11 14 17] => [0, 4) // col < 10, lessThan = [4 7 11 14 17] => [0, 3) // col < 3, lessThan = [4 7 11 14 17] => [0, 1) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) >= 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) >= 0 }) start, end = 0, pos+1 case ast.GE: // col >= 66, lessThan = [4 7 11 14 17] => [5, 5) // col >= 14, lessThan = [4 7 11 14 17] => [4, 5) // col >= 10, lessThan = [4 7 11 14 17] => [2, 5) // col >= 3, lessThan = [4 7 11 14 17] => [0, 5) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 }) start, end = pos, length case ast.GT: // col > 66, lessThan = [4 7 11 14 17] => [5, 5) @@ -559,14 +556,14 @@ func pruneUseBinarySearch(lessThan lessThanData, data dataForPrune) (start int, // col > 10, lessThan = [4 7 11 14 17] => [3, 5) // col > 3, lessThan = [4 7 11 14 17] => [1, 5) // col > 2, lessThan = [4 7 11 14 17] => [0, 5) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1, unsigned) > 0 }) start, end = pos, length case ast.LE: // col <= 66, lessThan = [4 7 11 14 17] => [0, 6) // col <= 14, lessThan = [4 7 11 14 17] => [0, 5) // col <= 10, lessThan = [4 7 11 14 17] => [0, 3) // col <= 3, lessThan = [4 7 11 14 17] => [0, 1) - pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 }) + pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 }) start, end = 0, pos+1 case ast.IsNull: start, end = 0, 1 diff --git a/table/tables/partition.go b/table/tables/partition.go index 6be609365b19a..66e2cc688a2a1 100644 --- a/table/tables/partition.go +++ b/table/tables/partition.go @@ -15,8 +15,10 @@ package tables import ( "bytes" + stderr "errors" "fmt" "sort" + "strconv" "strings" "github.com/pingcap/errors" @@ -113,6 +115,46 @@ type PartitionExpr struct { OrigExpr ast.ExprNode // Expr is the hash partition expression. Expr expression.Expression + // Used in the range pruning process. + *ForRangePruning +} + +// ForRangePruning is used for range partition pruning. +type ForRangePruning struct { + LessThan []int64 + MaxValue bool + Unsigned bool +} + +// dataForRangePruning extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...' +func dataForRangePruning(pi *model.PartitionInfo) (*ForRangePruning, error) { + var maxValue bool + var unsigned bool + lessThan := make([]int64, len(pi.Definitions)) + for i := 0; i < len(pi.Definitions); i++ { + if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { + // Use a bool flag instead of math.MaxInt64 to avoid the corner cases. + maxValue = true + } else { + var err error + lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64) + var numErr *strconv.NumError + if stderr.As(err, &numErr) && numErr.Err == strconv.ErrRange { + var tmp uint64 + tmp, err = strconv.ParseUint(pi.Definitions[i].LessThan[0], 10, 64) + lessThan[i] = int64(tmp) + unsigned = true + } + if err != nil { + return nil, errors.WithStack(err) + } + } + } + return &ForRangePruning{ + LessThan: lessThan, + MaxValue: maxValue, + Unsigned: unsigned, + }, nil } // rangePartitionString returns the partition string for a range typed partition. @@ -134,13 +176,11 @@ func rangePartitionString(pi *model.PartitionInfo) string { func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { // The caller should assure partition info is not nil. - partitionPruneExprs := make([]expression.Expression, 0, len(pi.Definitions)) locateExprs := make([]expression.Expression, 0, len(pi.Definitions)) var buf bytes.Buffer schema := expression.NewSchema(columns...) partStr := rangePartitionString(pi) for i := 0; i < len(pi.Definitions); i++ { - if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { // Expr less than maxvalue is always true. fmt.Fprintf(&buf, "true") @@ -155,28 +195,19 @@ func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, return nil, errors.Trace(err) } locateExprs = append(locateExprs, exprs[0]) - - if i > 0 { - fmt.Fprintf(&buf, " and ((%s) >= (%s))", partStr, pi.Definitions[i-1].LessThan[0]) - } else { - // NULL will locate in the first partition, so its expression is (expr < value or expr is null). - fmt.Fprintf(&buf, " or ((%s) is null)", partStr) - } - - exprs, err = expression.ParseSimpleExprsWithNames(ctx, buf.String(), schema, names) + buf.Reset() + } + ret := &PartitionExpr{ + UpperBounds: locateExprs, + } + if len(pi.Columns) == 0 { + tmp, err := dataForRangePruning(pi) if err != nil { - // If it got an error here, ddl may hang forever, so this error log is important. - logutil.BgLogger().Error("wrong table partition expression", zap.String("expression", buf.String()), zap.Error(err)) return nil, errors.Trace(err) } - // Get a hash code in advance to prevent data race afterwards. - exprs[0].HashCode(ctx.GetSessionVars().StmtCtx) - partitionPruneExprs = append(partitionPruneExprs, exprs[0]) - buf.Reset() + ret.ForRangePruning = tmp } - return &PartitionExpr{ - UpperBounds: locateExprs, - }, nil + return ret, nil } func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, @@ -201,13 +232,13 @@ func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, } // PartitionExpr returns the partition expression. -func (t *partitionedTable) PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { +func (t *partitionedTable) PartitionExpr() (*PartitionExpr, error) { pi := t.meta.GetPartitionInfo() switch pi.Type { case model.PartitionTypeHash: return t.partitionExpr, nil case model.PartitionTypeRange: - return generateRangePartitionExpr(ctx, pi, columns, names) + return t.partitionExpr, nil } panic("cannot reach here") } diff --git a/table/tables/partition_test.go b/table/tables/partition_test.go index 8e29548dddf0a..0709b9d4b73eb 100644 --- a/table/tables/partition_test.go +++ b/table/tables/partition_test.go @@ -368,3 +368,37 @@ func (ts *testSuite) TestCreatePartitionTableNotSupport(c *C) { _, err = tk.Exec(`create table t7 (a int) partition by range (-(select * from t)) (partition p1 values less than (1));`) c.Assert(ddl.ErrPartitionFunctionIsNotAllowed.Equal(err), IsTrue) } + +func (ts *testSuite) TestIntUint(c *C) { + tk := testkit.NewTestKitWithInit(c, ts.store) + tk.MustExec("use test") + tk.MustExec(`create table t_uint (id bigint unsigned) partition by range (id) ( +partition p0 values less than (4294967293), +partition p1 values less than (4294967296), +partition p2 values less than (484467440737095), +partition p3 values less than (18446744073709551614))`) + tk.MustExec("insert into t_uint values (1)") + tk.MustExec("insert into t_uint values (4294967294)") + tk.MustExec("insert into t_uint values (4294967295)") + tk.MustExec("insert into t_uint values (18446744073709551613)") + tk.MustQuery("select * from t_uint where id > 484467440737095").Check(testkit.Rows("18446744073709551613")) + tk.MustQuery("select * from t_uint where id = 4294967295").Check(testkit.Rows("4294967295")) + tk.MustQuery("select * from t_uint where id < 4294967294").Check(testkit.Rows("1")) + tk.MustQuery("select * from t_uint where id >= 4294967293 order by id").Check(testkit.Rows("4294967294", "4294967295", "18446744073709551613")) + + tk.MustExec(`create table t_int (id bigint signed) partition by range (id) ( +partition p0 values less than (-4294967293), +partition p1 values less than (-12345), +partition p2 values less than (0), +partition p3 values less than (484467440737095), +partition p4 values less than (9223372036854775806))`) + tk.MustExec("insert into t_int values (-9223372036854775803)") + tk.MustExec("insert into t_int values (-429496729312)") + tk.MustExec("insert into t_int values (-1)") + tk.MustExec("insert into t_int values (4294967295)") + tk.MustExec("insert into t_int values (9223372036854775805)") + tk.MustQuery("select * from t_int where id > 484467440737095").Check(testkit.Rows("9223372036854775805")) + tk.MustQuery("select * from t_int where id = 4294967295").Check(testkit.Rows("4294967295")) + tk.MustQuery("select * from t_int where id = -4294967294").Check(testkit.Rows()) + tk.MustQuery("select * from t_int where id < -12345 order by id desc").Check(testkit.Rows("-429496729312", "-9223372036854775803")) +} From ae3503c19e2e45b73dadb03550e151296a2392a0 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Tue, 24 Mar 2020 11:38:23 +0800 Subject: [PATCH 2/2] address comment --- expression/simple_rewriter.go | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index a8304e844fdcf..ca3a99284db67 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -45,16 +45,16 @@ func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableI exprStr = "select " + exprStr var stmts []ast.StmtNode var err error + var warns []error if p, ok := ctx.(interface { ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) }); ok { - stmts, _, err = p.ParseSQL(context.Background(), exprStr, "", "") + stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "") } else { - var warns []error stmts, warns, err = parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) - } + } + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) } if err != nil { @@ -92,12 +92,13 @@ func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) { exprStr = "select " + exprStr stmts, warns, err := parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) - } if err != nil { return nil, util.SyntaxWarn(err) } + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + fields := stmts[0].(*ast.SelectStmt).Fields.Fields exprs := make([]Expression, 0, len(fields)) for _, field := range fields { @@ -116,20 +117,21 @@ func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *S exprStr = "select " + exprStr var stmts []ast.StmtNode var err error + var warns []error if p, ok := ctx.(interface { ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) }); ok { - stmts, _, err = p.ParseSQL(context.Background(), exprStr, "", "") + stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "") } else { - var warns []error stmts, warns, err = parser.New().Parse(exprStr, "", "") - for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) - } } if err != nil { return nil, util.SyntaxWarn(err) } + for _, warn := range warns { + ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + fields := stmts[0].(*ast.SelectStmt).Fields.Fields exprs := make([]Expression, 0, len(fields)) for _, field := range fields {