Skip to content

Commit

Permalink
planner: fix cast(col) = range couldn't build range when cast functio…
Browse files Browse the repository at this point in the history
…n doesn't contain any precision loss in some cases (#46303)

close #45199
  • Loading branch information
AilinKid authored Aug 31, 2023
1 parent ede7ad4 commit 28a9c7f
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 1 deletion.
155 changes: 155 additions & 0 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,152 @@ func pushNotAcrossArgs(ctx sessionctx.Context, exprs []Expression, not bool) ([]
return newExprs, flag
}

// todo: consider more no precision-loss downcast cases.
func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool {
// now only consider varchar type and integer.
if !(types.IsTypeVarchar(cast.GetType()) && types.IsTypeVarchar(argCol.GetType())) &&
!(mysql.IsIntegerType(cast.GetType()) && mysql.IsIntegerType(argCol.GetType())) {
// varchar type and integer on the storage layer is quite same, while the char type has its padding suffix.
return false
}
if types.IsTypeVarchar(cast.GetType()) {
// cast varchar function only bear the flen extension.
if cast.GetFlen() < argCol.GetFlen() {
return false
}
if !collate.CompatibleCollate(cast.GetCollate(), argCol.GetCollate()) {
return false
}
} else {
// For integers, we should ignore the potential display length represented by flen, using the default flen of the type.
castFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(cast.GetType())
originFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(argCol.GetType())
// cast integer function only bear the flen extension and signed symbol unchanged.
if castFlen < originFlen {
return false
}
if mysql.HasUnsignedFlag(cast.GetFlag()) != mysql.HasUnsignedFlag(argCol.GetFlag()) {
return false
}
}
return true
}

func unwrapCast(sctx sessionctx.Context, parentF *ScalarFunction, castOffset int) (Expression, bool) {
_, collation := parentF.CharsetAndCollation()
cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction)
if !ok || cast.FuncName.L != ast.Cast {
return parentF, false
}
// eg: if (cast(A) EQ const) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range.
if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) {
return parentF, false
}
// 1-castOffset should be constant
if _, ok := parentF.GetArgs()[1-castOffset].(*Constant); !ok {
return parentF, false
}

// the direct args of cast function should be column.
c, ok := cast.GetArgs()[0].(*Column)
if !ok {
return parentF, false
}

// current only consider varchar and integer
if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) {
return parentF, false
}

// the column is covered by indexes, deconstructing it out.
if castOffset == 0 {
return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, c, parentF.GetArgs()[1]), true
}
return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, parentF.GetArgs()[0], c), true
}

// eliminateCastFunction will detect the original arg before and the cast type after, once upon
// there is no precision loss between them, current cast wrapper can be eliminated. For string
// type, collation is also taken into consideration. (mainly used to build range or point)
func eliminateCastFunction(sctx sessionctx.Context, expr Expression) (_ Expression, changed bool) {
f, ok := expr.(*ScalarFunction)
if !ok {
return expr, false
}
_, collation := expr.CharsetAndCollation()
switch f.FuncName.L {
case ast.LogicOr:
dnfItems := FlattenDNFConditions(f)
rmCast := false
rmCastItems := make([]Expression, len(dnfItems))
for i, dnfItem := range dnfItems {
newExpr, curDowncast := eliminateCastFunction(sctx, dnfItem)
rmCastItems[i] = newExpr
if curDowncast {
rmCast = true
}
}
if rmCast {
// compose the new DNF expression.
return ComposeDNFCondition(sctx, rmCastItems...), true
}
return expr, false
case ast.LogicAnd:
cnfItems := FlattenCNFConditions(f)
rmCast := false
rmCastItems := make([]Expression, len(cnfItems))
for i, cnfItem := range cnfItems {
newExpr, curDowncast := eliminateCastFunction(sctx, cnfItem)
rmCastItems[i] = newExpr
if curDowncast {
rmCast = true
}
}
if rmCast {
// compose the new CNF expression.
return ComposeCNFCondition(sctx, rmCastItems...), true
}
return expr, false
case ast.EQ, ast.NullEQ, ast.LE, ast.GE, ast.LT, ast.GT:
// for case: eq(cast(test.t2.a, varchar(100), "aaaaa"), once t2.a is covered by index or pk, try deconstructing it out.
if newF, ok := unwrapCast(sctx, f, 0); ok {
return newF, true
}
// for case: eq("aaaaa", cast(test.t2.a, varchar(100)), once t2.a is covered by index or pk, try deconstructing it out.
if newF, ok := unwrapCast(sctx, f, 1); ok {
return newF, true
}
case ast.In:
// case for: cast(a<int> as bigint) in (1,2,3), we could deconstruct column 'a out directly.
cast, ok := f.GetArgs()[0].(*ScalarFunction)
if !ok || cast.FuncName.L != ast.Cast {
return expr, false
}
// eg: if (cast(A) IN {const}) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range.
if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) {
return expr, false
}
for _, arg := range f.GetArgs()[1:] {
if _, ok := arg.(*Constant); !ok {
return expr, false
}
}
// the direct args of cast function should be column.
c, ok := cast.GetArgs()[0].(*Column)
if !ok {
return expr, false
}
// current only consider varchar and integer
if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) {
return expr, false
}
newArgs := []Expression{c}
newArgs = append(newArgs, f.GetArgs()[1:]...)
return NewFunctionInternal(sctx, f.FuncName.L, f.RetType, newArgs...), true
}
return expr, false
}

// pushNotAcrossExpr try to eliminate the NOT expr in expression tree.
// Input `not` indicates whether there's a `NOT` be pushed down.
// Output `changed` indicates whether the output expression differs from the
Expand Down Expand Up @@ -793,6 +939,15 @@ func PushDownNot(ctx sessionctx.Context, expr Expression) Expression {
return newExpr
}

// EliminateNoPrecisionLossCast remove the redundant cast function for range build convenience.
// 1: deeper cast embedded in other complicated function will not be considered.
// 2: cast args should be one for original base column and one for constant.
// 3: some collation compatibility and precision loss will be considered when remove this cast func.
func EliminateNoPrecisionLossCast(sctx sessionctx.Context, expr Expression) Expression {
newExpr, _ := eliminateCastFunction(sctx, expr)
return newExpr
}

// ContainOuterNot checks if there is an outer `not`.
func ContainOuterNot(expr Expression) bool {
return containOuterNot(expr, false)
Expand Down
49 changes: 49 additions & 0 deletions planner/core/casetest/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,55 @@ func TestTiFlashFineGrainedShuffle(t *testing.T) {
}
}

func TestDowncastPointGetOrRangeScan(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create table t1 (a bigint key)")
tk.MustExec("create table t2 (a int key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v1 as (select a from t1) union (select a from t2)")
// select * from v where a = 1 will lead a condition: EQ(cast(t2.a as bigint), 1),
// we should downcast it, utilizing t2.a =1 to walking through the pk point-get. Because cast doesn't contain any precision loss.

tk.MustExec("create table t3 (a varchar(100) key)")
tk.MustExec("create table t4 (a varchar(10) key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v2 as (select a from t3) union (select a from t4)")
// select * from v2 where a = 'test' will lead a condition: EQ(cast(t2.a as varchar(100) same collation), 1),
// we should downcast it, utilizing t2.a = 'test' to walking through the pk point-get. Because cast doesn't contain any precision loss.

tk.MustExec("create table t5 (a char(100) key)")
tk.MustExec("create table t6 (a char(10) key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v3 as (select a from t5) union (select a from t6)")
// select * from v3 where a = 'test' will lead a condition: EQ(cast(t2.a as char(100) same collation), 1),
// for char type, it depends, with binary collate, the appended '0' after cast column a from char(10) to char(100) will make some difference
// on comparison on where a = 'test' before and after the UNION operator; so we didn't allow this kind of type downcast currently (precision diff).

tk.MustExec("create table t7 (a varchar(100) key)")
tk.MustExec("create table t8 (a int key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v4 as (select a from t7) union (select a from t8)")
// since UNION OP will unify the a(int) and a(varchar100) as varchar(100)
// select * from v4 where a = "test" will lead a condition: EQ(cast(t2.a as varchar(100)), "test"), and since
// cast int to varchar(100) may have some precision loss, we couldn't utilize a="test" to get the range directly.

var input []string
var output []struct {
SQL string
Plan []string
Result []string
}
integrationSuiteData := GetIntegrationSuiteData()
integrationSuiteData.LoadTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format='brief' " + tt).Rows())
output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows())
})
tk.MustQuery("explain format='brief' " + tt).Check(testkit.Rows(output[i].Plan...))
tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Result...))
}
}

func TestNullConditionForPrefixIndex(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
15 changes: 15 additions & 0 deletions planner/core/casetest/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -594,5 +594,20 @@
"select b from t3 where a = 1 and b is not null",
"select b from t3 where a = 1 and b is null"
]
},
{
"name": "TestDowncastPointGetOrRangeScan",
"cases": [
"select * from v1 where a = 1; -- the condition should be downcast through both side and go get point",
"select * from v1 where a = '1test'; -- the condition should be downcast through both side and go get point too",
"select * from v1 where a > 1; -- the condition should be downcast through both side and go range scan",
"select * from v2 where a = 'test';",
"select * from v2 where a = 1;",
"select * from v2 where a > 'test';",
"select * from v3 where a = 'test' -- the condition shouldn't be downcast through both side and go get point",
"select * from v3 where a > 'test' -- the condition shouldn't be downcast through both side and go get point too",
"select * from v4 where a = 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range",
"select * from v4 where a > 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range"
]
}
]
133 changes: 133 additions & 0 deletions planner/core/casetest/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -4022,5 +4022,138 @@
"Result": null
}
]
},
{
"Name": "TestDowncastPointGetOrRangeScan",
"Cases": [
{
"SQL": "select * from v1 where a = 1; -- the condition should be downcast through both side and go get point",
"Plan": [
"HashAgg 2.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 2.00 root ",
" ├─Point_Get 1.00 root table:t1 handle:1",
" └─Projection 1.00 root cast(test.t2.a, bigint(20) BINARY)->Column#3",
" └─Point_Get 1.00 root table:t2 handle:1"
],
"Result": null
},
{
"SQL": "select * from v1 where a = '1test'; -- the condition should be downcast through both side and go get point too",
"Plan": [
"HashAgg 2.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 2.00 root ",
" ├─Point_Get 1.00 root table:t1 handle:1",
" └─Projection 1.00 root cast(test.t2.a, bigint(20) BINARY)->Column#3",
" └─Point_Get 1.00 root table:t2 handle:1"
],
"Result": null
},
{
"SQL": "select * from v1 where a > 1; -- the condition should be downcast through both side and go range scan",
"Plan": [
"HashAgg 5333.33 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 6666.67 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t1 range:(1,+inf], keep order:false, stats:pseudo",
" └─Projection 3333.33 root cast(test.t2.a, bigint(20) BINARY)->Column#3",
" └─TableReader 3333.33 root data:TableRangeScan",
" └─TableRangeScan 3333.33 cop[tikv] table:t2 range:(1,+inf], keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v2 where a = 'test';",
"Plan": [
"HashAgg 16.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 20.00 root ",
" ├─Point_Get 1.00 root table:t3, clustered index:PRIMARY(a) ",
" └─Projection 10.00 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─Point_Get 1.00 root table:t4, clustered index:PRIMARY(a) "
],
"Result": null
},
{
"SQL": "select * from v2 where a = 1;",
"Plan": [
"HashAgg 12800.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 16000.00 root ",
" ├─TableReader 8000.00 root data:Selection",
" │ └─Selection 8000.00 cop[tikv] eq(cast(test.t3.a, double BINARY), 1)",
" │ └─TableFullScan 10000.00 cop[tikv] table:t3 keep order:false, stats:pseudo",
" └─Projection 8000.00 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] eq(cast(cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), double BINARY), 1)",
" └─TableFullScan 10000.00 cop[tikv] table:t4 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v2 where a > 'test';",
"Plan": [
"HashAgg 5333.33 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 6666.67 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t3 range:(\"test\",+inf], keep order:false, stats:pseudo",
" └─Projection 3333.33 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 3333.33 root data:TableRangeScan",
" └─TableRangeScan 3333.33 cop[tikv] table:t4 range:(\"test\",+inf], keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v3 where a = 'test' -- the condition shouldn't be downcast through both side and go get point",
"Plan": [
"HashAgg 6408.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 8010.00 root ",
" ├─Point_Get 1.00 root table:t5, clustered index:PRIMARY(a) ",
" └─Projection 8000.00 root cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] eq(cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t6 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v3 where a > 'test' -- the condition shouldn't be downcast through both side and go get point too",
"Plan": [
"HashAgg 9066.67 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 11333.33 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t5 range:(\"test\",+inf], keep order:false, stats:pseudo",
" └─Projection 8000.00 root cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] gt(cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t6 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v4 where a = 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range",
"Plan": [
"HashAgg 6408.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 8010.00 root ",
" ├─Point_Get 1.00 root table:t7, clustered index:PRIMARY(a) ",
" └─Projection 8000.00 root cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] eq(cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t8 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v4 where a > 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range",
"Plan": [
"HashAgg 9066.67 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 11333.33 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t7 range:(\"test\",+inf], keep order:false, stats:pseudo",
" └─Projection 8000.00 root cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] gt(cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t8 keep order:false, stats:pseudo"
],
"Result": null
}
]
}
]
5 changes: 4 additions & 1 deletion planner/core/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,12 @@ func (ds *DataSource) DeriveStats(_ []*property.StatsInfo, _ *expression.Schema,
debugtrace.EnterContextCommon(ds.SCtx())
defer debugtrace.LeaveContextCommon(ds.SCtx())
}
// PushDownNot here can convert query 'not (a != 1)' to 'a = 1'.
// two preprocess here.
// 1: PushDownNot here can convert query 'not (a != 1)' to 'a = 1'.
// 2: EliminateNoPrecisionCast here can convert query 'cast(c<int> as bigint) = 1' to 'c = 1' to leverage access range.
for i, expr := range ds.pushedDownConds {
ds.pushedDownConds[i] = expression.PushDownNot(ds.SCtx(), expr)
ds.pushedDownConds[i] = expression.EliminateNoPrecisionLossCast(ds.SCtx(), ds.pushedDownConds[i])
}
for _, path := range ds.possibleAccessPaths {
if path.IsTablePath() {
Expand Down

0 comments on commit 28a9c7f

Please sign in to comment.