From 7d8ba949c0a69e49a1b5dd9f4915b3e7fe37326e Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Mon, 25 Feb 2019 15:42:38 +0800 Subject: [PATCH] plan/executor: make semi joins null and empty aware (#9051) --- cmd/explaintest/r/explain_easy.result | 38 +-- cmd/explaintest/r/select.result | 32 ++ cmd/explaintest/t/select.test | 10 + executor/executor.go | 2 +- executor/index_lookup_join.go | 7 +- executor/join.go | 23 +- executor/join_test.go | 390 ++++++++++++++++++++++ executor/joiner.go | 143 ++++---- executor/merge_join.go | 10 +- executor/merge_join_test.go | 12 +- executor/union_scan.go | 2 +- expression/chunk_executor.go | 2 +- expression/column.go | 4 + expression/constant_propagation.go | 2 +- expression/expression.go | 45 ++- expression/util.go | 21 +- planner/core/cbo_test.go | 2 +- planner/core/expression_rewriter.go | 49 ++- planner/core/find_best_task.go | 2 +- planner/core/logical_plan_builder.go | 9 +- planner/core/rule_eliminate_projection.go | 4 +- planner/core/rule_predicate_push_down.go | 6 + 22 files changed, 687 insertions(+), 128 deletions(-) diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 0d6b9da8c17f5..e1a7d53d1167b 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -184,12 +184,12 @@ set @@session.tidb_opt_insubquery_unfold = 0; explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info StreamAgg_12 1.00 root funcs:sum(col_0) -└─Projection_33 10000.00 root cast(5_aux_0) - └─MergeJoin_26 10000.00 root left outer semi join, left key:test.t1.c1, right key:test.t2.c1 - ├─TableReader_19 10000.00 root data:TableScan_18 - │ └─TableScan_18 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo - └─IndexReader_21 10000.00 root index:IndexScan_20 - └─IndexScan_20 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo +└─Projection_21 10000.00 root cast(5_aux_0) + └─HashLeftJoin_18 10000.00 root left outer semi join, inner:TableReader_17, other cond:eq(test.t1.c1, test.t2.c1) + ├─TableReader_20 10000.00 root data:TableScan_19 + │ └─TableScan_19 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─TableReader_17 10000.00 root data:TableScan_16 + └─TableScan_16 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo explain select 1 in (select c2 from t2) from t1; id count task operator info Projection_6 10000.00 root 5_aux_0 @@ -217,25 +217,25 @@ subgraph cluster12{ node [style=filled, color=lightgrey] color=black label = "root" -"StreamAgg_12" -> "Projection_33" -"Projection_33" -> "MergeJoin_26" -"MergeJoin_26" -> "TableReader_19" -"MergeJoin_26" -> "IndexReader_21" +"StreamAgg_12" -> "Projection_21" +"Projection_21" -> "HashLeftJoin_18" +"HashLeftJoin_18" -> "TableReader_20" +"HashLeftJoin_18" -> "TableReader_17" } -subgraph cluster18{ +subgraph cluster19{ node [style=filled, color=lightgrey] color=black label = "cop" -"TableScan_18" +"TableScan_19" } -subgraph cluster20{ +subgraph cluster16{ node [style=filled, color=lightgrey] color=black label = "cop" -"IndexScan_20" +"TableScan_16" } -"TableReader_19" -> "TableScan_18" -"IndexReader_21" -> "IndexScan_20" +"TableReader_20" -> "TableScan_19" +"TableReader_17" -> "TableScan_16" } explain format="dot" select 1 in (select c2 from t2) from t1; @@ -272,7 +272,7 @@ create table t(a int primary key, b int, c int, index idx(b)); explain select t.c in (select count(*) from t s ignore index(idx), t t1 where s.a = t.a and s.a = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 -└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, equal:[eq(test.t.c, count(*))] +└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, count(*)) ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) @@ -285,7 +285,7 @@ Projection_11 10000.00 root 9_aux_0 explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 -└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, equal:[eq(test.t.c, count(*))] +└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, count(*)) ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) @@ -297,7 +297,7 @@ Projection_11 10000.00 root 9_aux_0 explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.c = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 -└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, equal:[eq(test.t.c, count(*))] +└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, count(*)) ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result index ef3f1be3fd2b9..88976590bd186 100644 --- a/cmd/explaintest/r/select.result +++ b/cmd/explaintest/r/select.result @@ -354,3 +354,35 @@ Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(t1.a, col_firstrow)), └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) └─TableReader_24 10000.00 root data:TableScan_23 └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +drop table if exists t; +create table t(a int, b int); +drop table if exists s; +create table s(a varchar(20), b varchar(20)); +explain select a in (select a from s where s.b = t.b) from t; +id count task operator info +Projection_9 10000.00 root 6_aux_0 +└─HashLeftJoin_10 10000.00 root left outer semi join, inner:Projection_14, equal:[eq(cast(test.t.b), cast(test.s.b))], other cond:eq(cast(test.t.a), cast(test.s.a)) + ├─Projection_11 10000.00 root test.t.a, test.t.b, cast(test.t.b) + │ └─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo + └─Projection_14 10000.00 root test.s.a, test.s.b, cast(test.s.b) + └─TableReader_16 10000.00 root data:TableScan_15 + └─TableScan_15 10000.00 cop table:s, range:[-inf,+inf], keep order:false, stats:pseudo +explain select a in (select a+b from t t2 where t2.b = t1.b) from t t1; +id count task operator info +Projection_7 10000.00 root 6_aux_0 +└─HashLeftJoin_8 10000.00 root left outer semi join, inner:TableReader_12, equal:[eq(t1.b, t2.b)], other cond:eq(t1.a, plus(t2.a, t2.b)) + ├─TableReader_10 10000.00 root data:TableScan_9 + │ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─TableReader_12 10000.00 root data:TableScan_11 + └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +drop table t; +create table t(a int not null, b int); +explain select a in (select a from t t2 where t2.b = t1.b) from t t1; +id count task operator info +Projection_7 10000.00 root 6_aux_0 +└─HashLeftJoin_8 10000.00 root left outer semi join, inner:TableReader_12, equal:[eq(t1.b, t2.b) eq(t1.a, t2.a)] + ├─TableReader_10 10000.00 root data:TableScan_9 + │ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─TableReader_12 10000.00 root data:TableScan_11 + └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/t/select.test b/cmd/explaintest/t/select.test index 68abcfe508f41..878db36e05693 100644 --- a/cmd/explaintest/t/select.test +++ b/cmd/explaintest/t/select.test @@ -171,3 +171,13 @@ drop table if exists t; create table t(a int, b int); explain select a != any (select a from t t2) from t t1; explain select a = all (select a from t t2) from t t1; + +drop table if exists t; +create table t(a int, b int); +drop table if exists s; +create table s(a varchar(20), b varchar(20)); +explain select a in (select a from s where s.b = t.b) from t; +explain select a in (select a+b from t t2 where t2.b = t1.b) from t t1; +drop table t; +create table t(a int not null, b int); +explain select a in (select a from t t2 where t2.b = t1.b) from t t1; diff --git a/executor/executor.go b/executor/executor.go index 00e4e692c5eb4..d0d11b6357280 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -917,7 +917,7 @@ func (e *SelectionExec) Next(ctx context.Context, chk *chunk.Chunk) error { func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) error { for { for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - selected, err := expression.EvalBool(e.ctx, e.filters, e.inputRow) + selected, _, err := expression.EvalBool(e.ctx, e.filters, e.inputRow) if err != nil { return errors.Trace(err) } diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index d614477387105..95190b7f8535f 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -94,6 +94,7 @@ type lookUpJoinTask struct { doneCh chan error cursor int hasMatch bool + hasNull bool memTracker *memory.Tracker // track memory usage. } @@ -234,18 +235,20 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, chk *chunk.Chunk) error { outerRow := task.outerResult.GetRow(task.cursor) if e.innerIter.Current() != e.innerIter.End() { - matched, err := e.joiner.tryToMatch(outerRow, e.innerIter, chk) + matched, isNull, err := e.joiner.tryToMatch(outerRow, e.innerIter, chk) if err != nil { return errors.Trace(err) } task.hasMatch = task.hasMatch || matched + task.hasNull = task.hasNull || isNull } if e.innerIter.Current() == e.innerIter.End() { if !task.hasMatch { - e.joiner.onMissMatch(outerRow, chk) + e.joiner.onMissMatch(task.hasNull, outerRow, chk) } task.cursor++ task.hasMatch = false + task.hasNull = false } if chk.NumRows() == e.maxChunkSize { return nil diff --git a/executor/join.go b/executor/join.go index c539ae68a378b..0338a285d4e69 100644 --- a/executor/join.go +++ b/executor/join.go @@ -176,7 +176,6 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyB return true, keyBuf, nil } } - keyBuf = keyBuf[:0] keyBuf, err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, keyBuf, row, allTypes, keyColIdx) if err != nil { @@ -408,13 +407,13 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R return false, joinResult } if hasNull { - e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } e.hashTableValBufs[workerID] = e.hashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) innerPtrs := e.hashTableValBufs[workerID] if len(innerPtrs) == 0 { - e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } innerRows := make([]chunk.Row, 0, len(innerPtrs)) @@ -424,14 +423,15 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R innerRows = append(innerRows, matchedInner) } iter := chunk.NewIterator4Slice(innerRows) - hasMatch := false + hasMatch, hasNull := false, false for iter.Begin(); iter.Current() != iter.End(); { - matched, err := e.joiners[workerID].tryToMatch(outerRow, iter, joinResult.chk) + matched, isNull, err := e.joiners[workerID].tryToMatch(outerRow, iter, joinResult.chk) if err != nil { joinResult.err = errors.Trace(err) return false, joinResult } hasMatch = hasMatch || matched + hasNull = hasNull || isNull if joinResult.chk.NumRows() == e.maxChunkSize { ok := true @@ -443,7 +443,7 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R } } if !hasMatch { - e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) + e.joiners[workerID].onMissMatch(hasNull, outerRow, joinResult.chk) } return true, joinResult } @@ -471,7 +471,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResu } for i := range selected { if !selected[i] { // process unmatched outer rows - e.joiners[workerID].onMissMatch(outerChk.GetRow(i), joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerChk.GetRow(i), joinResult.chk) } else { // process matched outer rows ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult) if !ok { @@ -609,6 +609,7 @@ type NestedLoopApplyExec struct { innerIter chunk.Iterator outerRow *chunk.Row hasMatch bool + hasNull bool memTracker *memory.Tracker // track memory usage. } @@ -666,7 +667,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch if selected { return &outerRow, nil } else if e.outer { - e.joiner.onMissMatch(outerRow, chk) + e.joiner.onMissMatch(false, outerRow, chk) if chk.NumRows() == e.maxChunkSize { return nil, nil } @@ -714,13 +715,14 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, chk *chunk.Chunk) (err e for { if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { if e.outerRow != nil && !e.hasMatch { - e.joiner.onMissMatch(*e.outerRow, chk) + e.joiner.onMissMatch(e.hasNull, *e.outerRow, chk) } e.outerRow, err = e.fetchSelectedOuterRow(ctx, chk) if e.outerRow == nil || err != nil { return errors.Trace(err) } e.hasMatch = false + e.hasNull = false for _, col := range e.outerSchema { *col.Data = e.outerRow.GetDatum(col.Index, col.RetType) @@ -733,8 +735,9 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, chk *chunk.Chunk) (err e e.innerIter.Begin() } - matched, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, chk) + matched, isNull, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, chk) e.hasMatch = e.hasMatch || matched + e.hasNull = e.hasNull || isNull if err != nil || chk.NumRows() == e.maxChunkSize { return errors.Trace(err) diff --git a/executor/join_test.go b/executor/join_test.go index f40263b99dd96..c104ffc2df86d 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -14,6 +14,7 @@ package executor_test import ( + "fmt" "time" . "github.com/pingcap/check" @@ -1016,3 +1017,392 @@ func (s *testSuite) TestJoinDifferentDecimals(c *C) { c.Assert(len(row), Equals, 3) rst.Check(testkit.Rows("1 1.000", "2 2.000", "3 3.000")) } + +func (s *testSuite) TestNullEmptyAwareSemiJoin(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idb_b(b), index idx_c(c))") + tk.MustExec("insert into t values(null, 1, 0), (1, 2, 0)") + tests := []struct { + sql string + }{ + { + "a, b from t t1 where a not in (select b from t t2)", + }, + { + "a, b from t t1 where a not in (select b from t t2 where t1.b = t2.a)", + }, + { + "a, b from t t1 where a not in (select a from t t2)", + }, + { + "a, b from t t1 where a not in (select a from t t2 where t1.b = t2.b)", + }, + { + "a, b from t t1 where a != all (select b from t t2)", + }, + { + "a, b from t t1 where a != all (select b from t t2 where t1.b = t2.a)", + }, + { + "a, b from t t1 where a != all (select a from t t2)", + }, + { + "a, b from t t1 where a != all (select a from t t2 where t1.b = t2.b)", + }, + { + "a, b from t t1 where not exists (select * from t t2 where t1.a = t2.b)", + }, + { + "a, b from t t1 where not exists (select * from t t2 where t1.a = t2.a)", + }, + } + results := []struct { + result [][]interface{} + }{ + { + testkit.Rows(), + }, + { + testkit.Rows("1 2"), + }, + { + testkit.Rows(), + }, + { + testkit.Rows(), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 2"), + }, + { + testkit.Rows(), + }, + { + testkit.Rows(), + }, + { + testkit.Rows(" 1"), + }, + { + testkit.Rows(" 1"), + }, + } + hints := [3]string{"/*+ TIDB_HJ(t1, t2) */", "/*+ TIDB_INLJ(t1, t2) */", "/*+ TIDB_SMJ(t1, t2) */"} + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, null, 0), (2, 1, 0)") + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows("2 1"), + }, + { + testkit.Rows(), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, null, 0), (2, 1, 0), (null, 2, 0)") + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(), + }, + { + testkit.Rows("1 "), + }, + { + testkit.Rows(" 2"), + }, + { + testkit.Rows(" 2"), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, null, 0), (2, null, 0)") + tests = []struct { + sql string + }{ + { + "a, b from t t1 where b not in (select a from t t2)", + }, + } + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows(), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(null, 1, 1), (2, 2, 2), (3, null, 3), (4, 4, 3)") + tests = []struct { + sql string + }{ + { + "a, b, a not in (select b from t t2) from t t1 order by a", + }, + { + "a, c, a not in (select c from t t2) from t t1 order by a", + }, + { + "a, b, a in (select b from t t2) from t t1 order by a", + }, + { + "a, c, a in (select c from t t2) from t t1 order by a", + }, + } + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows( + " 1 ", + "2 2 0", + "3 ", + "4 4 0", + ), + }, + { + testkit.Rows( + " 1 ", + "2 2 0", + "3 3 0", + "4 3 1", + ), + }, + { + testkit.Rows( + " 1 ", + "2 2 1", + "3 ", + "4 4 1", + ), + }, + { + testkit.Rows( + " 1 ", + "2 2 1", + "3 3 1", + "4 3 0", + ), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("drop table if exists s") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into s values(1, 2)") + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(null, null, 0)") + tests = []struct { + sql string + }{ + { + "a in (select b from t t2 where t2.a = t1.b) from s t1", + }, + { + "a in (select b from s t2 where t2.a = t1.b) from t t1", + }, + } + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows("0"), + }, + { + testkit.Rows("0"), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("truncate table s") + tk.MustExec("insert into s values(2, 2)") + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(null, 1, 0)") + tests = []struct { + sql string + }{ + { + "a in (select a from s t2 where t2.b = t1.b) from t t1", + }, + { + "a in (select a from s t2 where t2.b < t1.b) from t t1", + }, + } + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows("0"), + }, + { + testkit.Rows("0"), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } + + tk.MustExec("truncate table s") + tk.MustExec("insert into s values(null, 2)") + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, 1, 0)") + tests = []struct { + sql string + }{ + { + "a in (select a from s t2 where t2.b = t1.b) from t t1", + }, + { + "b in (select a from s t2) from t t1", + }, + { + "* from t t1 where a not in (select a from s t2 where t2.b = t1.b)", + }, + { + "* from t t1 where a not in (select a from s t2)", + }, + { + "* from s t1 where a not in (select a from t t2)", + }, + } + results = []struct { + result [][]interface{} + }{ + { + testkit.Rows("0"), + }, + { + testkit.Rows(""), + }, + { + testkit.Rows("1 1 0"), + }, + { + testkit.Rows(), + }, + { + testkit.Rows(), + }, + } + for i, tt := range tests { + for _, hint := range hints { + sql := fmt.Sprintf("select %s %s", hint, tt.sql) + result := tk.MustQuery(sql) + result.Check(results[i].result) + } + } +} + +func (s *testSuite) TestScalarFuncNullSemiJoin(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t values(null, 1), (1, 2)") + tk.MustExec("drop table if exists s") + tk.MustExec("create table s(a varchar(20), b varchar(20))") + tk.MustExec("insert into s values(null, '1')") + tk.MustQuery("select a in (select a from s) from t").Check(testkit.Rows("", "")) + tk.MustExec("drop table s") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into s values(null, 1)") + tk.MustQuery("select a in (select a+b from s) from t").Check(testkit.Rows("", "")) +} diff --git a/executor/joiner.go b/executor/joiner.go index 5423acdca86f8..2325009cb8b1a 100644 --- a/executor/joiner.go +++ b/executor/joiner.go @@ -35,14 +35,15 @@ var ( // joiner is used to generate join results according to the join type. // A typical instruction flow is: // -// hasMatch := false +// hasMatch, hasNull := false, false // for innerIter.Current() != innerIter.End() { -// matched, err := j.tryToMatch(outer, innerIter, chk) +// matched, isNull, err := j.tryToMatch(outer, innerIter, chk) // // handle err // hasMatch = hasMatch || matched +// hasNull = hasNull || isNull // } // if !hasMatch { -// j.onMissMatch(outer) +// j.onMissMatch(hasNull, outer, chk) // } // // NOTE: This interface is **not** thread-safe. @@ -51,11 +52,15 @@ type joiner interface { // 'inners.Len != 0' but all the joined rows are filtered, the outer row is // considered unmatched. Otherwise, the outer row is matched and some joined // rows are appended to `chk`. The size of `chk` is limited to MaxChunkSize. + // Note that when the outer row is considered unmatched, we need to differentiate + // whether the join conditions return null or false, because that matters for + // AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemijoin, and the result is reflected + // by the second return value; for other join types, we always return false. // // NOTE: Callers need to call this function multiple times to consume all // the inner rows for an outer row, and dicide whether the outer row can be // matched with at lease one inner row. - tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) + tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) // onMissMatch operates on the unmatched outer row according to the join // type. An outer row can be considered miss matched if: @@ -76,7 +81,13 @@ type joiner interface { // 6. 'RightOuterJoin': concats the unmatched outer row with a row of NULLs // and appends it to the result buffer. // 7. 'InnerJoin': ignores the unmatched outer row. - onMissMatch(outer chunk.Row, chk *chunk.Chunk) + // + // Note that, for LeftOuterSemiJoin, AntiSemiJoin and AntiLeftOuterSemiJoin, + // we need to know the reason of outer row being treated as unmatched: + // whether the join condition returns false, or returns null, because + // it decides if this outer row should be outputed, hence we have a `hasNull` + // parameter passed to `onMissMatch`. + onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) } func newJoiner(ctx sessionctx.Context, joinType plannercore.JoinType, @@ -176,34 +187,36 @@ type semiJoiner struct { baseJoiner } -func (j *semiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *semiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { chk.AppendPartialRow(0, outer) inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(j.outerIsRight, inner, outer) - matched, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + // For SemiJoin, we can safely treat null result of join conditions as false, + // so we ignore the nullness returned by EvalBool here. + matched, _, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { chk.AppendPartialRow(0, outer) inners.ReachEnd() - return true, nil + return true, false, nil } } - return false, nil + return false, false, nil } -func (j *semiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *semiJoiner) onMissMatch(_ bool, outer chunk.Row, chk *chunk.Chunk) { } type antiSemiJoiner struct { @@ -211,33 +224,36 @@ type antiSemiJoiner struct { } // tryToMatch implements joiner interface. -func (j *antiSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *antiSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(j.outerIsRight, inner, outer) - matched, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { inners.ReachEnd() - return true, nil + return true, false, nil } + hasNull = hasNull || isNull } - return false, nil + return false, hasNull, nil } -func (j *antiSemiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { - chk.AppendRow(outer) +func (j *antiSemiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { + if !hasNull { + chk.AppendRow(outer) + } } type leftOuterSemiJoiner struct { @@ -245,31 +261,32 @@ type leftOuterSemiJoiner struct { } // tryToMatch implements joiner interface. -func (j *leftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *leftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(false, inner, outer) - matched, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } + hasNull = hasNull || isNull } - return false, nil + return false, hasNull, nil } func (j *leftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { @@ -277,9 +294,13 @@ func (j *leftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { chk.AppendInt64(outer.Len(), 1) } -func (j *leftOuterSemiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *leftOuterSemiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, outer) - chk.AppendInt64(outer.Len(), 0) + if hasNull { + chk.AppendNull(outer.Len()) + } else { + chk.AppendInt64(outer.Len(), 0) + } } type antiLeftOuterSemiJoiner struct { @@ -287,31 +308,32 @@ type antiLeftOuterSemiJoiner struct { } // tryToMatch implements joiner interface. -func (j *antiLeftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *antiLeftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(false, inner, outer) - matched, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } + hasNull = hasNull || isNull } - return false, nil + return false, hasNull, nil } func (j *antiLeftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { @@ -319,9 +341,13 @@ func (j *antiLeftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { chk.AppendInt64(outer.Len(), 0) } -func (j *antiLeftOuterSemiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *antiLeftOuterSemiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, outer) - chk.AppendInt64(outer.Len(), 1) + if hasNull { + chk.AppendNull(outer.Len()) + } else { + chk.AppendInt64(outer.Len(), 1) + } } type leftOuterJoiner struct { @@ -329,9 +355,9 @@ type leftOuterJoiner struct { } // tryToMatch implements joiner interface. -func (j *leftOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) { +func (j *leftOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } j.chk.Reset() chkForJoin := j.chk @@ -345,18 +371,18 @@ func (j *leftOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk inners.Next() } if len(j.conditions) == 0 { - return true, nil + return true, false, nil } // reach here, chkForJoin is j.chk matched, err := j.filter(chkForJoin, chk, outer.Len()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } - return matched, nil + return matched, false, nil } -func (j *leftOuterJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *leftOuterJoiner) onMissMatch(_ bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, outer) chk.AppendPartialRow(outer.Len(), j.defaultInner) } @@ -366,9 +392,9 @@ type rightOuterJoiner struct { } // tryToMatch implements joiner interface. -func (j *rightOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) { +func (j *rightOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } j.chk.Reset() @@ -383,17 +409,17 @@ func (j *rightOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, ch inners.Next() } if len(j.conditions) == 0 { - return true, nil + return true, false, nil } matched, err := j.filter(chkForJoin, chk, outer.Len()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } - return matched, nil + return matched, false, nil } -func (j *rightOuterJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *rightOuterJoiner) onMissMatch(_ bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, j.defaultInner) chk.AppendPartialRow(j.defaultInner.Len(), outer) } @@ -403,9 +429,9 @@ type innerJoiner struct { } // tryToMatch implements joiner interface. -func (j *innerJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) { +func (j *innerJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } j.chk.Reset() chkForJoin := j.chk @@ -421,17 +447,16 @@ func (j *innerJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *ch } } if len(j.conditions) == 0 { - return true, nil + return true, false, nil } // reach here, chkForJoin is j.chk matched, err := j.filter(chkForJoin, chk, outer.Len()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } - return matched, nil - + return matched, false, nil } -func (j *innerJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *innerJoiner) onMissMatch(_ bool, outer chunk.Row, chk *chunk.Chunk) { } diff --git a/executor/merge_join.go b/executor/merge_join.go index 8e3787ba3b427..e7934d1562967 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -63,6 +63,7 @@ type mergeJoinOuterTable struct { iter *chunk.Iterator4Chunk row chunk.Row hasMatch bool + hasNull bool } // mergeJoinInnerTable represents the inner table of merge join. @@ -307,13 +308,14 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM } if cmpResult < 0 { - e.joiner.onMissMatch(e.outerTable.row, chk) + e.joiner.onMissMatch(false, e.outerTable.row, chk) if err != nil { return false, errors.Trace(err) } e.outerTable.row = e.outerTable.iter.Next() e.outerTable.hasMatch = false + e.outerTable.hasNull = false if chk.NumRows() == e.maxChunkSize { return true, nil @@ -321,18 +323,20 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM continue } - matched, err := e.joiner.tryToMatch(e.outerTable.row, e.innerIter4Row, chk) + matched, isNull, err := e.joiner.tryToMatch(e.outerTable.row, e.innerIter4Row, chk) if err != nil { return false, errors.Trace(err) } e.outerTable.hasMatch = e.outerTable.hasMatch || matched + e.outerTable.hasNull = e.outerTable.hasNull || isNull if e.innerIter4Row.Current() == e.innerIter4Row.End() { if !e.outerTable.hasMatch { - e.joiner.onMissMatch(e.outerTable.row, chk) + e.joiner.onMissMatch(e.outerTable.hasNull, e.outerTable.row, chk) } e.outerTable.row = e.outerTable.iter.Next() e.outerTable.hasMatch = false + e.outerTable.hasNull = false e.innerIter4Row.Begin() } diff --git a/executor/merge_join_test.go b/executor/merge_join_test.go index 913b9460d4856..ccf21c60ffd78 100644 --- a/executor/merge_join_test.go +++ b/executor/merge_join_test.go @@ -358,13 +358,11 @@ func (s *testSuite) TestMergeJoin(c *C) { tk.MustExec("insert into s values(1,1)") tk.MustQuery("explain select /*+ TIDB_SMJ(t, s) */ a in (select a from s where s.b >= t.b) from t").Check(testkit.Rows( "Projection_7 10000.00 root 6_aux_0", - "└─MergeJoin_8 10000.00 root left outer semi join, left key:test.t.a, right key:test.s.a, other cond:ge(test.s.b, test.t.b)", - " ├─Sort_12 10000.00 root test.t.a:asc", - " │ └─TableReader_11 10000.00 root data:TableScan_10", - " │ └─TableScan_10 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo", - " └─Sort_16 10000.00 root test.s.a:asc", - " └─TableReader_15 10000.00 root data:TableScan_14", - " └─TableScan_14 10000.00 cop table:s, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─MergeJoin_8 10000.00 root left outer semi join, other cond:eq(test.t.a, test.s.a), ge(test.s.b, test.t.b)", + " ├─TableReader_10 10000.00 root data:TableScan_9", + " │ └─TableScan_9 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─TableReader_12 10000.00 root data:TableScan_11", + " └─TableScan_11 10000.00 cop table:s, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("select /*+ TIDB_SMJ(t, s) */ a in (select a from s where s.b >= t.b) from t").Check(testkit.Rows( "1", diff --git a/executor/union_scan.go b/executor/union_scan.go index ee12f0925f990..1b5bc98f4dfaf 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -286,7 +286,7 @@ func (us *UnionScanExec) buildAndSortAddedRows() error { } } mutableRow.SetDatums(newData...) - matched, err := expression.EvalBool(us.ctx, us.conditions, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(us.ctx, us.conditions, mutableRow.ToRow()) if err != nil { return errors.Trace(err) } diff --git a/expression/chunk_executor.go b/expression/chunk_executor.go index 7434ccda38963..2bea52bb783ea 100644 --- a/expression/chunk_executor.go +++ b/expression/chunk_executor.go @@ -253,7 +253,7 @@ func VectorizedFilter(ctx sessionctx.Context, filters []Expression, iterator *ch selected[row.Idx()] = selected[row.Idx()] && !isNull && (filterResult != 0) } else { // TODO: should rewrite the filter to `cast(expr as SIGNED) != 0` and always use `EvalInt`. - bVal, err := EvalBool(ctx, []Expression{filter}, row) + bVal, _, err := EvalBool(ctx, []Expression{filter}, row) if err != nil { return nil, errors.Trace(err) } diff --git a/expression/column.go b/expression/column.go index fdae386632503..af6cfcdb81745 100644 --- a/expression/column.go +++ b/expression/column.go @@ -161,6 +161,10 @@ type Column struct { Index int hashcode []byte + + // InOperand indicates whether this column is the inner operand of column equal condition converted + // from `[not] in (subq)`. + InOperand bool } // Equal implements Expression interface. diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index 49243fc30ab45..1ddfcd28c44a4 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -215,7 +215,7 @@ func (s *propagateConstantSolver) pickNewEQConds(visited []bool) (retMapper map[ var ok bool if col == nil { if con, ok = cond.(*Constant); ok { - value, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) + value, _, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) terror.Log(errors.Trace(err)) if !value { s.setConds2ConstFalse() diff --git a/expression/expression.go b/expression/expression.go index 1d00438a34f41..755b424de66bd 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -112,26 +112,57 @@ func (e CNFExprs) Clone() CNFExprs { return cnf } -// EvalBool evaluates expression list to a boolean value. -func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, error) { +func isColumnInOperand(c *Column) bool { + return c.InOperand +} + +// IsEQCondFromIn checks if an expression is equal condition converted from `[not] in (subq)`. +func IsEQCondFromIn(expr Expression) bool { + sf, ok := expr.(*ScalarFunction) + if !ok || sf.FuncName.L != ast.EQ { + return false + } + cols := make([]*Column, 0, 1) + cols = ExtractColumnsFromExpressions(cols, sf.GetArgs(), isColumnInOperand) + return len(cols) > 0 +} + +// EvalBool evaluates expression list to a boolean value. The first returned value +// indicates bool result of the expression list, the second returned value indicates +// whether the result of the expression list is null, it can only be true when the +// first returned values is false. +func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, bool, error) { + hasNull := false for _, expr := range exprList { data, err := expr.Eval(row) if err != nil { - return false, errors.Trace(err) + return false, false, err } if data.IsNull() { - return false, nil + // For queries like `select a in (select a from s where t.b = s.b) from t`, + // if result of `t.a = s.a` is null, we cannot return immediately until + // we have checked if `t.b = s.b` is null or false, because it means + // subquery is empty, and we should return false as the result of the whole + // exprList in that case, instead of null. + if !IsEQCondFromIn(expr) { + return false, false, nil + } + hasNull = true + continue } i, err := data.ToBool(ctx.GetSessionVars().StmtCtx) if err != nil { - return false, errors.Trace(err) + return false, false, err } if i == 0 { - return false, nil + return false, false, nil } } - return true, nil + if hasNull { + return false, true, nil + } + return true, false, nil } // composeConditionWithBinaryOp composes condition with binary operator into a balance deep tree, which benefits a lot for pb decoder/encoder. diff --git a/expression/util.go b/expression/util.go index 36ada315874f3..fd447bf9e9156 100644 --- a/expression/util.go +++ b/expression/util.go @@ -113,6 +113,21 @@ func extractColumnSet(expr Expression, set *intsets.Sparse) { } } +func setExprColumnInOperand(expr Expression) Expression { + switch v := expr.(type) { + case *Column: + col := v.Clone().(*Column) + col.InOperand = true + return col + case *ScalarFunction: + args := v.GetArgs() + for i, arg := range args { + args[i] = setExprColumnInOperand(arg) + } + } + return expr +} + // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression { @@ -122,7 +137,11 @@ func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Ex if id == -1 { return v } - return newExprs[id] + newExpr := newExprs[id] + if v.InOperand { + newExpr = setExprColumnInOperand(newExpr) + } + return newExpr case *ScalarFunction: if v.FuncName.L == ast.Cast { newFunc := v.Clone().(*ScalarFunction) diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index dd846305fe836..992cd10540095 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -657,7 +657,7 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) { tk.MustQuery("explain select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t;"). Check(testkit.Rows( "Projection_11 10.00 root 9_aux_0", - "└─Apply_13 10.00 root left outer semi join, inner:StreamAgg_20, equal:[eq(test.t.c, count(*))]", + "└─Apply_13 10.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, count(*))", " ├─TableReader_15 10.00 root data:TableScan_14", " │ └─TableScan_14 10.00 cop table:t, range:[-inf,+inf], keep order:false", " └─StreamAgg_20 1.00 root funcs:count(1)", diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 6d8d24c6284d6..2392ade00e552 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -306,6 +306,23 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { return inNode, false } +func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r expression.Expression, not bool) { + var condition expression.Expression + if rCol, ok := r.(*expression.Column); ok && (er.asScalar || not) { + rCol.InOperand = true + // If both input columns of `!= all / = any` expression are not null, we can treat the expression + // as normal column equal condition. + if lCol, ok := l.(*expression.Column); ok && mysql.HasNotNullFlag(lCol.GetType().Flag) && mysql.HasNotNullFlag(rCol.GetType().Flag) { + rCol.InOperand = false + } + } + condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ) + if er.err != nil { + return + } + er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, not) +} + func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) { v.L.Accept(er) if er.err != nil { @@ -333,7 +350,6 @@ func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen) return v, true } - var condition expression.Expression var rexpr expression.Expression if np.Schema().Len() == 1 { rexpr = np.Schema().Columns[0] @@ -351,20 +367,23 @@ func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) switch v.Op { // Only EQ, NE and NullEQ can be composed with and. case opcode.EQ, opcode.NE, opcode.NullEQ: - condition, er.err = er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ) - if er.err != nil { - er.err = errors.Trace(er.err) - return v, true - } if v.Op == opcode.EQ { if v.All { er.handleEQAll(lexpr, rexpr, np) } else { - er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, false) + // `a = any(subq)` will be rewriten as `a in (subq)`. + er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, false) + if er.err != nil { + return v, true + } } } else if v.Op == opcode.NE { if v.All { - er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, true) + // `a != all(subq)` will be rewriten as `a not in (subq)`. + er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, true) + if er.err != nil { + return v, true + } } else { er.handleNEAny(lexpr, rexpr, np) } @@ -658,6 +677,18 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, var rexpr expression.Expression if np.Schema().Len() == 1 { rexpr = np.Schema().Columns[0] + rCol := rexpr.(*expression.Column) + // For AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, we cannot treat `in` expression as + // normal column equal condition, so we specially mark the inner operand here. + if v.Not || asScalar { + rCol.InOperand = true + // If both input columns of `in` expression are not null, we can treat the expression + // as normal column equal condition instead. + lCol, ok := lexpr.(*expression.Column) + if ok && mysql.HasNotNullFlag(lCol.GetType().Flag) && mysql.HasNotNullFlag(rCol.GetType().Flag) { + rCol.InOperand = false + } + } } else { args := make([]expression.Expression, 0, np.Schema().Len()) for _, col := range np.Schema().Columns { @@ -669,8 +700,6 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, return v, true } } - // a in (subq) will be rewrote as a = any(subq). - // a not in (subq) will be rewrote as a != all(subq). checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ) if err != nil { er.err = errors.Trace(err) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index ed8203a0c0026..024ada9e8d1b2 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -181,7 +181,7 @@ func (ds *DataSource) tryToGetMemTask(prop *property.PhysicalProperty) (task tas func (ds *DataSource) tryToGetDualTask() (task, error) { for _, cond := range ds.pushedDownConds { if _, ok := cond.(*expression.Constant); ok { - result, err := expression.EvalBool(ds.ctx, []expression.Expression{cond}, chunk.Row{}) + result, _, err := expression.EvalBool(ds.ctx, []expression.Expression{cond}, chunk.Row{}) if err != nil { return nil, errors.Trace(err) } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index c47f6380c0fc7..30ebc82b44824 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -187,7 +187,12 @@ func extractOnCondition(conditions []expression.Expression, left LogicalPlan, ri if ok && binop.FuncName.L == ast.EQ { ln, lOK := binop.GetArgs()[0].(*expression.Column) rn, rOK := binop.GetArgs()[1].(*expression.Column) - if lOK && rOK { + // For quries like `select a in (select a from s where s.b = t.b) from t`, + // if subquery is empty caused by `s.b = t.b`, the result should always be + // false even if t.a is null or s.a is null. To make this join "empty aware", + // we should differentiate `t.a = s.a` from other column equal conditions, so + // we put it into OtherConditions instead of EqualConditions of join. + if lOK && rOK && !ln.InOperand && !rn.InOperand { if left.Schema().Contains(ln) && right.Schema().Contains(rn) { eqCond = append(eqCond, binop) continue @@ -472,7 +477,7 @@ func (b *planBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMappe cnfItems := expression.SplitCNFItems(expr) for _, item := range cnfItems { if con, ok := item.(*expression.Constant); ok { - ret, err := expression.EvalBool(b.ctx, expression.CNFExprs{con}, chunk.Row{}) + ret, _, err := expression.EvalBool(b.ctx, expression.CNFExprs{con}, chunk.Row{}) if err != nil || ret { continue } diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 2060e7dd64ff0..9a3658e579364 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -51,9 +51,9 @@ func canProjectionBeEliminatedStrict(p *PhysicalProjection) bool { func resolveColumnAndReplace(origin *expression.Column, replace map[string]*expression.Column) { dst := replace[string(origin.HashCode(nil))] if dst != nil { - colName, retType := origin.ColName, origin.RetType + colName, retType, inOperand := origin.ColName, origin.RetType, origin.InOperand *origin = *dst - origin.ColName, origin.RetType = colName, retType + origin.ColName, origin.RetType, origin.InOperand = colName, retType, inOperand } } diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index c2e93528b65da..5dd7ed040a87a 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -193,6 +193,12 @@ func (p *LogicalJoin) updateEQCond() { for i := len(p.OtherConditions) - 1; i >= 0; i-- { need2Remove := false if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { + // If it is a column equal condition converted from `[not] in (subq)`, do not move it + // to EqualConditions, and keep it in OtherConditions. Reference comments in `extractOnCondition` + // for detailed reasons. + if expression.IsEQCondFromIn(eqCond) { + continue + } lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { lKeys = append(lKeys, lExpr)