Skip to content

Commit

Permalink
cherry pick #19620 to release-4.0 (#21019)
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>

Co-authored-by: Kenan Yao <cauchy1992@gmail.com>
  • Loading branch information
ti-srebot and eurekaka authored Nov 20, 2020
1 parent 5f07350 commit cc61a9f
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 48 deletions.
24 changes: 11 additions & 13 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,18 @@ StreamAgg_12 1.00 root funcs:sum(Column#10)->Column#8
└─TableFullScan_14 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain select 1 in (select c2 from t2) from t1;
id estRows task access object operator info
HashJoin_7 10000.00 root CARTESIAN left outer semi join
├─TableReader_14(Build) 10.00 root data:Selection_13
│ └─Selection_13 10.00 cop[tikv] eq(1, test.t2.c2)
│ └─TableFullScan_12 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
HashJoin_7 10000.00 root CARTESIAN left outer semi join, other cond:eq(1, test.t2.c2)
├─TableReader_13(Build) 10000.00 root data:TableFullScan_12
│ └─TableFullScan_12 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─TableReader_9(Probe) 10000.00 root data:TableFullScan_8
└─TableFullScan_8 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain select sum(6 in (select c2 from t2)) from t1;
id estRows task access object operator info
StreamAgg_12 1.00 root funcs:sum(Column#10)->Column#8
└─Projection_22 10000.00 root cast(Column#7, decimal(65,0) BINARY)->Column#10
└─HashJoin_21 10000.00 root CARTESIAN left outer semi join
├─TableReader_20(Build) 10.00 root data:Selection_19
│ └─Selection_19 10.00 cop[tikv] eq(6, test.t2.c2)
│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─Projection_21 10000.00 root cast(Column#7, decimal(65,0) BINARY)->Column#10
└─HashJoin_20 10000.00 root CARTESIAN left outer semi join, other cond:eq(6, test.t2.c2)
├─TableReader_19(Build) 10000.00 root data:TableFullScan_18
│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─TableReader_15(Probe) 10000.00 root data:TableFullScan_14
└─TableFullScan_14 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain format="dot" select sum(t1.c1 in (select c1 from t2)) from t1;
Expand Down Expand Up @@ -285,22 +283,22 @@ node [style=filled, color=lightgrey]
color=black
label = "root"
"HashJoin_7" -> "TableReader_9"
"HashJoin_7" -> "TableReader_14"
"HashJoin_7" -> "TableReader_13"
}
subgraph cluster8{
node [style=filled, color=lightgrey]
color=black
label = "cop"
"TableFullScan_8"
}
subgraph cluster13{
subgraph cluster12{
node [style=filled, color=lightgrey]
color=black
label = "cop"
"Selection_13" -> "TableFullScan_12"
"TableFullScan_12"
}
"TableReader_9" -> "TableFullScan_8"
"TableReader_14" -> "Selection_13"
"TableReader_13" -> "TableFullScan_12"
}

drop table if exists t1, t2, t3, t4;
Expand Down
15 changes: 7 additions & 8 deletions cmd/explaintest/r/explain_easy_stats.result
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ Limit_10 1.00 root offset:0, count:1
set @@session.tidb_opt_insubq_to_join_and_agg=0;
explain select 1 in (select c2 from t2) from t1;
id estRows task access object operator info
HashJoin_7 1999.00 root CARTESIAN left outer semi join
├─TableReader_14(Build) 0.00 root data:Selection_13
│ └─Selection_13 0.00 cop[tikv] eq(1, test.t2.c2)
│ └─TableFullScan_12 1985.00 cop[tikv] table:t2 keep order:false
HashJoin_7 1999.00 root CARTESIAN left outer semi join, other cond:eq(1, test.t2.c2)
├─TableReader_13(Build) 1985.00 root data:TableFullScan_12
│ └─TableFullScan_12 1985.00 cop[tikv] table:t2 keep order:false
└─TableReader_9(Probe) 1999.00 root data:TableFullScan_8
└─TableFullScan_8 1999.00 cop[tikv] table:t1 keep order:false
explain format="dot" select 1 in (select c2 from t2) from t1;
Expand All @@ -135,22 +134,22 @@ node [style=filled, color=lightgrey]
color=black
label = "root"
"HashJoin_7" -> "TableReader_9"
"HashJoin_7" -> "TableReader_14"
"HashJoin_7" -> "TableReader_13"
}
subgraph cluster8{
node [style=filled, color=lightgrey]
color=black
label = "cop"
"TableFullScan_8"
}
subgraph cluster13{
subgraph cluster12{
node [style=filled, color=lightgrey]
color=black
label = "cop"
"Selection_13" -> "TableFullScan_12"
"TableFullScan_12"
}
"TableReader_9" -> "TableFullScan_8"
"TableReader_14" -> "Selection_13"
"TableReader_13" -> "TableFullScan_12"
}

explain select * from index_prune WHERE a = 1010010404050976781 AND b = 26467085526790 LIMIT 1;
Expand Down
26 changes: 26 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,32 @@ func (s *testSuiteP2) TestRow(c *C) {
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select (select 1)")
result.Check(testkit.Rows("1"))

tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1 (a int, b int)")
tk.MustExec("insert t1 values (1,2),(1,null)")
tk.MustExec("drop table if exists t2")
tk.MustExec("create table t2 (c int, d int)")
tk.MustExec("insert t2 values (0,0)")

tk.MustQuery("select * from t2 where (1,2) in (select * from t1)").Check(testkit.Rows("0 0"))
tk.MustQuery("select * from t2 where (1,2) not in (select * from t1)").Check(testkit.Rows())
tk.MustQuery("select * from t2 where (1,1) not in (select * from t1)").Check(testkit.Rows())
tk.MustQuery("select * from t2 where (1,null) in (select * from t1)").Check(testkit.Rows())
tk.MustQuery("select * from t2 where (null,null) in (select * from t1)").Check(testkit.Rows())

tk.MustExec("delete from t1 where a=1 and b=2")
tk.MustQuery("select (1,1) in (select * from t2) from t1").Check(testkit.Rows("0"))
tk.MustQuery("select (1,1) not in (select * from t2) from t1").Check(testkit.Rows("1"))
tk.MustQuery("select (1,1) in (select 1,1 from t2) from t1").Check(testkit.Rows("1"))
tk.MustQuery("select (1,1) not in (select 1,1 from t2) from t1").Check(testkit.Rows("0"))

// MySQL 5.7 returns 1 for these 2 queries, which is wrong.
tk.MustQuery("select (1,null) not in (select 1,1 from t2) from t1").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (t1.a,null) not in (select 1,1 from t2) from t1").Check(testkit.Rows("<nil>"))

tk.MustQuery("select (1,null) in (select * from t1)").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (1,null) not in (select * from t1)").Check(testkit.Rows("<nil>"))
}

func (s *testSuiteP2) TestColumnName(c *C) {
Expand Down
59 changes: 59 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,65 @@ func (s *testSuiteJoin3) TestSubquery(c *C) {
tk.MustExec("insert into t2 values(1)")
tk.MustQuery("select * from t1 where a in (select a from t2)").Check(testkit.Rows("1"))

tk.MustExec("insert into t2 value(null)")
tk.MustQuery("select * from t1 where 1 in (select b from t2)").Check(testkit.Rows("1"))
tk.MustQuery("select * from t1 where 1 not in (select b from t2)").Check(testkit.Rows())
tk.MustQuery("select * from t1 where 2 not in (select b from t2)").Check(testkit.Rows())
tk.MustQuery("select * from t1 where 2 in (select b from t2)").Check(testkit.Rows())
tk.MustQuery("select 1 in (select b from t2) from t1").Check(testkit.Rows("1"))
tk.MustQuery("select 1 in (select 1 from t2) from t1").Check(testkit.Rows("1"))
tk.MustQuery("select 1 not in (select b from t2) from t1").Check(testkit.Rows("0"))
tk.MustQuery("select 1 not in (select 1 from t2) from t1").Check(testkit.Rows("0"))

tk.MustExec("delete from t2 where b=1")
tk.MustQuery("select 1 in (select b from t2) from t1").Check(testkit.Rows("<nil>"))
tk.MustQuery("select 1 not in (select b from t2) from t1").Check(testkit.Rows("<nil>"))
tk.MustQuery("select 1 not in (select 1 from t2) from t1").Check(testkit.Rows("0"))
tk.MustQuery("select 1 in (select 1 from t2) from t1").Check(testkit.Rows("1"))
tk.MustQuery("select 1 not in (select null from t1) from t2").Check(testkit.Rows("<nil>"))
tk.MustQuery("select 1 in (select null from t1) from t2").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table if exists s")
tk.MustExec("create table s(a int not null, b int)")
tk.MustExec("set sql_mode = ''")
tk.MustQuery("select (2,0) in (select s.a, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) not in (select s.a, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) = any (select s.a, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) != all (select s.a, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) in (select s.b, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) not in (select s.b, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) = any (select s.b, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (2,0) != all (select s.b, min(s.b) from s) as f").Check(testkit.Rows("<nil>"))
tk.MustExec("insert into s values(1,null)")
tk.MustQuery("select 1 in (select b from s)").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int)")
tk.MustExec("insert into t values(1),(null)")
tk.MustQuery("select a not in (select 1) from t").Sort().Check(testkit.Rows(
"0",
"<nil>",
))
tk.MustQuery("select 1 not in (select null from t t1) from t").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select 1 in (select null from t t1) from t").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select a in (select 0) xx from (select null as a) x").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(1,null),(null, null),(null, 2)")
tk.MustQuery("select * from t t1 where (2 in (select a from t t2 where (t2.b=t1.b) is null))").Check(testkit.Rows())
tk.MustQuery("select (t2.a in (select t1.a from t t1)) is true from t t2").Sort().Check(testkit.Rows(
"0",
"0",
"1",
))

tk.MustExec("set @@tidb_hash_join_concurrency=5")
}

Expand Down
11 changes: 0 additions & 11 deletions expression/aggregation/base_func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,3 @@ func (s *testBaseFuncSuite) TestClone(c *check.C) {
c.Assert(desc.Args[0], check.Equals, col)
c.Assert(desc.equal(s.ctx, cloned), check.IsFalse)
}

func (s *testBaseFuncSuite) TestMaxMin(c *check.C) {
col := &expression.Column{
UniqueID: 0,
RetType: types.NewFieldType(mysql.TypeLonglong),
}
col.RetType.Flag |= mysql.NotNullFlag
desc, err := newBaseFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col})
c.Assert(err, check.IsNil)
c.Assert(mysql.HasNotNullFlag(desc.RetTp.Flag), check.IsFalse)
}
41 changes: 41 additions & 0 deletions expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"math"
"strconv"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -272,3 +273,43 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx sessionctx.Context, sch
}
return con.Value, true
}

// UpdateNotNullFlag4RetType checks if we should remove the NotNull flag for the return type of the agg.
func (a *AggFuncDesc) UpdateNotNullFlag4RetType(hasGroupBy, allAggsFirstRow bool) error {
var removeNotNull bool
switch a.Name {
case ast.AggFuncCount, ast.AggFuncApproxCountDistinct, ast.AggFuncApproxPercentile,
ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor,
ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue, ast.WindowFuncRowNumber,
ast.WindowFuncRank, ast.WindowFuncDenseRank, ast.WindowFuncCumeDist, ast.WindowFuncNtile, ast.WindowFuncPercentRank,
ast.WindowFuncLead, ast.WindowFuncLag, ast.AggFuncJsonObjectAgg,
ast.AggFuncVarSamp, ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncStddevSamp:
removeNotNull = false
case ast.AggFuncSum, ast.AggFuncAvg, ast.AggFuncGroupConcat:
if !hasGroupBy {
removeNotNull = true
}
// `select max(a) from empty_tbl` returns `null`, while `select max(a) from empty_tbl group by b` returns empty.
case ast.AggFuncMax, ast.AggFuncMin:
if !hasGroupBy && a.RetTp.Tp != mysql.TypeBit {
removeNotNull = true
}
// `select distinct a from empty_tbl` returns empty
// `select a from empty_tbl group by b` returns empty
// `select a, max(a) from empty_tbl` returns `(null, null)`
// `select a, max(a) from empty_tbl group by b` returns empty
// `select a, count(a) from empty_tbl` returns `(null, 0)`
// `select a, count(a) from empty_tbl group by b` returns empty
case ast.AggFuncFirstRow:
if !allAggsFirstRow && !hasGroupBy {
removeNotNull = true
}
default:
return errors.Errorf("unsupported agg function: %s", a.Name)
}
if removeNotNull {
a.RetTp = a.RetTp.Clone()
a.RetTp.Flag &^= mysql.NotNullFlag
}
return nil
}
10 changes: 10 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ func IsEQCondFromIn(expr Expression) bool {
return len(cols) > 0
}

// ExprNotNull checks if an expression is possible to be null.
func ExprNotNull(expr Expression) bool {
if c, ok := expr.(*Constant); ok {
return !c.Value.IsNull()
}
// For ScalarFunction, the result would not be correct until we support maintaining
// NotNull flag for it.
return mysql.HasNotNullFlag(expr.GetType().Flag)
}

// HandleOverflowOnSelection handles Overflow errors when evaluating selection filters.
// We should ignore overflow errors when evaluating selection conditions:
// INSERT INTO t VALUES ("999999999999999999");
Expand Down
51 changes: 41 additions & 10 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,16 +429,41 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
}

func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r expression.Expression, not bool) {
var condition expression.Expression
if rCol, ok := r.(*expression.Column); ok && (er.asScalar || not) {
// If both input columns of `!= all / = any` expression are not null, we can treat the expression
// as normal column equal condition.
if lCol, ok := l.(*expression.Column); !ok || !mysql.HasNotNullFlag(lCol.GetType().Flag) || !mysql.HasNotNullFlag(rCol.GetType().Flag) {
rColCopy := *rCol
rColCopy.InOperand = true
r = &rColCopy
if er.asScalar || not {
if expression.GetRowLen(r) == 1 {
rCol := r.(*expression.Column)
// If both input columns of `!= all / = any` expression are not null, we can treat the expression
// as normal column equal condition.
if !expression.ExprNotNull(l) || !expression.ExprNotNull(rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
r = &rColCopy
}
} else {
rowFunc := r.(*expression.ScalarFunction)
rargs := rowFunc.GetArgs()
args := make([]expression.Expression, 0, len(rargs))
modified := false
for i, rarg := range rargs {
larg := expression.GetFuncArg(l, i)
if !expression.ExprNotNull(larg) || !expression.ExprNotNull(rarg) {
rCol := rarg.(*expression.Column)
rColCopy := *rCol
rColCopy.InOperand = true
rarg = &rColCopy
modified = true
}
args = append(args, rarg)
}
if modified {
r, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
if er.err != nil {
return
}
}
}
}
var condition expression.Expression
condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ)
if er.err != nil {
return
Expand Down Expand Up @@ -811,15 +836,21 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte
if v.Not || asScalar {
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead.
if !mysql.HasNotNullFlag(lexpr.GetType().Flag) || !mysql.HasNotNullFlag(rCol.GetType().Flag) {
if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
rexpr = &rColCopy
}
}
} else {
args := make([]expression.Expression, 0, np.Schema().Len())
for _, col := range np.Schema().Columns {
for i, col := range np.Schema().Columns {
larg := expression.GetFuncArg(lexpr, i)
if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) {
rarg := *col
rarg.InOperand = true
col = &rarg
}
args = append(args, col)
}
rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
Expand Down
Loading

0 comments on commit cc61a9f

Please sign in to comment.