From 4d5d89ba23efd501a029da5775de7d75df9a42f3 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Tue, 31 Oct 2023 14:26:06 +0800 Subject: [PATCH] planner: fix cast(col) = range couldn't build range when cast function doesn't contain any precision loss in some cases (#46303) (#46546) close pingcap/tidb#45199 --- expression/util.go | 155 ++++++++++++++++++ planner/core/integration_test.go | 50 ++++++ planner/core/stats.go | 7 +- .../core/testdata/integration_suite_in.json | 15 ++ .../core/testdata/integration_suite_out.json | 133 +++++++++++++++ 5 files changed, 358 insertions(+), 2 deletions(-) diff --git a/expression/util.go b/expression/util.go index 8c8db64c75136..28ae372a9e20a 100644 --- a/expression/util.go +++ b/expression/util.go @@ -670,6 +670,152 @@ func pushNotAcrossArgs(ctx sessionctx.Context, exprs []Expression, not bool) ([] return newExprs, flag } +// todo: consider more no precision-loss downcast cases. +func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool { + // now only consider varchar type and integer. + if !(types.IsTypeVarchar(cast.GetType()) && types.IsTypeVarchar(argCol.GetType())) && + !(mysql.IsIntegerType(cast.GetType()) && mysql.IsIntegerType(argCol.GetType())) { + // varchar type and integer on the storage layer is quite same, while the char type has its padding suffix. + return false + } + if types.IsTypeVarchar(cast.GetType()) { + // cast varchar function only bear the flen extension. + if cast.GetFlen() < argCol.GetFlen() { + return false + } + if !collate.CompatibleCollate(cast.GetCollate(), argCol.GetCollate()) { + return false + } + } else { + // For integers, we should ignore the potential display length represented by flen, using the default flen of the type. + castFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(cast.GetType()) + originFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(argCol.GetType()) + // cast integer function only bear the flen extension and signed symbol unchanged. + if castFlen < originFlen { + return false + } + if mysql.HasUnsignedFlag(cast.GetFlag()) != mysql.HasUnsignedFlag(argCol.GetFlag()) { + return false + } + } + return true +} + +func unwrapCast(sctx sessionctx.Context, parentF *ScalarFunction, castOffset int) (Expression, bool) { + _, collation := parentF.CharsetAndCollation() + cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction) + if !ok || cast.FuncName.L != ast.Cast { + return parentF, false + } + // eg: if (cast(A) EQ const) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. + if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { + return parentF, false + } + // 1-castOffset should be constant + if _, ok := parentF.GetArgs()[1-castOffset].(*Constant); !ok { + return parentF, false + } + + // the direct args of cast function should be column. + c, ok := cast.GetArgs()[0].(*Column) + if !ok { + return parentF, false + } + + // current only consider varchar and integer + if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { + return parentF, false + } + + // the column is covered by indexes, deconstructing it out. + if castOffset == 0 { + return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, c, parentF.GetArgs()[1]), true + } + return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, parentF.GetArgs()[0], c), true +} + +// eliminateCastFunction will detect the original arg before and the cast type after, once upon +// there is no precision loss between them, current cast wrapper can be eliminated. For string +// type, collation is also taken into consideration. (mainly used to build range or point) +func eliminateCastFunction(sctx sessionctx.Context, expr Expression) (_ Expression, changed bool) { + f, ok := expr.(*ScalarFunction) + if !ok { + return expr, false + } + _, collation := expr.CharsetAndCollation() + switch f.FuncName.L { + case ast.LogicOr: + dnfItems := FlattenDNFConditions(f) + rmCast := false + rmCastItems := make([]Expression, len(dnfItems)) + for i, dnfItem := range dnfItems { + newExpr, curDowncast := eliminateCastFunction(sctx, dnfItem) + rmCastItems[i] = newExpr + if curDowncast { + rmCast = true + } + } + if rmCast { + // compose the new DNF expression. + return ComposeDNFCondition(sctx, rmCastItems...), true + } + return expr, false + case ast.LogicAnd: + cnfItems := FlattenCNFConditions(f) + rmCast := false + rmCastItems := make([]Expression, len(cnfItems)) + for i, cnfItem := range cnfItems { + newExpr, curDowncast := eliminateCastFunction(sctx, cnfItem) + rmCastItems[i] = newExpr + if curDowncast { + rmCast = true + } + } + if rmCast { + // compose the new CNF expression. + return ComposeCNFCondition(sctx, rmCastItems...), true + } + return expr, false + case ast.EQ, ast.NullEQ, ast.LE, ast.GE, ast.LT, ast.GT: + // for case: eq(cast(test.t2.a, varchar(100), "aaaaa"), once t2.a is covered by index or pk, try deconstructing it out. + if newF, ok := unwrapCast(sctx, f, 0); ok { + return newF, true + } + // for case: eq("aaaaa", cast(test.t2.a, varchar(100)), once t2.a is covered by index or pk, try deconstructing it out. + if newF, ok := unwrapCast(sctx, f, 1); ok { + return newF, true + } + case ast.In: + // case for: cast(a as bigint) in (1,2,3), we could deconstruct column 'a out directly. + cast, ok := f.GetArgs()[0].(*ScalarFunction) + if !ok || cast.FuncName.L != ast.Cast { + return expr, false + } + // eg: if (cast(A) IN {const}) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. + if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { + return expr, false + } + for _, arg := range f.GetArgs()[1:] { + if _, ok := arg.(*Constant); !ok { + return expr, false + } + } + // the direct args of cast function should be column. + c, ok := cast.GetArgs()[0].(*Column) + if !ok { + return expr, false + } + // current only consider varchar and integer + if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { + return expr, false + } + newArgs := []Expression{c} + newArgs = append(newArgs, f.GetArgs()[1:]...) + return NewFunctionInternal(sctx, f.FuncName.L, f.RetType, newArgs...), true + } + return expr, false +} + // pushNotAcrossExpr try to eliminate the NOT expr in expression tree. // Input `not` indicates whether there's a `NOT` be pushed down. // Output `changed` indicates whether the output expression differs from the @@ -728,6 +874,15 @@ func PushDownNot(ctx sessionctx.Context, expr Expression) Expression { return newExpr } +// EliminateNoPrecisionLossCast remove the redundant cast function for range build convenience. +// 1: deeper cast embedded in other complicated function will not be considered. +// 2: cast args should be one for original base column and one for constant. +// 3: some collation compatibility and precision loss will be considered when remove this cast func. +func EliminateNoPrecisionLossCast(sctx sessionctx.Context, expr Expression) Expression { + newExpr, _ := eliminateCastFunction(sctx, expr) + return newExpr +} + // ContainOuterNot checks if there is an outer `not`. func ContainOuterNot(expr Expression) bool { return containOuterNot(expr, false) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 9d67c7f841601..4994679f32c8e 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -6864,3 +6864,53 @@ func TestIssueXXX(t *testing.T) { rs := tk.MustQuery("WITH tmp AS (SELECT t2.* FROM t2) SELECT * FROM t1 WHERE t1.id = (select id from tmp where id = 1) or t1.id = (select id from tmp where id = 2) or t1.id = (select id from tmp where id = 3)") rs.Sort().Check(testkit.Rows("1 ", "2 ", "3 ")) } + +func TestDowncastPointGetOrRangeScan(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t1 (a bigint key)") + tk.MustExec("create table t2 (a int key)") + tk.MustExec("create definer=`root`@`127.0.0.1` view v1 as (select a from t1) union (select a from t2)") + // select * from v where a = 1 will lead a condition: EQ(cast(t2.a as bigint), 1), + // we should downcast it, utilizing t2.a =1 to walking through the pk point-get. Because cast doesn't contain any precision loss. + + tk.MustExec("create table t3 (a varchar(100) key)") + tk.MustExec("create table t4 (a varchar(10) key)") + tk.MustExec("create definer=`root`@`127.0.0.1` view v2 as (select a from t3) union (select a from t4)") + // select * from v2 where a = 'test' will lead a condition: EQ(cast(t2.a as varchar(100) same collation), 1), + // we should downcast it, utilizing t2.a = 'test' to walking through the pk point-get. Because cast doesn't contain any precision loss. + + tk.MustExec("create table t5 (a char(100) key)") + tk.MustExec("create table t6 (a char(10) key)") + tk.MustExec("create definer=`root`@`127.0.0.1` view v3 as (select a from t5) union (select a from t6)") + // select * from v3 where a = 'test' will lead a condition: EQ(cast(t2.a as char(100) same collation), 1), + // for char type, it depends, with binary collate, the appended '0' after cast column a from char(10) to char(100) will make some difference + // on comparison on where a = 'test' before and after the UNION operator; so we didn't allow this kind of type downcast currently (precision diff). + + tk.MustExec("create table t7 (a varchar(100) key)") + tk.MustExec("create table t8 (a int key)") + tk.MustExec("create definer=`root`@`127.0.0.1` view v4 as (select a from t7) union (select a from t8)") + // since UNION OP will unify the a(int) and a(varchar100) as varchar(100) + // select * from v4 where a = "test" will lead a condition: EQ(cast(t2.a as varchar(100)), "test"), and since + // cast int to varchar(100) may have some precision loss, we couldn't utilize a="test" to get the range directly. + + var input []string + var output []struct { + SQL string + Plan []string + Result []string + } + integrationSuiteData := core.GetIntegrationSuiteData() + integrationSuiteData.GetTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format='brief' " + tt).Rows()) + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) + }) + tk.MustQuery("explain format='brief' " + tt).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Result...)) + } +} diff --git a/planner/core/stats.go b/planner/core/stats.go index ded25c37bd930..73cee139269c6 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -394,9 +394,12 @@ func (ds *DataSource) DeriveStats(childStats []*property.StatsInfo, selfSchema * ds.stats = ds.tableStats.Scale(selectivity) return ds.stats, nil } - // PushDownNot here can convert query 'not (a != 1)' to 'a = 1'. + // two preprocess here. + // 1: PushDownNot here can convert query 'not (a != 1)' to 'a = 1'. + // 2: EliminateNoPrecisionCast here can convert query 'cast(c as bigint) = 1' to 'c = 1' to leverage access range. for i, expr := range ds.pushedDownConds { - ds.pushedDownConds[i] = expression.PushDownNot(ds.ctx, expr) + ds.pushedDownConds[i] = expression.PushDownNot(ds.SCtx(), expr) + ds.pushedDownConds[i] = expression.EliminateNoPrecisionLossCast(ds.SCtx(), ds.pushedDownConds[i]) } for _, path := range ds.possibleAccessPaths { if path.IsTablePath() { diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index e3aff32273794..6d0689c215c8c 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -932,5 +932,20 @@ "explain format = 'brief' select count(*) from rp_t where a = 1 or a = 20", "explain format = 'brief' select count(*) from hp_t where a = 1 or a = 20" ] + }, + { + "name": "TestDowncastPointGetOrRangeScan", + "cases": [ + "select * from v1 where a = 1; -- the condition should be downcast through both side and go get point", + "select * from v1 where a = '1test'; -- the condition should be downcast through both side and go get point too", + "select * from v1 where a > 1; -- the condition should be downcast through both side and go range scan", + "select * from v2 where a = 'test';", + "select * from v2 where a = 1;", + "select * from v2 where a > 'test';", + "select * from v3 where a = 'test' -- the condition shouldn't be downcast through both side and go get point", + "select * from v3 where a > 'test' -- the condition shouldn't be downcast through both side and go get point too", + "select * from v4 where a = 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range", + "select * from v4 where a > 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range" + ] } ] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 10a8475503219..1c342e263f974 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -7034,5 +7034,138 @@ ] } ] + }, + { + "Name": "TestDowncastPointGetOrRangeScan", + "Cases": [ + { + "SQL": "select * from v1 where a = 1; -- the condition should be downcast through both side and go get point", + "Plan": [ + "HashAgg 2.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3", + "└─Union 2.00 root ", + " ├─Point_Get 1.00 root table:t1 handle:1", + " └─Projection 1.00 root cast(test.t2.a, bigint(20) BINARY)->Column#3", + " └─Point_Get 1.00 root table:t2 handle:1" + ], + "Result": null + }, + { + "SQL": "select * from v1 where a = '1test'; -- the condition should be downcast through both side and go get point too", + "Plan": [ + "HashAgg 2.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3", + "└─Union 2.00 root ", + " ├─Point_Get 1.00 root table:t1 handle:1", + " └─Projection 1.00 root cast(test.t2.a, bigint(20) BINARY)->Column#3", + " └─Point_Get 1.00 root table:t2 handle:1" + ], + "Result": null + }, + { + "SQL": "select * from v1 where a > 1; -- the condition should be downcast through both side and go range scan", + "Plan": [ + "HashAgg 5333.33 root group by:Column#3, funcs:firstrow(Column#3)->Column#3", + "└─Union 6666.67 root ", + " ├─TableReader 3333.33 root data:TableRangeScan", + " │ └─TableRangeScan 3333.33 cop[tikv] table:t1 range:(1,+inf], keep order:false, stats:pseudo", + " └─Projection 3333.33 root cast(test.t2.a, bigint(20) BINARY)->Column#3", + " └─TableReader 3333.33 root data:TableRangeScan", + " └─TableRangeScan 3333.33 cop[tikv] table:t2 range:(1,+inf], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from v2 where a = 'test';", + "Plan": [ + "HashAgg 2.00 root group by:Column#5, funcs:firstrow(Column#5)->Column#5", + "└─Union 2.00 root ", + " ├─Point_Get 1.00 root table:t3, index:PRIMARY(a) ", + " └─Projection 1.00 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#5", + " └─Point_Get 1.00 root table:t4, index:PRIMARY(a) " + ], + "Result": null + }, + { + "SQL": "select * from v2 where a = 1;", + "Plan": [ + "HashAgg 12800.00 root group by:Column#5, funcs:firstrow(Column#5)->Column#5", + "└─Union 16000.00 root ", + " ├─IndexReader 8000.00 root index:Selection", + " │ └─Selection 8000.00 cop[tikv] eq(cast(test.t3.a, double BINARY), 1)", + " │ └─IndexFullScan 10000.00 cop[tikv] table:t3, index:PRIMARY(a) keep order:false, stats:pseudo", + " └─Projection 8000.00 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#5", + " └─IndexReader 8000.00 root index:Selection", + " └─Selection 8000.00 cop[tikv] eq(cast(cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), double BINARY), 1)", + " └─IndexFullScan 10000.00 cop[tikv] table:t4, index:PRIMARY(a) keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from v2 where a > 'test';", + "Plan": [ + "HashAgg 5333.33 root group by:Column#5, funcs:firstrow(Column#5)->Column#5", + "└─Union 6666.67 root ", + " ├─IndexReader 3333.33 root index:IndexRangeScan", + " │ └─IndexRangeScan 3333.33 cop[tikv] table:t3, index:PRIMARY(a) range:(\"test\",+inf], keep order:false, stats:pseudo", + " └─Projection 3333.33 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#5", + " └─IndexReader 3333.33 root index:IndexRangeScan", + " └─IndexRangeScan 3333.33 cop[tikv] table:t4, index:PRIMARY(a) range:(\"test\",+inf], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from v3 where a = 'test' -- the condition shouldn't be downcast through both side and go get point", + "Plan": [ + "HashAgg 6401.00 root group by:Column#5, funcs:firstrow(Column#5)->Column#5", + "└─Union 8001.00 root ", + " ├─Point_Get 1.00 root table:t5, index:PRIMARY(a) ", + " └─Projection 8000.00 root cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#5", + " └─IndexReader 8000.00 root index:Selection", + " └─Selection 8000.00 cop[tikv] eq(cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")", + " └─IndexFullScan 10000.00 cop[tikv] table:t6, index:PRIMARY(a) keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from v3 where a > 'test' -- the condition shouldn't be downcast through both side and go get point too", + "Plan": [ + "HashAgg 9066.67 root group by:Column#5, funcs:firstrow(Column#5)->Column#5", + "└─Union 11333.33 root ", + " ├─IndexReader 3333.33 root index:IndexRangeScan", + " │ └─IndexRangeScan 3333.33 cop[tikv] table:t5, index:PRIMARY(a) range:(\"test\",+inf], keep order:false, stats:pseudo", + " └─Projection 8000.00 root cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#5", + " └─IndexReader 8000.00 root index:Selection", + " └─Selection 8000.00 cop[tikv] gt(cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")", + " └─IndexFullScan 10000.00 cop[tikv] table:t6, index:PRIMARY(a) keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from v4 where a = 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range", + "Plan": [ + "HashAgg 6401.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Union 8001.00 root ", + " ├─Point_Get 1.00 root table:t7, index:PRIMARY(a) ", + " └─Projection 8000.00 root cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#4", + " └─TableReader 8000.00 root data:Selection", + " └─Selection 8000.00 cop[tikv] eq(cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")", + " └─TableFullScan 10000.00 cop[tikv] table:t8 keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from v4 where a > 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range", + "Plan": [ + "HashAgg 9066.67 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Union 11333.33 root ", + " ├─IndexReader 3333.33 root index:IndexRangeScan", + " │ └─IndexRangeScan 3333.33 cop[tikv] table:t7, index:PRIMARY(a) range:(\"test\",+inf], keep order:false, stats:pseudo", + " └─Projection 8000.00 root cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#4", + " └─TableReader 8000.00 root data:Selection", + " └─Selection 8000.00 cop[tikv] gt(cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")", + " └─TableFullScan 10000.00 cop[tikv] table:t8 keep order:false, stats:pseudo" + ], + "Result": null + } + ] } ]