From c0d4b29c751d5c3003c8acdfe96f005c1caf3bd8 Mon Sep 17 00:00:00 2001 From: Arenatlx <314806019@qq.com> Date: Tue, 20 Sep 2022 00:05:03 +0800 Subject: [PATCH 1/4] cherry pick #37512 to release-6.2 Signed-off-by: ti-srebot --- cmd/explaintest/r/naaj.result | 412 +++++++++++++++++++++ cmd/explaintest/t/naaj.test | 213 +++++++++++ executor/benchmark_test.go | 7 +- executor/builder.go | 54 ++- executor/hash_table.go | 243 +++++++++++- executor/join.go | 446 ++++++++++++++++++++++- executor/joiner.go | 161 +++++++- executor/joiner_test.go | 2 +- executor/pkg_test.go | 2 +- planner/core/exhaust_physical_plans.go | 46 ++- planner/core/explain.go | 21 +- planner/core/logical_plans.go | 14 + planner/core/physical_plans.go | 32 +- planner/core/plan_cost.go | 15 +- planner/core/plan_to_pb.go | 2 + planner/core/resolve_indices.go | 20 + planner/core/rule_column_pruning.go | 3 + planner/core/rule_predicate_push_down.go | 92 +++-- planner/core/stats.go | 15 + sessionctx/variable/session.go | 3 + sessionctx/variable/sysvar.go | 4 + sessionctx/variable/tidb_vars.go | 4 + util/bitmap/concurrent.go | 37 ++ util/bitmap/concurrent_test.go | 14 + 24 files changed, 1765 insertions(+), 97 deletions(-) create mode 100644 cmd/explaintest/r/naaj.result create mode 100644 cmd/explaintest/t/naaj.test diff --git a/cmd/explaintest/r/naaj.result b/cmd/explaintest/r/naaj.result new file mode 100644 index 0000000000000..bc5bda03fbbc3 --- /dev/null +++ b/cmd/explaintest/r/naaj.result @@ -0,0 +1,412 @@ +use test; +set @@session.tidb_enable_null_aware_anti_join=1; +select "***************************************************** PART 1 *****************************************************************" as name; +name +***************************************************** PART 1 ***************************************************************** +drop table if exists naaj_A, naaj_B; +create table naaj_A(a int, b int, c int); +create table naaj_B(a int, b int, c int); +insert into naaj_A values (1,1,1); +insert into naaj_B values (1,2,2); +explain format = 'brief' select (a, b) not in (select a, b from naaj_B) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)] +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +1 +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B); +id estRows task access object operator info +HashJoin 8000.00 root Null-aware anti semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)] +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +1 1 1 +insert into naaj_B values(1,1,1); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +0 +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +insert into naaj_B values(1, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +0 +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +(a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) +1 +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); +id estRows task access object operator info +HashJoin 8000.00 root Null-aware anti semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); +a b c +1 1 1 +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:ne(test.naaj_a.a, test.naaj_b.a) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; +(a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a) +1 +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a); +id estRows task access object operator info +HashJoin 8000.00 root Null-aware anti semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:ne(test.naaj_a.a, test.naaj_b.a) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a); +a b c +1 1 1 +select * from naaj_A where (a, b) not in (select a, b from naaj_B where false); +a b c +1 1 1 +select (a, b) not in (select a, b from naaj_B where false) from naaj_A; +(a, b) not in (select a, b from naaj_B where false) +1 +insert into naaj_B values(2, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +0 +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +delete from naaj_B where a=1 and b=1 and c=1; +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +NULL +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +select "***************************************************** PART 2 *****************************************************************" as name; +name +***************************************************** PART 2 ***************************************************************** +delete from naaj_A; +delete from naaj_B; +insert into naaj_A values(1,null,1); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +1 +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +1 NULL 1 +insert into naaj_B values(2, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +1 +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +1 NULL 1 +insert into naaj_B values(null, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +NULL +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +delete from naaj_B; +insert into naaj_B values(2, 2, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +(a, b) not in (select a, b from naaj_B) +1 +select * from naaj_A where (a, b) not in (select a, b from naaj_B); +a b c +1 NULL 1 +insert into naaj_B values(2, null, 2); +insert into naaj_B values(null, null, 2); +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +(a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) +1 +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); +id estRows task access object operator info +HashJoin 8000.00 root Null-aware anti semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); +a b c +1 NULL 1 +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root anti left outer semi join, equal:[eq(test.naaj_a.c, test.naaj_b.c)], other cond:eq(test.naaj_a.a, test.naaj_b.a), eq(test.naaj_a.b, test.naaj_b.b) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +(a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c) +1 +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c); +id estRows task access object operator info +HashJoin 8000.00 root anti semi join, equal:[eq(test.naaj_a.c, test.naaj_b.c)], other cond:eq(test.naaj_a.a, test.naaj_b.a), eq(test.naaj_a.b, test.naaj_b.b) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c); +a b c +1 NULL 1 +select "***************************************************** PART 3 *****************************************************************" as name; +name +***************************************************** PART 3 ***************************************************************** +drop table if exists naaj_A, naaj_B; +create table naaj_A(a int, b int, c int); +create table naaj_B(a int, b int, c int); +insert into naaj_A values (1,1,1); +insert into naaj_B values (1,2,2); +explain format = 'brief' select (a, b) != all (select a, b from naaj_B) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)] +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +1 +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B); +id estRows task access object operator info +Projection 8000.00 root test.naaj_a.a, test.naaj_a.b, test.naaj_a.c +└─Selection 8000.00 root Column#9 + └─HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)] + ├─TableReader(Build) 10000.00 root data:TableFullScan + │ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo + └─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +1 1 1 +insert into naaj_B values(1,1,1); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +0 +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +insert into naaj_B values(1, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +0 +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +(a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) +1 +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); +id estRows task access object operator info +Projection 8000.00 root test.naaj_a.a, test.naaj_a.b, test.naaj_a.c +└─Selection 8000.00 root Column#9 + └─HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) + ├─TableReader(Build) 10000.00 root data:TableFullScan + │ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo + └─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); +a b c +1 1 1 +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:ne(test.naaj_a.a, test.naaj_b.a) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; +(a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a) +1 +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a); +id estRows task access object operator info +Projection 8000.00 root test.naaj_a.a, test.naaj_a.b, test.naaj_a.c +└─Selection 8000.00 root Column#9 + └─HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:ne(test.naaj_a.a, test.naaj_b.a) + ├─TableReader(Build) 10000.00 root data:TableFullScan + │ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo + └─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a); +a b c +1 1 1 +select * from naaj_A where (a, b) != all (select a, b from naaj_B where false); +a b c +1 1 1 +select (a, b) != all (select a, b from naaj_B where false) from naaj_A; +(a, b) != all (select a, b from naaj_B where false) +1 +insert into naaj_B values(2, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +0 +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +delete from naaj_B where a=1 and b=1 and c=1; +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +NULL +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +select "***************************************************** PART 4 *****************************************************************" as name; +name +***************************************************** PART 4 ***************************************************************** +delete from naaj_A; +delete from naaj_B; +insert into naaj_A values(1,null,1); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +1 +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +1 NULL 1 +insert into naaj_B values(2, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +1 +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +1 NULL 1 +insert into naaj_B values(null, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +NULL +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +delete from naaj_B; +insert into naaj_B values(2, 2, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +(a, b) != all (select a, b from naaj_B) +1 +select * from naaj_A where (a, b) != all (select a, b from naaj_B); +a b c +1 NULL 1 +insert into naaj_B values(2, null, 2); +insert into naaj_B values(null, null, 2); +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +(a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) +1 +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); +id estRows task access object operator info +Projection 8000.00 root test.naaj_a.a, test.naaj_a.b, test.naaj_a.c +└─Selection 8000.00 root Column#9 + └─HashJoin 10000.00 root Null-aware anti left outer semi join, equal:[eq(test.naaj_a.b, test.naaj_b.b) eq(test.naaj_a.a, test.naaj_b.a)], other cond:gt(test.naaj_a.c, test.naaj_b.c) + ├─TableReader(Build) 10000.00 root data:TableFullScan + │ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo + └─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); +a b c +1 NULL 1 +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +id estRows task access object operator info +HashJoin 10000.00 root anti left outer semi join, equal:[eq(test.naaj_a.c, test.naaj_b.c)], other cond:eq(test.naaj_a.a, test.naaj_b.a), eq(test.naaj_a.b, test.naaj_b.b) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +(a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c) +1 +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c); +id estRows task access object operator info +Projection 8000.00 root test.naaj_a.a, test.naaj_a.b, test.naaj_a.c +└─Selection 8000.00 root Column#9 + └─HashJoin 10000.00 root anti left outer semi join, equal:[eq(test.naaj_a.c, test.naaj_b.c)], other cond:eq(test.naaj_a.a, test.naaj_b.a), eq(test.naaj_a.b, test.naaj_b.b) + ├─TableReader(Build) 10000.00 root data:TableFullScan + │ └─TableFullScan 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo + └─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c); +a b c +1 NULL 1 +select "***************************************************** PART 5 *****************************************************************" as name; +name +***************************************************** PART 5 ***************************************************************** +delete from naaj_A; +delete from naaj_B; +insert into naaj_A values(1,1,1); +insert into naaj_B values(2,null,2); +select (a,b) not in (select a, b from naaj_B) from naaj_A; +(a,b) not in (select a, b from naaj_B) +1 +select * from naaj_A where (a,b) not in (select a, b from naaj_B); +a b c +1 1 1 +explain select (a+1,b*2) not in (select a, b from naaj_B) from naaj_A; +id estRows task access object operator info +HashJoin_9 10000.00 root Null-aware anti left outer semi join, equal:[eq(Column#14, test.naaj_b.b) eq(Column#15, test.naaj_b.a)] +├─TableReader_14(Build) 10000.00 root data:TableFullScan_13 +│ └─TableFullScan_13 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─Projection_10(Probe) 10000.00 root mul(test.naaj_a.b, 2)->Column#14, plus(test.naaj_a.a, 1)->Column#15 + └─TableReader_12 10000.00 root data:TableFullScan_11 + └─TableFullScan_11 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select (a+1,b*2) not in (select a, b from naaj_B) from naaj_A; +(a+1,b*2) not in (select a, b from naaj_B) +NULL +insert into naaj_B values(2,2,2); +select (a+1,b*2) not in (select a, b from naaj_B) from naaj_A; +(a+1,b*2) not in (select a, b from naaj_B) +0 +explain select * from naaj_A where (a+1,b*2) not in (select a+1, b-1 from naaj_B); +id estRows task access object operator info +HashJoin_9 8000.00 root Null-aware anti semi join, equal:[eq(Column#13, Column#10) eq(Column#14, Column#9)] +├─Projection_13(Build) 10000.00 root plus(test.naaj_b.a, 1)->Column#9, minus(test.naaj_b.b, 1)->Column#10 +│ └─TableReader_15 10000.00 root data:TableFullScan_14 +│ └─TableFullScan_14 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─Projection_10(Probe) 10000.00 root test.naaj_a.a, test.naaj_a.b, test.naaj_a.c, mul(test.naaj_a.b, 2)->Column#13, plus(test.naaj_a.a, 1)->Column#14 + └─TableReader_12 10000.00 root data:TableFullScan_11 + └─TableFullScan_11 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +select * from naaj_A where (a+1,b*2) not in (select a, b from naaj_B); +a b c +explain select (a+1,b*2) not in (select a, b=1 from naaj_B where naaj_A.a = naaj_B.a) from naaj_A; +id estRows task access object operator info +HashJoin_9 10000.00 root anti left outer semi join, equal:[eq(test.naaj_a.a, test.naaj_b.a)], other cond:eq(mul(test.naaj_a.b, 2), eq(test.naaj_b.b, 1)), eq(plus(test.naaj_a.a, 1), test.naaj_b.a) +├─TableReader_13(Build) 10000.00 root data:TableFullScan_12 +│ └─TableFullScan_12 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader_11(Probe) 10000.00 root data:TableFullScan_10 + └─TableFullScan_10 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +explain select * from naaj_A where (a+1,b*2) not in (select a, b=1 from naaj_B where naaj_A.a = naaj_B.a); +id estRows task access object operator info +HashJoin_9 8000.00 root anti semi join, equal:[eq(test.naaj_a.a, test.naaj_b.a)], other cond:eq(mul(test.naaj_a.b, 2), eq(test.naaj_b.b, 1)), eq(plus(test.naaj_a.a, 1), test.naaj_b.a) +├─TableReader_13(Build) 10000.00 root data:TableFullScan_12 +│ └─TableFullScan_12 10000.00 cop[tikv] table:naaj_B keep order:false, stats:pseudo +└─TableReader_11(Probe) 10000.00 root data:TableFullScan_10 + └─TableFullScan_10 10000.00 cop[tikv] table:naaj_A keep order:false, stats:pseudo +set @@session.tidb_enable_null_aware_anti_join=0; diff --git a/cmd/explaintest/t/naaj.test b/cmd/explaintest/t/naaj.test new file mode 100644 index 0000000000000..eedada4c29202 --- /dev/null +++ b/cmd/explaintest/t/naaj.test @@ -0,0 +1,213 @@ +# naaj.test file is for null-aware anti join +use test; +set @@session.tidb_enable_null_aware_anti_join=1; +# assert the cases for the left side without null. +select "***************************************************** PART 1 *****************************************************************" as name; +drop table if exists naaj_A, naaj_B; +create table naaj_A(a int, b int, c int); +create table naaj_B(a int, b int, c int); +insert into naaj_A values (1,1,1); +insert into naaj_B values (1,2,2); + +# assert 1: both side don't have null values. +# AntiLeftOuterSemiJoin +explain format = 'brief' select (a, b) not in (select a, b from naaj_B) from naaj_A; +select (a, b) not in (select a, b from naaj_B) from naaj_A; + +# AntiSemiJoin +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B); +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# assert 2: right side has same key bucket. +insert into naaj_B values(1,1,1); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# assert 3: right side has null values. +insert into naaj_B values(1, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# assert 4: right side have null values, but it can't pass the inner(join key related or not) filter. +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; + +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); + +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; +select (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; + +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a); +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.a != naaj_B.a); + +# assert 5: right side is empty. +select * from naaj_A where (a, b) not in (select a, b from naaj_B where false); +select (a, b) not in (select a, b from naaj_B where false) from naaj_A; + +# assert 6: right side null bucket filter (not-null join key should match with each other). +insert into naaj_B values(2, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +delete from naaj_B where a=1 and b=1 and c=1; +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# case 2: assert the cases for the left side has null. +select "***************************************************** PART 2 *****************************************************************" as name; +delete from naaj_A; +delete from naaj_B; +insert into naaj_A values(1,null,1); + +# assert 1: left side has null, while the right is empty. +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# assert 2: left side has null, while the right has a invalid null row (can't pass the nullBit filter). +insert into naaj_B values(2, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# left side has null, while the right has a valid null row. (passed the nullBit filter). +insert into naaj_B values(null, null, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# assert 3: left side has null, while the right has a valid non-null row. +delete from naaj_B; +insert into naaj_B values(2, 2, 2); +select (a, b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) not in (select a, b from naaj_B); + +# assert 4: left side has null, while the right has no valid rows (equivalent to ). +insert into naaj_B values(2, null, 2); +insert into naaj_B values(null, null, 2); +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +select (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c > naaj_B.c); + +# assert 5: When the inner subq has a correlated EQ condition, we won't built the NA-EQ connecting condition here. +explain format = 'brief' select (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +select (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +explain format = 'brief' select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c); +select * from naaj_A where (a, b) not in (select a, b from naaj_B where naaj_A.c = naaj_B.c); + +# case 3: assert the cases for the equivalent semantic predicate of != ALL +select "***************************************************** PART 3 *****************************************************************" as name; +drop table if exists naaj_A, naaj_B; +create table naaj_A(a int, b int, c int); +create table naaj_B(a int, b int, c int); +insert into naaj_A values (1,1,1); +insert into naaj_B values (1,2,2); + +# assert 1: both side don't have null values. +# AntiLeftOuterSemiJoin +explain format = 'brief' select (a, b) != all (select a, b from naaj_B) from naaj_A; +select (a, b) != all (select a, b from naaj_B) from naaj_A; + +# AntiSemiJoin +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B); +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# assert 2: right side has same key bucket. +insert into naaj_B values(1,1,1); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# assert 3: right side has null values. +insert into naaj_B values(1, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# assert 4: right side have null values, but it can't pass the inner(join key related or not) filter. +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; + +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); + +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; +select (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a) from naaj_A; + +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a); +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.a != naaj_B.a); + +# assert 5: right side is empty. +select * from naaj_A where (a, b) != all (select a, b from naaj_B where false); +select (a, b) != all (select a, b from naaj_B where false) from naaj_A; + +# assert 6: right side null bucket filter (not-null join key should match with each other). +insert into naaj_B values(2, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +delete from naaj_B where a=1 and b=1 and c=1; +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# case 4: assert the cases for the equivalent semantic predicate of != ALL +select "***************************************************** PART 4 *****************************************************************" as name; +delete from naaj_A; +delete from naaj_B; +insert into naaj_A values(1,null,1); + +# assert 1: left side has null, while the right is empty. +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# assert 2: left side has null, while the right has a invalid null row (can't pass the nullBit filter). +insert into naaj_B values(2, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# left side has null, while the right has a valid null row. (passed the nullBit filter). +insert into naaj_B values(null, null, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# assert 3: left side has null, while the right has a valid non-null row. +delete from naaj_B; +insert into naaj_B values(2, 2, 2); +select (a, b) != all (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a, b) != all (select a, b from naaj_B); + +# assert 4: left side has null, while the right has no valid rows (equivalent to ). +insert into naaj_B values(2, null, 2); +insert into naaj_B values(null, null, 2); +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +select (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c) from naaj_A; +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c > naaj_B.c); + +# assert 5: When the inner subq has a correlated EQ condition, we won't built the NA-EQ connecting condition here. +explain format = 'brief' select (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +select (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c) from naaj_A; +explain format = 'brief' select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c); +select * from naaj_A where (a, b) != all (select a, b from naaj_B where naaj_A.c = naaj_B.c); + +# case 5: assert some bugs. +select "***************************************************** PART 5 *****************************************************************" as name; +delete from naaj_A; +delete from naaj_B; +insert into naaj_A values(1,1,1); +insert into naaj_B values(2,null,2); + +# assert 1: although the probe key doesn't have null values, we still need to use buildNullBits to guarantee the non-null position has the exactly the same value. +select (a,b) not in (select a, b from naaj_B) from naaj_A; +select * from naaj_A where (a,b) not in (select a, b from naaj_B); + +# assert 2: should inject the projection under join. +explain select (a+1,b*2) not in (select a, b from naaj_B) from naaj_A; +select (a+1,b*2) not in (select a, b from naaj_B) from naaj_A; +insert into naaj_B values(2,2,2); +select (a+1,b*2) not in (select a, b from naaj_B) from naaj_A; + +explain select * from naaj_A where (a+1,b*2) not in (select a+1, b-1 from naaj_B); +select * from naaj_A where (a+1,b*2) not in (select a, b from naaj_B); + +# assert 3: NA-EQ and EQ can't co-exist at the same time. +explain select (a+1,b*2) not in (select a, b=1 from naaj_B where naaj_A.a = naaj_B.a) from naaj_A; +explain select * from naaj_A where (a+1,b*2) not in (select a, b=1 from naaj_B where naaj_A.a = naaj_B.a); +set @@session.tidb_enable_null_aware_anti_join=0; diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 06282390677e4..ee09477062232 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -929,7 +929,7 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor) e.joiners = make([]joiner, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.joiners[i] = newJoiner(testCase.ctx, e.joinType, true, defaultValues, - nil, lhsTypes, rhsTypes, childrenUsedSchema) + nil, lhsTypes, rhsTypes, childrenUsedSchema, false) } memLimit := int64(-1) if testCase.disk { @@ -1336,7 +1336,7 @@ func prepare4IndexInnerHashJoin(tc *indexJoinTestCase, outerDS *mockDataSource, hashCols: tc.innerHashKeyIdx, }, workerWg: new(sync.WaitGroup), - joiner: newJoiner(tc.ctx, 0, false, defaultValues, nil, leftTypes, rightTypes, nil), + joiner: newJoiner(tc.ctx, 0, false, defaultValues, nil, leftTypes, rightTypes, nil, false), isOuterJoin: false, keyOff2IdxOff: keyOff2IdxOff, lastColHelper: nil, @@ -1420,7 +1420,7 @@ func prepare4IndexMergeJoin(tc *indexJoinTestCase, outerDS *mockDataSource, inne concurrency := e.ctx.GetSessionVars().IndexLookupJoinConcurrency() joiners := make([]joiner, concurrency) for i := 0; i < concurrency; i++ { - joiners[i] = newJoiner(tc.ctx, 0, false, defaultValues, nil, leftTypes, rightTypes, nil) + joiners[i] = newJoiner(tc.ctx, 0, false, defaultValues, nil, leftTypes, rightTypes, nil, false) } e.joiners = joiners return e, nil @@ -1539,6 +1539,7 @@ func prepareMergeJoinExec(tc *mergeJoinTestCase, joinSchema *expression.Schema, retTypes(leftExec), retTypes(rightExec), tc.childrenUsedSchema, + false, ) mergeJoinExec.innerTable = &mergeJoinTable{ diff --git a/executor/builder.go b/executor/builder.go index bdf055d95cf17..41d34e40bbaac 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1210,6 +1210,7 @@ func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) Execu retTypes(leftExec), retTypes(rightExec), markChildrenUsedCols(v.Schema(), v.Children()[0].Schema(), v.Children()[1].Schema()), + false, ), isOuterJoin: v.JoinType.IsOuterJoin(), desc: v.Desc, @@ -1295,12 +1296,12 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo if v.UseOuterToBuild { // update the buildSideEstCount due to changing the build side if v.InnerChildIdx == 1 { - e.buildSideExec, e.buildKeys = leftExec, v.LeftJoinKeys - e.probeSideExec, e.probeKeys = rightExec, v.RightJoinKeys + e.buildSideExec, e.buildKeys, e.buildNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys + e.probeSideExec, e.probeKeys, e.probeNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys e.outerFilter = v.LeftConditions } else { - e.buildSideExec, e.buildKeys = rightExec, v.RightJoinKeys - e.probeSideExec, e.probeKeys = leftExec, v.LeftJoinKeys + e.buildSideExec, e.buildKeys, e.buildNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys + e.probeSideExec, e.probeKeys, e.probeNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys e.outerFilter = v.RightConditions leftIsBuildSide = false } @@ -1309,12 +1310,12 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } } else { if v.InnerChildIdx == 0 { - e.buildSideExec, e.buildKeys = leftExec, v.LeftJoinKeys - e.probeSideExec, e.probeKeys = rightExec, v.RightJoinKeys + e.buildSideExec, e.buildKeys, e.buildNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys + e.probeSideExec, e.probeKeys, e.probeNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys e.outerFilter = v.RightConditions } else { - e.buildSideExec, e.buildKeys = rightExec, v.RightJoinKeys - e.probeSideExec, e.probeKeys = leftExec, v.LeftJoinKeys + e.buildSideExec, e.buildKeys, e.buildNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys + e.probeSideExec, e.probeKeys, e.probeNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys e.outerFilter = v.LeftConditions leftIsBuildSide = false } @@ -1322,12 +1323,13 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo defaultValues = make([]types.Datum, e.buildSideExec.Schema().Len()) } } + isNAJoin := len(v.LeftNAJoinKeys) > 0 e.buildSideEstCount = b.buildSideEstCount(v) childrenUsedSchema := markChildrenUsedCols(v.Schema(), v.Children()[0].Schema(), v.Children()[1].Schema()) e.joiners = make([]joiner, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.joiners[i] = newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, - v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema) + v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema, isNAJoin) } executorCountHashJoinExec.Inc() @@ -1336,15 +1338,26 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, // and use ETString to hash the second column, although they may be the same column. leftExecTypes, rightExecTypes := retTypes(leftExec), retTypes(rightExec) - leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)) + leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)+len(v.LeftNAJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)+len(v.RightNAJoinKeys)) + // set left types and right types for joiner. for i, col := range v.LeftJoinKeys { leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) leftTypes[i].SetFlag(col.RetType.GetFlag()) } + offset := len(v.LeftJoinKeys) + for i, col := range v.LeftNAJoinKeys { + leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) + leftTypes[i+offset].SetFlag(col.RetType.GetFlag()) + } for i, col := range v.RightJoinKeys { rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) rightTypes[i].SetFlag(col.RetType.GetFlag()) } + offset = len(v.RightJoinKeys) + for i, col := range v.RightNAJoinKeys { + rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) + rightTypes[i+offset].SetFlag(col.RetType.GetFlag()) + } // consider collations for i := range v.EqualConditions { @@ -1354,6 +1367,14 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo rightTypes[i].SetCharset(chs) rightTypes[i].SetCollate(coll) } + offset = len(v.EqualConditions) + for i := range v.NAEqualConditions { + chs, coll := v.NAEqualConditions[i].CharsetAndCollation() + leftTypes[i+offset].SetCharset(chs) + leftTypes[i+offset].SetCollate(coll) + rightTypes[i+offset].SetCharset(chs) + rightTypes[i+offset].SetCollate(coll) + } if leftIsBuildSide { e.buildTypes, e.probeTypes = leftTypes, rightTypes } else { @@ -1900,7 +1921,10 @@ func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) Executor { if b.err != nil { return nil } - otherConditions := append(expression.ScalarFuncs2Exprs(v.EqualConditions), v.OtherConditions...) + // test is in the explain/naaj.test#part5. + // although we prepared the NAEqualConditions, but for Apply mode, we still need move it to other conditions like eq condition did here. + otherConditions := append(expression.ScalarFuncs2Exprs(v.EqualConditions), expression.ScalarFuncs2Exprs(v.NAEqualConditions)...) + otherConditions = append(otherConditions, v.OtherConditions...) defaultValues := v.DefaultValues if defaultValues == nil { defaultValues = make([]types.Datum, v.Children()[v.InnerChildIdx].Schema().Len()) @@ -1912,7 +1936,7 @@ func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) Executor { outerFilter, innerFilter = v.RightConditions, v.LeftConditions } tupleJoiner := newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, - defaultValues, otherConditions, retTypes(leftChild), retTypes(rightChild), nil) + defaultValues, otherConditions, retTypes(leftChild), retTypes(rightChild), nil, false) serialExec := &NestedLoopApplyExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec, innerExec), innerExec: innerExec, @@ -1949,7 +1973,7 @@ func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) Executor { corCols = append(corCols, corCol) innerFilters = append(innerFilters, innerFilter.Clone()) joiners = append(joiners, newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, - defaultValues, otherConditions, retTypes(leftChild), retTypes(rightChild), nil)) + defaultValues, otherConditions, retTypes(leftChild), retTypes(rightChild), nil, false)) } allExecs := append([]Executor{outerExec}, innerExecs...) @@ -2936,7 +2960,7 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) finished: &atomic.Value{}, } childrenUsedSchema := markChildrenUsedCols(v.Schema(), v.Children()[0].Schema(), v.Children()[1].Schema()) - e.joiner = newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema) + e.joiner = newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema, false) outerKeyCols := make([]int, len(v.OuterJoinKeys)) for i := 0; i < len(v.OuterJoinKeys); i++ { outerKeyCols[i] = v.OuterJoinKeys[i].Index @@ -3060,7 +3084,7 @@ func (b *executorBuilder) buildIndexLookUpMergeJoin(v *plannercore.PhysicalIndex childrenUsedSchema := markChildrenUsedCols(v.Schema(), v.Children()[0].Schema(), v.Children()[1].Schema()) joiners := make([]joiner, e.ctx.GetSessionVars().IndexLookupJoinConcurrency()) for i := 0; i < len(joiners); i++ { - joiners[i] = newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema) + joiners[i] = newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema, false) } e.joiners = joiners return e diff --git a/executor/hash_table.go b/executor/hash_table.go index 8b39573d3c5b4..d2b294f52d9ad 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/bitmap" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/disk" @@ -37,11 +38,14 @@ import ( // hashContext keeps the needed hash context of a db table in hash join. type hashContext struct { // allTypes one-to-one correspondence with keyColIdx - allTypes []*types.FieldType - keyColIdx []int - buf []byte - hashVals []hash.Hash64 - hasNull []bool + allTypes []*types.FieldType + keyColIdx []int + naKeyColIdx []int + buf []byte + hashVals []hash.Hash64 + hasNull []bool + naHasNull []bool + naColNullBitMap []*bitmap.ConcurrentBitmap } func (hc *hashContext) initHash(rows int) { @@ -61,6 +65,21 @@ func (hc *hashContext) initHash(rows int) { hc.hashVals[i].Reset() } } + if len(hc.naKeyColIdx) > 0 { + // isNAAJ + if len(hc.naColNullBitMap) < rows { + hc.naHasNull = make([]bool, rows) + hc.naColNullBitMap = make([]*bitmap.ConcurrentBitmap, rows) + for i := 0; i < rows; i++ { + hc.naColNullBitMap[i] = bitmap.NewConcurrentBitmap(len(hc.naKeyColIdx)) + } + } else { + for i := 0; i < rows; i++ { + hc.naHasNull[i] = false + hc.naColNullBitMap[i].Reset(len(hc.naKeyColIdx)) + } + } + } } type hashStatistic struct { @@ -83,6 +102,9 @@ type hashRowContainer struct { // hashTable stores the map of hashKey and RowPtr hashTable baseHashTable + // hashNANullBucket stores the rows with any null value in NAAJ join key columns. + // After build process, NANUllBucket is read only here for multi probe worker. + hashNANullBucket []*naEntry rowContainer *chunk.RowContainer memTracker *memory.Tracker @@ -109,6 +131,8 @@ func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContex func (c *hashRowContainer) ShallowCopy() *hashRowContainer { newHRC := *c newHRC.rowContainer = c.rowContainer.ShallowCopyWithNewMutex() + // multi hashRowContainer ref to one single NA-NULL bucket slice. + // newHRC.hashNANullBucket = c.hashNANullBucket return &newHRC } @@ -120,6 +144,68 @@ func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, h return matchedRows, err } +func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row, + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + // for NAAJ probe row with null, we should match them with all build rows. + var ( + ok bool + err error + innerPtrs []chunk.RowPtr + ) + c.hashTable.Iter( + func(_ uint64, e *entry) { + entryAddr := e + for entryAddr != nil { + innerPtrs = append(innerPtrs, entryAddr.ptr) + entryAddr = entryAddr.next + } + }) + matched = matched[:0] + if len(innerPtrs) == 0 { + return matched, nil + } + // all built bucket rows come from hash table, their bitmap are all nil (doesn't contain any null). so + // we could only use the probe null bits to filter valid rows. + if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 { + // if len(probeHCtx.naKeyColIdx)=1 + // that means the NA-Join probe key is directly a (null) <-> (fetch all buckets), nothing to do. + // else like + // (null, 1, 2), we should use the not-null probe bit to filter rows. Only fetch rows like + // ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column. + needCheckProbeRowPos = needCheckProbeRowPos[:0] + needCheckBuildRowPos = needCheckBuildRowPos[:0] + keyColLen := len(c.hCtx.naKeyColIdx) + for i := 0; i < keyColLen; i++ { + // since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated. + if probeKeyNullBits.UnsafeIsSet(i) { + continue + } + needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) + needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + } + } + var mayMatchedRow chunk.Row + for _, ptr := range innerPtrs { + mayMatchedRow, c.chkBuf, err = c.rowContainer.GetRowAndAppendToChunk(ptr, c.chkBuf) + if err != nil { + return nil, err + } + if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 { + // check the idxs-th value of the join columns. + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + if err != nil { + return nil, err + } + if !ok { + continue + } + // once ok. just append the (maybe) valid build row for latter other conditions check if any. + } + matched = append(matched, mayMatchedRow) + } + return matched, nil +} + // GetMatchedRowsAndPtrs get matched rows and Ptrs from probeRow. It can be called // in multiple goroutines while each goroutine should keep its own // h and buf. @@ -154,6 +240,84 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk return matched, matchedPtrs, err } +func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRow chunk.Row, + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + var ( + ok bool + err error + mayMatchedRow chunk.Row + ) + matched = matched[:0] + for _, nullEntry := range c.hashNANullBucket { + mayMatchedRow, c.chkBuf, err = c.rowContainer.GetRowAndAppendToChunk(nullEntry.ptr, c.chkBuf) + if err != nil { + return nil, err + } + // since null bucket is a unified bucket. cases like below: + // case1: left side (probe side) has null + // left side key <1,null>, actually we can fetch all bucket <1, ?> and filter 1 at the first join key, once + // got a valid right row after other condition, then we can just return. + // case2: left side (probe side) don't have null + // left side key <1, 2>, actually we should fetch <1,null>, , from the null bucket because + // case like <3,null> is obviously not matched with the probe key. + needCheckProbeRowPos = needCheckProbeRowPos[:0] + needCheckBuildRowPos = needCheckBuildRowPos[:0] + keyColLen := len(c.hCtx.naKeyColIdx) + if probeKeyNullBits != nil { + // when the probeKeyNullBits is not nil, it means the probe key has null values, where we should distinguish + // whether is empty set or not. In other words, we should fetch at least a valid from the null bucket here. + // for values at the same index of the join key in which they are both not null, the values should be exactly the same. + // + // step: probeKeyNullBits & buildKeyNullBits, for those bits with 0, we should check if both values are the same. + // we can just use the UnsafeIsSet here, because insert action of the build side has all finished. + // + // 1 0 1 0 means left join key : null ? null ? + // 1 0 0 0 means right join key : null ? ? ? + // --------------------------------------------- + // left & right: 1 0 1 0: just do the explicit column value check for whose bit is 0. (means no null from both side) + for i := 0; i < keyColLen; i++ { + if probeKeyNullBits.UnsafeIsSet(i) || nullEntry.nullBitMap.UnsafeIsSet(i) { + continue + } + needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) + needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + } + // check the idxs-th value of the join columns. + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + if err != nil { + return nil, err + } + if !ok { + continue + } + } else { + // when the probeKeyNullBits is nil, it means the probe key is not null. But in the process of matching the null bucket, + // we still need to do the non-null (explicit) value check. + // + // eg: the probe key is <1,2>, we only get <2, null> in the null bucket, even we can take the null as a wildcard symbol, + // the first value of this two tuple is obviously not a match. So we need filter it here. + for i := 0; i < keyColLen; i++ { + if nullEntry.nullBitMap.UnsafeIsSet(i) { + continue + } + needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) + needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + } + // check the idxs-th value of the join columns. + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + if err != nil { + return nil, err + } + if !ok { + continue + } + } + // once ok. just append the (maybe) valid build row for latter other conditions check if any. + matched = append(matched, mayMatchedRow) + } + return matched, err +} + // matchJoinKey checks if join keys of buildRow and probeRow are logically equal. func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx *hashContext) (ok bool, err error) { return codec.EqualChunkRow(c.sc, @@ -190,6 +354,8 @@ func (c *hashRowContainer) PutChunkSelected(chk *chunk.Chunk, selected, ignoreNu c.hCtx.initHash(numRows) hCtx := c.hCtx + // By now, the combination of 1 and 2 can't take a run at same time. + // 1: write the row data of join key to hashVals. (normal EQ key should ignore the null values.) null-EQ for Except statement is an exception. for keyIdx, colIdx := range c.hCtx.keyColIdx { ignoreNull := len(ignoreNulls) > keyIdx && ignoreNulls[keyIdx] err := codec.HashChunkSelected(c.sc, hCtx.hashVals, chk, hCtx.allTypes[keyIdx], colIdx, hCtx.buf, hCtx.hasNull, selected, ignoreNull) @@ -197,13 +363,52 @@ func (c *hashRowContainer) PutChunkSelected(chk *chunk.Chunk, selected, ignoreNu return errors.Trace(err) } } + // 2: write the row data of NA join key to hashVals. (NA EQ key should collect all rows including null value as one bucket.) + isNAAJ := len(c.hCtx.naKeyColIdx) > 0 + hasNullMark := make([]bool, len(hCtx.hasNull)) + for keyIdx, colIdx := range c.hCtx.naKeyColIdx { + // NAAJ won't ignore any null values, but collect them as one hash bucket. + err := codec.HashChunkSelected(c.sc, hCtx.hashVals, chk, hCtx.allTypes[keyIdx], colIdx, hCtx.buf, hCtx.hasNull, selected, false) + if err != nil { + return errors.Trace(err) + } + // todo: we can collect the bitmap in codec.HashChunkSelected to avoid loop here, but the params modification is quite big. + // after fetch one NA column, collect the null value to null bitmap for every row. (use hasNull flag to accelerate) + // eg: if a NA Join cols is (a, b, c), for every build row here we maintained a 3-bit map to mark which column are null for them. + for rowIdx := 0; rowIdx < numRows; rowIdx++ { + if hCtx.hasNull[rowIdx] { + hCtx.naColNullBitMap[rowIdx].UnsafeSet(keyIdx) + // clean and try fetch next NA join col. + hCtx.hasNull[rowIdx] = false + // just a mark variable for whether there is a null in at least one NA join column. + hasNullMark[rowIdx] = true + } + } + } for i := 0; i < numRows; i++ { - if (selected != nil && !selected[i]) || c.hCtx.hasNull[i] { - continue + if isNAAJ { + if selected != nil && !selected[i] { + continue + } + if hasNullMark[i] { + // collect the null rows to slice. + rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)} + // do not directly ref the null bits map here, because the bit map will be reset and reused in next batch of chunk data. + c.hashNANullBucket = append(c.hashNANullBucket, &naEntry{rowPtr, c.hCtx.naColNullBitMap[i].Clone()}) + } else { + // insert the not-null rows to hash table. + key := c.hCtx.hashVals[i].Sum64() + rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)} + c.hashTable.Put(key, rowPtr) + } + } else { + if (selected != nil && !selected[i]) || c.hCtx.hasNull[i] { + continue + } + key := c.hCtx.hashVals[i].Sum64() + rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)} + c.hashTable.Put(key, rowPtr) } - key := c.hCtx.hashVals[i].Sum64() - rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)} - c.hashTable.Put(key, rowPtr) } c.GetMemTracker().Consume(c.hashTable.GetAndCleanMemoryDelta()) return nil @@ -261,6 +466,11 @@ type entry struct { next *entry } +type naEntry struct { + ptr chunk.RowPtr + nullBitMap *bitmap.ConcurrentBitmap +} + type entryStore struct { slices [][]entry cursor int @@ -299,6 +509,7 @@ type baseHashTable interface { // GetAndCleanMemoryDelta gets and cleans the memDelta of the baseHashTable. Memory delta will be cleared after each fetch. // It indicates the memory delta of the baseHashTable since the last calling GetAndCleanMemoryDelta(). GetAndCleanMemoryDelta() int64 + Iter(func(uint64, *entry)) } // TODO (fangzhuhe) remove unsafeHashTable later if it not used anymore @@ -359,6 +570,13 @@ func (ht *unsafeHashTable) GetAndCleanMemoryDelta() int64 { return memDelta } +func (ht *unsafeHashTable) Iter(traverse func(key uint64, e *entry)) { + for k := range ht.hashMap { + entryAddr := ht.hashMap[k] + traverse(k, entryAddr) + } +} + // concurrentMapHashTable is a concurrent hash table built on concurrentMap type concurrentMapHashTable struct { hashMap concurrentMap @@ -404,6 +622,11 @@ func (ht *concurrentMapHashTable) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) { return } +// Iter gets the every value of the hash table. +func (ht *concurrentMapHashTable) Iter(traverse func(key uint64, e *entry)) { + ht.hashMap.IterCb(traverse) +} + // GetAndCleanMemoryDelta gets and cleans the memDelta of the concurrentMapHashTable. Memory delta will be cleared after each fetch. func (ht *concurrentMapHashTable) GetAndCleanMemoryDelta() int64 { var memDelta int64 diff --git a/executor/join.go b/executor/join.go index d01b34f37a3e6..97176a8deab37 100644 --- a/executor/join.go +++ b/executor/join.go @@ -55,7 +55,9 @@ type HashJoinExec struct { buildSideEstCount float64 outerFilter expression.CNFExprs probeKeys []*expression.Column + probeNAKeys []*expression.Column buildKeys []*expression.Column + buildNAKeys []*expression.Column isNullEQ []bool probeTypes []*types.FieldType buildTypes []*types.FieldType @@ -98,6 +100,10 @@ type HashJoinExec struct { // We pre-alloc and reuse the Rows and RowPtrs for each probe goroutine, to avoid allocation frequently buildSideRows [][]chunk.Row buildSideRowPtrs [][]chunk.RowPtr + + // for every naaj probe worker, pre-allocate the int slice for store the join column index to check. + needCheckBuildRowPos [][]int + needCheckProbeRowPos [][]int } // probeChkResource stores the result of the join probe side fetch worker, @@ -154,6 +160,8 @@ func (e *HashJoinExec) Close() error { e.outerMatchedStatus = e.outerMatchedStatus[:0] e.buildSideRows = nil e.buildSideRowPtrs = nil + e.needCheckBuildRowPos = nil + e.needCheckProbeRowPos = nil if e.stats != nil && e.rowContainer != nil { e.stats.hashStat = *e.rowContainer.stat } @@ -241,6 +249,11 @@ func (e *HashJoinExec) fetchProbeSideChunks(ctx context.Context) { } else if emptyBuild { return } + // after building is finished. the hash null bucket slice is allocated and determined. + // copy it for multi probe worker. + for i := range e.rowContainerForProbe { + e.rowContainerForProbe[i].hashNANullBucket = e.rowContainer.hashNANullBucket + } hasWaitedForBuild = true } @@ -336,6 +349,8 @@ func (e *HashJoinExec) initializeForProbe() { e.buildSideRows = make([][]chunk.Row, e.concurrency) e.buildSideRowPtrs = make([][]chunk.RowPtr, e.concurrency) + e.needCheckBuildRowPos = make([][]int, e.concurrency) + e.needCheckProbeRowPos = make([][]int, e.concurrency) } func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { @@ -347,16 +362,20 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { }, e.handleProbeSideFetcherPanic) probeKeyColIdx := make([]int, len(e.probeKeys)) + probeNAKeColIdx := make([]int, len(e.probeNAKeys)) for i := range e.probeKeys { probeKeyColIdx[i] = e.probeKeys[i].Index } + for i := range e.probeNAKeys { + probeNAKeColIdx[i] = e.probeNAKeys[i].Index + } for i := uint(0); i < e.concurrency; i++ { e.joinWorkerWaitGroup.Add(1) workID := i go util.WithRecovery(func() { defer trace.StartRegion(ctx, "HashJoinWorker").End() - e.runJoinWorker(workID, probeKeyColIdx) + e.runJoinWorker(workID, probeKeyColIdx, probeNAKeColIdx) }, e.handleJoinWorkerPanic) } go util.WithRecovery(e.waitJoinWorkersAndCloseResultChan, nil) @@ -429,7 +448,7 @@ func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { close(e.joinResultCh) } -func (e *HashJoinExec) runJoinWorker(workerID uint, probeKeyColIdx []int) { +func (e *HashJoinExec) runJoinWorker(workerID uint, probeKeyColIdx, probeNAKeyColIdx []int) { probeTime := int64(0) if e.stats != nil { start := time.Now() @@ -455,8 +474,9 @@ func (e *HashJoinExec) runJoinWorker(workerID uint, probeKeyColIdx []int) { dest: e.probeResultChs[workerID], } hCtx := &hashContext{ - allTypes: e.probeTypes, - keyColIdx: probeKeyColIdx, + allTypes: e.probeTypes, + keyColIdx: probeKeyColIdx, + naKeyColIdx: probeNAKeyColIdx, } for ok := true; ok; { if e.finished.Load().(bool) { @@ -533,6 +553,351 @@ func (e *HashJoinExec) joinMatchedProbeSideRow2ChunkForOuterHashJoin(workerID ui return true, joinResult } +// joinNAALOSJMatchProbeSideRow2Chunk implement the matching logic for NA-AntiLeftOuterSemiJoin +func (e *HashJoinExec) joinNAALOSJMatchProbeSideRow2Chunk(workerID uint, probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *hashContext, + rowContainer *hashRowContainer, joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { + var ( + err error + ok bool + ) + if probeKeyNullBits == nil { + // step1: match the same key bucket first. + // because AntiLeftOuterSemiJoin cares about the scalar value. If we both have a match from null + // bucket and same key bucket, we should return the result as from same-key bucket + // rather than from null bucket. + e.buildSideRows[workerID], err = rowContainer.GetMatchedRows(probeKey, probeSideRow, hCtx, e.buildSideRows[workerID]) + buildSideRows := e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) != 0 { + iter1 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter1) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter1, joinResult.chk, LeftNotNullRightNotNull) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid same-key bucket row from right side. + // as said in the comment, once we meet a same key (NOT IN semantic) in CNF, we can determine the result as . + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + } + // step2: match the null bucket secondly. + e.buildSideRows[workerID], err = rowContainer.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, e.buildSideRows[workerID], e.needCheckBuildRowPos[workerID], e.needCheckProbeRowPos[workerID]) + buildSideRows = e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't find a valid same key match from same-key bucket yet + // and the null bucket is empty. so the result should be . + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult + } + iter2 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter2) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter2, joinResult.chk, LeftNotNullRightHasNull) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid null bucket row from right side. + // as said in the comment, once we meet a null in CNF, we can determine the result as . + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and same key bucket, here means two cases: + // case1: x NOT IN (empty set): if other key bucket don't have the valid rows yet. + // case2: x NOT IN (l,m,n...): if other key bucket do have the valid rows. + // both cases mean the result should be + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult + } + // when left side has null values, all we want is to find a valid build side rows (past other condition) + // so we can return it as soon as possible. here means two cases: + // case1: NOT IN (empty set): ----------------------> result is . + // case2: NOT IN (at least a valid inner row) ------------------> result is . + // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) + e.buildSideRows[workerID], err = rowContainer.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, e.buildSideRows[workerID], e.needCheckBuildRowPos[workerID], e.needCheckProbeRowPos[workerID]) + buildSideRows := e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) != 0 { + iter1 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter1) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter1, joinResult.chk, LeftHasNullRightHasNull) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid null bucket row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the result as . + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + } + // Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any). + e.buildSideRows[workerID], err = rowContainer.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, e.buildSideRows[workerID], e.needCheckBuildRowPos[workerID], e.needCheckProbeRowPos[workerID]) + buildSideRows = e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, + // which means x NOT IN (empty set) or x NOT IN (l,m,n), the result should be + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult + } + iter2 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter2) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter2, joinResult.chk, LeftHasNullRightNotNull) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid same key bucket row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the result as . + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and all hash bucket, here means only one cases: + // case1: NOT IN (empty set): + // empty set comes from no rows from all bucket can pass other condition. the result should be + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult +} + +// joinNAASJMatchProbeSideRow2Chunk implement the matching logic for NA-AntiSemiJoin +func (e *HashJoinExec) joinNAASJMatchProbeSideRow2Chunk(workerID uint, probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *hashContext, + rowContainer *hashRowContainer, joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { + var ( + err error + ok bool + ) + if probeKeyNullBits == nil { + // step1: match null bucket first. + // need fetch the "valid" rows every time. (nullBits map check is necessary) + e.buildSideRows[workerID], err = rowContainer.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, e.buildSideRows[workerID], e.needCheckBuildRowPos[workerID], e.needCheckProbeRowPos[workerID]) + buildSideRows := e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) != 0 { + iter1 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter1) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter1, joinResult.chk) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid null bucket row from right side. + // as said in the comment, once we meet a rhs null in CNF, we can determine the reject of lhs row. + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + } + // step2: then same key bucket. + e.buildSideRows[workerID], err = rowContainer.GetMatchedRows(probeKey, probeSideRow, hCtx, e.buildSideRows[workerID]) + buildSideRows = e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, + // which means x NOT IN (empty set), accept the rhs row. + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult + } + iter2 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter2) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter2, joinResult.chk) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid same key bucket row from right side. + // as said in the comment, once we meet a false in CNF, we can determine the reject of lhs row. + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and same key bucket, here means two cases: + // case1: x NOT IN (empty set): if other key bucket don't have the valid rows yet. + // case2: x NOT IN (l,m,n...): if other key bucket do have the valid rows. + // both cases should accept the rhs row. + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult + } + // when left side has null values, all we want is to find a valid build side rows (passed from other condition) + // so we can return it as soon as possible. here means two cases: + // case1: NOT IN (empty set): ----------------------> accept rhs row. + // case2: NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row. + // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) + e.buildSideRows[workerID], err = rowContainer.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, e.buildSideRows[workerID], e.needCheckBuildRowPos[workerID], e.needCheckProbeRowPos[workerID]) + buildSideRows := e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) != 0 { + iter1 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter1) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter1, joinResult.chk) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid null bucket row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the reject of lhs row. + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + } + // Step2: match all hash table bucket build rows. + e.buildSideRows[workerID], err = rowContainer.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, e.buildSideRows[workerID], e.needCheckBuildRowPos[workerID], e.needCheckProbeRowPos[workerID]) + buildSideRows = e.buildSideRows[workerID] + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, + // which means NOT IN (empty set) or NOT IN (no valid rows) accept the rhs row. + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult + } + iter2 := chunk.NewIterator4Slice(buildSideRows) + defer chunk.FreeIterator(iter2) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := e.joiners[workerID].tryToMatchInners(probeSideRow, iter2, joinResult.chk) + if err != nil { + joinResult.err = err + return false, joinResult + } + // here matched means: there is a valid key row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the reject of lhs row. + if matched { + return true, joinResult + } + if joinResult.chk.IsFull() { + e.joinResultCh <- joinResult + ok, joinResult = e.getNewJoinResult(workerID) + if !ok { + return false, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and all hash bucket, here means only one cases: + // case1: NOT IN (empty set): + // empty set comes from no rows from all bucket can pass other condition. we should accept the rhs row. + e.joiners[workerID].onMissMatch(false, probeSideRow, joinResult.chk) + return true, joinResult +} + +// joinNAAJMatchProbeSideRow2Chunk implement the matching priority logic for NA-AntiSemiJoin and NA-AntiLeftOuterSemiJoin +// there are some bucket-matching priority difference between them. +// +// Since NA-AntiSemiJoin don't need to append the scalar value with the left side row, there is a quick matching path. +// 1: lhs row has null: +// lhs row has null can't determine its result in advance, we should judge whether the right valid set is empty +// or not. For semantic like x NOT IN(y set), If y set is empty, the scalar result is 1; Otherwise, the result +// is 0. Since NA-AntiSemiJoin don't care about the scalar value, we just try to find a valid row from right side, +// once we found it then just return the left side row instantly. (same as NA-AntiLeftOuterSemiJoin) +// +// 2: lhs row without null: +// same-key bucket and null-bucket which should be the first to match? For semantic like x NOT IN(y set), once y +// set has a same key x, the scalar value is 0; else if y set has a null key, then the scalar value is null. Both +// of them lead the refuse of the lhs row without any difference. Since NA-AntiSemiJoin don't care about the scalar +// value, we can just match the null bucket first and refuse the lhs row as quickly as possible, because a null of +// yi in the CNF (x NA-EQ yi) can always determine a negative value (refuse lhs row) in advance here. +// +// For NA-AntiLeftOuterSemiJoin, we couldn't match null-bucket first, because once y set has a same key x and null +// key, we should return the result as left side row appended with a scalar value 0 which is from same key matching failure. +func (e *HashJoinExec) joinNAAJMatchProbeSideRow2Chunk(workerID uint, probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *hashContext, + rowContainer *hashRowContainer, joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { + NAAntiSemiJoin := e.joinType == plannercore.AntiSemiJoin && len(e.buildNAKeys) > 0 + NAAntiLeftOuterSemiJoin := e.joinType == plannercore.AntiLeftOuterSemiJoin && len(e.buildNAKeys) > 0 + if NAAntiSemiJoin { + return e.joinNAASJMatchProbeSideRow2Chunk(workerID, probeKey, probeKeyNullBits, probeSideRow, hCtx, rowContainer, joinResult) + } + if NAAntiLeftOuterSemiJoin { + return e.joinNAALOSJMatchProbeSideRow2Chunk(workerID, probeKey, probeKeyNullBits, probeSideRow, hCtx, rowContainer, joinResult) + } + // shouldn't be here, not a valid NAAJ. + return false, joinResult +} + func (e *HashJoinExec) joinMatchedProbeSideRow2Chunk(workerID uint, probeKey uint64, probeSideRow chunk.Row, hCtx *hashContext, rowContainer *hashRowContainer, joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { var err error @@ -594,7 +959,10 @@ func (e *HashJoinExec) join2Chunk(workerID uint, probeSideChk *chunk.Chunk, hCtx return false, joinResult } - hCtx.initHash(probeSideChk.NumRows()) + numRows := probeSideChk.NumRows() + hCtx.initHash(numRows) + // By now, path 1 and 2 won't be conducted at the same time. + // 1: write the row data of join key to hashVals. (normal EQ key should ignore the null values.) null-EQ for Except statement is an exception. for keyIdx, i := range hCtx.keyColIdx { ignoreNull := len(e.isNullEQ) > keyIdx && e.isNullEQ[keyIdx] err = codec.HashChunkSelected(rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[keyIdx], i, hCtx.buf, hCtx.hasNull, selected, ignoreNull) @@ -603,6 +971,26 @@ func (e *HashJoinExec) join2Chunk(workerID uint, probeSideChk *chunk.Chunk, hCtx return false, joinResult } } + // 2: write the row data of NA join key to hashVals. (NA EQ key should collect all row including null value, store null value in a special position) + isNAAJ := len(hCtx.naKeyColIdx) > 0 + for keyIdx, i := range hCtx.naKeyColIdx { + // NAAJ won't ignore any null values, but collect them up to probe. + err = codec.HashChunkSelected(rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[keyIdx], i, hCtx.buf, hCtx.hasNull, selected, false) + if err != nil { + joinResult.err = err + return false, joinResult + } + // after fetch one NA column, collect the null value to null bitmap for every row. (use hasNull flag to accelerate) + // eg: if a NA Join cols is (a, b, c), for every build row here we maintained a 3-bit map to mark which column is null for them. + for rowIdx := 0; rowIdx < numRows; rowIdx++ { + if hCtx.hasNull[rowIdx] { + hCtx.naColNullBitMap[rowIdx].UnsafeSet(keyIdx) + // clean and try fetch next NA join col. + hCtx.hasNull[rowIdx] = false + hCtx.naHasNull[rowIdx] = true + } + } + } for i := range selected { killed := atomic.LoadUint32(&e.ctx.GetSessionVars().Killed) == 1 @@ -615,13 +1003,38 @@ func (e *HashJoinExec) join2Chunk(workerID uint, probeSideChk *chunk.Chunk, hCtx joinResult.err = ErrQueryInterrupted return false, joinResult } - if !selected[i] || hCtx.hasNull[i] { // process unmatched probe side rows - e.joiners[workerID].onMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) - } else { // process matched probe side rows - probeKey, probeRow := hCtx.hashVals[i].Sum64(), probeSideChk.GetRow(i) - ok, joinResult = e.joinMatchedProbeSideRow2Chunk(workerID, probeKey, probeRow, hCtx, rowContainer, joinResult) - if !ok { - return false, joinResult + if isNAAJ { + if !selected[i] { + // since this is the case of using inner to build, so for an outer row unselected, we should fill the result when it's outer join. + e.joiners[workerID].onMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) + } + if hCtx.naHasNull[i] { + // here means the probe join connecting column has null value in it and this is special for matching all the hash buckets + // for it. (probeKey is not necessary here) + probeRow := probeSideChk.GetRow(i) + ok, joinResult = e.joinNAAJMatchProbeSideRow2Chunk(workerID, 0, hCtx.naColNullBitMap[i].Clone(), probeRow, hCtx, rowContainer, joinResult) + if !ok { + return false, joinResult + } + } else { + // here means the probe join connecting column without null values, where we should match same key bucket and null bucket for it at its order. + // step1: process same key matched probe side rows + probeKey, probeRow := hCtx.hashVals[i].Sum64(), probeSideChk.GetRow(i) + ok, joinResult = e.joinNAAJMatchProbeSideRow2Chunk(workerID, probeKey, nil, probeRow, hCtx, rowContainer, joinResult) + if !ok { + return false, joinResult + } + } + } else { + // since this is the case of using inner to build, so for an outer row unselected, we should fill the result when it's outer join. + if !selected[i] || hCtx.hasNull[i] { // process unmatched probe side rows + e.joiners[workerID].onMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) + } else { // process matched probe side rows + probeKey, probeRow := hCtx.hashVals[i].Sum64(), probeSideChk.GetRow(i) + ok, joinResult = e.joinMatchedProbeSideRow2Chunk(workerID, probeKey, probeRow, hCtx, rowContainer, joinResult) + if !ok { + return false, joinResult + } } } if joinResult.chk.IsFull() { @@ -683,9 +1096,14 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { for i := range e.buildKeys { buildKeyColIdx[i] = e.buildKeys[i].Index } + buildNAKeyColIdx := make([]int, len(e.buildNAKeys)) + for i := range e.buildNAKeys { + buildNAKeyColIdx[i] = e.buildNAKeys[i].Index + } hCtx := &hashContext{ - allTypes: e.buildTypes, - keyColIdx: buildKeyColIdx, + allTypes: e.buildTypes, + keyColIdx: buildKeyColIdx, + naKeyColIdx: buildNAKeyColIdx, } e.rowContainer = newHashRowContainer(e.ctx, int(e.buildSideEstCount), hCtx, retTypes(e.buildSideExec)) // we shallow copies rowContainer for each probe worker to avoid lock contention diff --git a/executor/joiner.go b/executor/joiner.go index ecfab11f66822..280eff9c15e2c 100644 --- a/executor/joiner.go +++ b/executor/joiner.go @@ -27,8 +27,10 @@ import ( var ( _ joiner = &semiJoiner{} _ joiner = &antiSemiJoiner{} + _ joiner = &nullAwareAntiSemiJoiner{} _ joiner = &leftOuterSemiJoiner{} _ joiner = &antiLeftOuterSemiJoiner{} + _ joiner = &nullAwareAntiLeftOuterSemiJoiner{} _ joiner = &leftOuterJoiner{} _ joiner = &rightOuterJoiner{} _ joiner = &innerJoiner{} @@ -70,7 +72,7 @@ type joiner interface { // NOTE: Callers need to call this function multiple times to consume all // the inner rows for an outer row, and decide whether the outer row can be // matched with at lease one inner row. - tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, isNull bool, err error) + tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, opt ...NAAJType) (matched bool, isNull bool, err error) // tryToMatchOuters tries to join a batch of outer rows with one inner row. // It's used when the join is an outer join and the hash table is built @@ -130,7 +132,7 @@ func JoinerType(j joiner) plannercore.JoinType { func newJoiner(ctx sessionctx.Context, joinType plannercore.JoinType, outerIsRight bool, defaultInner []types.Datum, filter []expression.Expression, - lhsColTypes, rhsColTypes []*types.FieldType, childrenUsed [][]bool) joiner { + lhsColTypes, rhsColTypes []*types.FieldType, childrenUsed [][]bool, isNA bool) joiner { base := baseJoiner{ ctx: ctx, conditions: filter, @@ -175,12 +177,18 @@ func newJoiner(ctx sessionctx.Context, joinType plannercore.JoinType, return &semiJoiner{base} case plannercore.AntiSemiJoin: base.shallowRow = chunk.MutRowFromTypes(shallowRowType) + if isNA { + return &nullAwareAntiSemiJoiner{baseJoiner: base} + } return &antiSemiJoiner{base} case plannercore.LeftOuterSemiJoin: base.shallowRow = chunk.MutRowFromTypes(shallowRowType) return &leftOuterSemiJoiner{base} case plannercore.AntiLeftOuterSemiJoin: base.shallowRow = chunk.MutRowFromTypes(shallowRowType) + if isNA { + return &nullAwareAntiLeftOuterSemiJoiner{baseJoiner: base} + } return &antiLeftOuterSemiJoiner{base} case plannercore.LeftOuterJoin, plannercore.RightOuterJoin, plannercore.InnerJoin: if len(base.conditions) > 0 { @@ -362,7 +370,7 @@ type semiJoiner struct { baseJoiner } -func (j *semiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *semiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } @@ -429,12 +437,75 @@ func (j *semiJoiner) Clone() joiner { return &semiJoiner{baseJoiner: j.baseJoiner.Clone()} } +// NAAJType is join detail type only used by null-aware AntiLeftOuterSemiJoin. +type NAAJType byte + +const ( + // Unknown for those default value. + Unknown NAAJType = 0 + // LeftHasNullRightNotNull means lhs is a null key, and rhs is not a null key. + LeftHasNullRightNotNull NAAJType = 1 + // LeftHasNullRightHasNull means lhs is a null key, and rhs is a null key. + LeftHasNullRightHasNull NAAJType = 2 + // LeftNotNullRightNotNull means lhs is in not a null key, and rhs is not a null key. + LeftNotNullRightNotNull NAAJType = 3 + // LeftNotNullRightHasNull means lhs is in not a null key, and rhs is a null key. + LeftNotNullRightHasNull NAAJType = 4 +) + +type nullAwareAntiSemiJoiner struct { + baseJoiner +} + +// tryToMatchInners implements joiner interface. +func (naaj *nullAwareAntiSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { + // Step1: inner rows come from NULL-bucket OR Same-Key bucket. (no rows mean not matched) + if inners.Len() == 0 { + return false, false, nil + } + // Step2: conditions come from other condition. + if len(naaj.conditions) == 0 { + // once there is no other condition, that means right ride has non-empty valid rows. (all matched) + inners.ReachEnd() + return true, false, nil + } + for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { + naaj.makeShallowJoinRow(naaj.outerIsRight, inner, outer) + valid, _, err := expression.EvalBool(naaj.ctx, naaj.conditions, naaj.shallowRow.ToRow()) + if err != nil { + return false, false, err + } + // since other condition is only from inner where clause, here we can say: + // for x NOT IN (y set) semantics, once we found an x in y set, it's determined already. (refuse probe row, append nothing) + if valid { + inners.ReachEnd() + return true, false, nil + } + // false or null means that this merged row can't pass the other condition, not a valid right side row. (continue) + } + err = inners.Error() + return false, false, err +} + +func (naaj *nullAwareAntiSemiJoiner) tryToMatchOuters(outers chunk.Iterator, inner chunk.Row, chk *chunk.Chunk, outerRowStatus []outerRowStatusFlag) (_ []outerRowStatusFlag, err error) { + // todo: use the outer build. + return outerRowStatus, err +} + +func (naaj *nullAwareAntiSemiJoiner) onMissMatch(_ bool, outer chunk.Row, chk *chunk.Chunk) { + chk.AppendRowByColIdxs(outer, naaj.lUsed) +} + +func (naaj *nullAwareAntiSemiJoiner) Clone() joiner { + return &nullAwareAntiSemiJoiner{baseJoiner: naaj.baseJoiner.Clone()} +} + type antiSemiJoiner struct { baseJoiner } // tryToMatchInners implements joiner interface. -func (j *antiSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *antiSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } @@ -503,7 +574,7 @@ type leftOuterSemiJoiner struct { } // tryToMatchInners implements joiner interface. -func (j *leftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *leftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } @@ -580,12 +651,84 @@ func (j *leftOuterSemiJoiner) Clone() joiner { return &leftOuterSemiJoiner{baseJoiner: j.baseJoiner.Clone()} } +type nullAwareAntiLeftOuterSemiJoiner struct { + baseJoiner +} + +func (naal *nullAwareAntiLeftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, opt ...NAAJType) (matched bool, _ bool, err error) { + if inners.Len() == 0 { + return false, false, nil + } + // Difference between nullAwareAntiLeftOuterSemiJoiner and AntiLeftOuterSemiJoiner. + // AntiLeftOuterSemiJoiner conditions contain NA-EQ and inner filters. In EvalBool, once either side has a null value in NA-EQ + // column operand, it will lead a false matched, and a true value of isNull. (which only admit not-null same key match) + // nullAwareAntiLeftOuterSemiJoiner conditions only contain inner filters. in EvalBool, any filter null or false will contribute + // to false matched, in other words, the isNull is permanently false. + if len(naal.conditions) == 0 { + // no inner filter other condition means all matched. (inners are valid source) + naal.onMatch(outer, chk, opt...) + inners.ReachEnd() + return true, false, nil + } + for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { + naal.makeShallowJoinRow(false, inner, outer) + + valid, _, err := expression.EvalBool(naal.ctx, naal.conditions, naal.shallowRow.ToRow()) + if err != nil { + return false, false, err + } + if valid { + // once find a valid inner row, we can determine the result already. + naal.onMatch(outer, chk, opt...) + inners.ReachEnd() + return true, false, nil + } + } + err = inners.Error() + return false, false, err +} + +func (naal *nullAwareAntiLeftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk, opt ...NAAJType) { + switch opt[0] { + case LeftNotNullRightNotNull: + // either side are not null. (x NOT IN (x...)) --> (rhs, 0) + lWide := chk.AppendRowByColIdxs(outer, naal.lUsed) + chk.AppendInt64(lWide, 0) + case LeftNotNullRightHasNull: + // right side has a null NA-EQ key. (x NOT IN (null...)) --> (rhs, null) + lWide := chk.AppendRowByColIdxs(outer, naal.lUsed) + chk.AppendNull(lWide) + case LeftHasNullRightHasNull, LeftHasNullRightNotNull: + // left side has a null NA-EQ key. (null NOT IN (what ever valid inner)) --(rhs, null) + lWide := chk.AppendRowByColIdxs(outer, naal.lUsed) + chk.AppendNull(lWide) + } +} + +func (naal *nullAwareAntiLeftOuterSemiJoiner) onMissMatch(_ bool, outer chunk.Row, chk *chunk.Chunk) { + // once come to here, it means we couldn't make it in previous short paths. + // cases like: + // 1: null/x NOT IN (empty set) + // 2: x NOT IN (non-empty set without x and null) + lWide := chk.AppendRowByColIdxs(outer, naal.lUsed) + chk.AppendInt64(lWide, 1) +} + +func (naal *nullAwareAntiLeftOuterSemiJoiner) tryToMatchOuters(outers chunk.Iterator, inner chunk.Row, chk *chunk.Chunk, outerRowStatus []outerRowStatusFlag) (_ []outerRowStatusFlag, err error) { + // todo: + return nil, err +} + +func (naal *nullAwareAntiLeftOuterSemiJoiner) Clone() joiner { + return &antiLeftOuterSemiJoiner{baseJoiner: naal.baseJoiner.Clone()} +} + type antiLeftOuterSemiJoiner struct { baseJoiner } // tryToMatchInners implements joiner interface. -func (j *antiLeftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *antiLeftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } @@ -670,7 +813,7 @@ type leftOuterJoiner struct { } // tryToMatchInners implements joiner interface. -func (j *leftOuterJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *leftOuterJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } @@ -749,7 +892,7 @@ type rightOuterJoiner struct { } // tryToMatchInners implements joiner interface. -func (j *rightOuterJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *rightOuterJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } @@ -824,7 +967,7 @@ type innerJoiner struct { } // tryToMatchInners implements joiner interface. -func (j *innerJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { +func (j *innerJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk, _ ...NAAJType) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { return false, false, nil } diff --git a/executor/joiner_test.go b/executor/joiner_test.go index 77777d89c966c..ea408c869b22a 100644 --- a/executor/joiner_test.go +++ b/executor/joiner_test.go @@ -54,7 +54,7 @@ func TestRequiredRows(t *testing.T) { for i, f := range rfields { defaultInner = append(defaultInner, innerChk.GetRow(0).GetDatum(i, f)) } - joiner := newJoiner(defaultCtx(), joinType, false, defaultInner, nil, lfields, rfields, nil) + joiner := newJoiner(defaultCtx(), joinType, false, defaultInner, nil, lfields, rfields, nil, false) fields := make([]*types.FieldType, 0, len(lfields)+len(rfields)) fields = append(fields, rfields...) diff --git a/executor/pkg_test.go b/executor/pkg_test.go index 48c9678991d9b..f91197250be7b 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -62,7 +62,7 @@ func TestNestedLoopApply(t *testing.T) { otherFilter := expression.NewFunctionInternal(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col0, col1) joiner := newJoiner(sctx, plannercore.InnerJoin, false, make([]types.Datum, innerExec.Schema().Len()), []expression.Expression{otherFilter}, - retTypes(outerExec), retTypes(innerExec), nil) + retTypes(outerExec), retTypes(innerExec), nil, false) joinSchema := expression.NewSchema(col0, col1) join := &NestedLoopApplyExec{ baseExecutor: newBaseExecutor(sctx, joinSchema, 0), diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index ff5265e2f8da7..e39c19c41e5ef 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -1779,6 +1779,11 @@ func (p *LogicalJoin) shouldUseMPPBCJ() bool { return checkChildFitBC(p.children[0]) || checkChildFitBC(p.children[1]) } +// canPushToCop checks if it can be pushed to some stores. +func (p *LogicalJoin) canPushToCop(storeTp kv.StoreType) bool { + return len(p.NAEQConditions) == 0 && p.baseLogicalPlan.canPushToCop(storeTp) +} + // LogicalJoin can generates hash join, index join and sort merge join. // Firstly we check the hint, if hint is figured by user, we force to choose the corresponding physical plan. // If the hint is not matched, it will get other candidates. @@ -1818,17 +1823,20 @@ func (p *LogicalJoin) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]P return joins, true, nil } - mergeJoins := p.GetMergeJoin(prop, p.schema, p.Stats(), p.children[0].statsInfo(), p.children[1].statsInfo()) - if (p.preferJoinType&preferMergeJoin) > 0 && len(mergeJoins) > 0 { - return mergeJoins, true, nil - } - joins = append(joins, mergeJoins...) + if !p.isNAAJ() { + // naaj refuse merge join and index join. + mergeJoins := p.GetMergeJoin(prop, p.schema, p.Stats(), p.children[0].statsInfo(), p.children[1].statsInfo()) + if (p.preferJoinType&preferMergeJoin) > 0 && len(mergeJoins) > 0 { + return mergeJoins, true, nil + } + joins = append(joins, mergeJoins...) - indexJoins, forced := p.tryToGetIndexJoin(prop) - if forced { - return indexJoins, true, nil + indexJoins, forced := p.tryToGetIndexJoin(prop) + if forced { + return indexJoins, true, nil + } + joins = append(joins, indexJoins...) } - joins = append(joins, indexJoins...) hashJoins := p.getHashJoins(prop) if (p.preferJoinType&preferHashJoin) > 0 && len(hashJoins) > 0 { @@ -1913,6 +1921,11 @@ func (p *LogicalJoin) tryToGetMppHashJoin(prop *property.PhysicalProperty, useBC return nil } lkeys, rkeys, _, _ := p.GetJoinKeys() + lNAkeys, rNAKeys := p.GetNAJoinKeys() + if len(lNAkeys) > 0 || len(rNAKeys) > 0 { + return nil + } + // todo: mpp na-keys. // check match property baseJoin := basePhysicalJoin{ JoinType: p.JoinType, @@ -1922,6 +1935,8 @@ func (p *LogicalJoin) tryToGetMppHashJoin(prop *property.PhysicalProperty, useBC DefaultValues: p.DefaultValues, LeftJoinKeys: lkeys, RightJoinKeys: rkeys, + LeftNAJoinKeys: lNAkeys, + RightNAJoinKeys: rNAKeys, } // It indicates which side is the build side. preferredBuildIndex := 0 @@ -2000,12 +2015,13 @@ func (p *LogicalJoin) tryToGetMppHashJoin(prop *property.PhysicalProperty, useBC childrenProps[1] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: rPartitionKeys, CanAddEnforcer: true, RejectSort: true} } join := PhysicalHashJoin{ - basePhysicalJoin: baseJoin, - Concurrency: uint(p.ctx.GetSessionVars().CopTiFlashConcurrencyFactor), - EqualConditions: p.EqualConditions, - storeTp: kv.TiFlash, - mppShuffleJoin: !useBCJ, - // Mpp Join has quite heavy cost. Even limit might not suspend it in time, so we dont scale the count. + basePhysicalJoin: baseJoin, + Concurrency: uint(p.ctx.GetSessionVars().CopTiFlashConcurrencyFactor), + EqualConditions: p.EqualConditions, + NAEqualConditions: p.NAEQConditions, + storeTp: kv.TiFlash, + mppShuffleJoin: !useBCJ, + // Mpp Join has quite heavy cost. Even limit might not suspend it in time, so we don't scale the count. }.Init(p.ctx, p.stats, p.blockOffset, childrenProps...) join.SetSchema(p.schema) return []PhysicalPlan{join} diff --git a/planner/core/explain.go b/planner/core/explain.go index e03bd43ceee2a..788a9e19d43c5 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -479,7 +479,11 @@ func (p *PhysicalHashJoin) explainInfo(normalized bool) string { buffer := new(strings.Builder) if len(p.EqualConditions) == 0 { - buffer.WriteString("CARTESIAN ") + if len(p.NAEqualConditions) == 0 { + buffer.WriteString("CARTESIAN ") + } else { + buffer.WriteString("Null-aware ") + } } buffer.WriteString(p.JoinType.String()) @@ -499,6 +503,21 @@ func (p *PhysicalHashJoin) explainInfo(normalized bool) string { buffer.WriteString("]") } } + if len(p.NAEqualConditions) > 0 { + if normalized { + buffer.WriteString(", equal:") + buffer.Write(expression.SortedExplainNormalizedScalarFuncList(p.NAEqualConditions)) + } else { + buffer.WriteString(", equal:[") + for i, NAEqualCondition := range p.NAEqualConditions { + if i != 0 { + buffer.WriteString(" ") + } + buffer.WriteString(NAEqualCondition.String()) + } + buffer.WriteString("]") + } + } if len(p.LeftConditions) > 0 { if normalized { buffer.WriteString(", left cond:") diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 27315d316f45f..29b25165aa113 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -142,6 +142,7 @@ type LogicalJoin struct { preferJoinOrder bool EqualConditions []*expression.ScalarFunction + NAEQConditions []*expression.ScalarFunction LeftConditions expression.CNFExprs RightConditions expression.CNFExprs OtherConditions expression.CNFExprs @@ -176,6 +177,10 @@ type LogicalJoin struct { equalCondOutCnt float64 } +func (p *LogicalJoin) isNAAJ() bool { + return len(p.NAEQConditions) > 0 +} + // Shallow shallow copies a LogicalJoin struct. func (p *LogicalJoin) Shallow() *LogicalJoin { join := *p @@ -363,6 +368,15 @@ func (p *LogicalJoin) GetJoinKeys() (leftKeys, rightKeys []*expression.Column, i return } +// GetNAJoinKeys extracts join keys(columns) from NAEqualCondition. +func (p *LogicalJoin) GetNAJoinKeys() (leftKeys, rightKeys []*expression.Column) { + for _, expr := range p.NAEQConditions { + leftKeys = append(leftKeys, expr.GetArgs()[0].(*expression.Column)) + rightKeys = append(rightKeys, expr.GetArgs()[1].(*expression.Column)) + } + return +} + // GetPotentialPartitionKeys return potential partition keys for join, the potential partition keys are // the join keys of EqualConditions func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*property.MPPPartitionColumn) { diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 0bbb8ef285657..d77fe7c33a15f 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -818,8 +818,15 @@ type basePhysicalJoin struct { InnerJoinKeys []*expression.Column LeftJoinKeys []*expression.Column RightJoinKeys []*expression.Column + // IsNullEQ is used for cases like Except statement where null key should be matched with null key. + // <1,null> is exactly matched with <1,null>, where the null value should not be filtered and + // the null is exactly matched with null only. (while in NAAJ null value should also be matched + // with other non-null item as well) IsNullEQ []bool DefaultValues []types.Datum + + LeftNAJoinKeys []*expression.Column + RightNAJoinKeys []*expression.Column } func (p *basePhysicalJoin) cloneWithSelf(newSelf PhysicalPlan) (*basePhysicalJoin, error) { @@ -838,6 +845,8 @@ func (p *basePhysicalJoin) cloneWithSelf(newSelf PhysicalPlan) (*basePhysicalJoi cloned.InnerJoinKeys = cloneCols(p.InnerJoinKeys) cloned.LeftJoinKeys = cloneCols(p.LeftJoinKeys) cloned.RightJoinKeys = cloneCols(p.RightJoinKeys) + cloned.LeftNAJoinKeys = cloneCols(p.LeftNAJoinKeys) + cloned.RightNAJoinKeys = cloneCols(p.RightNAJoinKeys) for _, d := range p.DefaultValues { cloned.DefaultValues = append(cloned.DefaultValues, *d.Clone()) } @@ -866,6 +875,8 @@ type PhysicalHashJoin struct { Concurrency uint EqualConditions []*expression.ScalarFunction + NAEqualConditions []*expression.ScalarFunction + // use the outer table to build a hash table when the outer table is smaller. UseOuterToBuild bool @@ -887,15 +898,21 @@ func (p *PhysicalHashJoin) Clone() (PhysicalPlan, error) { for _, c := range p.EqualConditions { cloned.EqualConditions = append(cloned.EqualConditions, c.Clone().(*expression.ScalarFunction)) } + for _, c := range p.NAEqualConditions { + cloned.NAEqualConditions = append(cloned.NAEqualConditions, c.Clone().(*expression.ScalarFunction)) + } return cloned, nil } // ExtractCorrelatedCols implements PhysicalPlan interface. func (p *PhysicalHashJoin) ExtractCorrelatedCols() []*expression.CorrelatedColumn { - corCols := make([]*expression.CorrelatedColumn, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) + corCols := make([]*expression.CorrelatedColumn, 0, len(p.EqualConditions)+len(p.NAEqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) for _, fun := range p.EqualConditions { corCols = append(corCols, expression.ExtractCorColumns(fun)...) } + for _, fun := range p.NAEqualConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } for _, fun := range p.LeftConditions { corCols = append(corCols, expression.ExtractCorColumns(fun)...) } @@ -911,22 +928,27 @@ func (p *PhysicalHashJoin) ExtractCorrelatedCols() []*expression.CorrelatedColum // NewPhysicalHashJoin creates a new PhysicalHashJoin from LogicalJoin. func NewPhysicalHashJoin(p *LogicalJoin, innerIdx int, useOuterToBuild bool, newStats *property.StatsInfo, prop ...*property.PhysicalProperty) *PhysicalHashJoin { leftJoinKeys, rightJoinKeys, isNullEQ, _ := p.GetJoinKeys() + leftNAJoinKeys, rightNAJoinKeys := p.GetNAJoinKeys() baseJoin := basePhysicalJoin{ LeftConditions: p.LeftConditions, RightConditions: p.RightConditions, OtherConditions: p.OtherConditions, LeftJoinKeys: leftJoinKeys, RightJoinKeys: rightJoinKeys, + // NA join keys + LeftNAJoinKeys: leftNAJoinKeys, + RightNAJoinKeys: rightNAJoinKeys, IsNullEQ: isNullEQ, JoinType: p.JoinType, DefaultValues: p.DefaultValues, InnerChildIdx: innerIdx, } hashJoin := PhysicalHashJoin{ - basePhysicalJoin: baseJoin, - EqualConditions: p.EqualConditions, - Concurrency: uint(p.ctx.GetSessionVars().HashJoinConcurrency()), - UseOuterToBuild: useOuterToBuild, + basePhysicalJoin: baseJoin, + EqualConditions: p.EqualConditions, + NAEqualConditions: p.NAEQConditions, + Concurrency: uint(p.ctx.GetSessionVars().HashJoinConcurrency()), + UseOuterToBuild: useOuterToBuild, }.Init(p.ctx, newStats, p.blockOffset, prop...) return hashJoin } diff --git a/planner/core/plan_cost.go b/planner/core/plan_cost.go index 33ea0e64b2776..3b79b17b1f48d 100644 --- a/planner/core/plan_cost.go +++ b/planner/core/plan_cost.go @@ -775,7 +775,7 @@ func (p *PhysicalApply) GetCost(lCount, rCount, lCost, rCost float64) float64 { cpuCost += lCount * rCount * sessVars.GetCPUFactor() rCount *= SelectionFactor } - if len(p.EqualConditions)+len(p.OtherConditions) > 0 { + if len(p.EqualConditions)+len(p.OtherConditions)+len(p.NAEqualConditions) > 0 { if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { cpuCost += lCount * rCount * sessVars.GetCPUFactor() * 0.5 @@ -904,6 +904,7 @@ func (p *PhysicalHashJoin) GetCost(lCnt, rCnt float64, isMPP bool, costFlag uint diskCost := buildCnt * sessVars.GetDiskFactor() * rowSize // Number of matched row pairs regarding the equal join conditions. helper := &fullJoinRowCountHelper{ +<<<<<<< HEAD cartesian: false, leftProfile: p.children[0].statsInfo(), rightProfile: p.children[1].statsInfo(), @@ -911,6 +912,18 @@ func (p *PhysicalHashJoin) GetCost(lCnt, rCnt float64, isMPP bool, costFlag uint rightJoinKeys: p.RightJoinKeys, leftSchema: p.children[0].Schema(), rightSchema: p.children[1].Schema(), +======= + sctx: p.SCtx(), + cartesian: false, + leftProfile: p.children[0].statsInfo(), + rightProfile: p.children[1].statsInfo(), + leftJoinKeys: p.LeftJoinKeys, + rightJoinKeys: p.RightJoinKeys, + leftSchema: p.children[0].Schema(), + rightSchema: p.children[1].Schema(), + leftNAJoinKeys: p.LeftNAJoinKeys, + rightNAJoinKeys: p.RightNAJoinKeys, +>>>>>>> 0823fdb6b... planner, executor: implement the null-aware antiSemiJoin and null-aware antiLeftOuterSemiJoin (hash join with inner build) (#37512) } numPairs := helper.estimate() // For semi-join class, if `OtherConditions` is empty, we already know diff --git a/planner/core/plan_to_pb.go b/planner/core/plan_to_pb.go index fbff431562545..a59b7dfe62062 100644 --- a/planner/core/plan_to_pb.go +++ b/planner/core/plan_to_pb.go @@ -377,6 +377,7 @@ func (p *PhysicalIndexScan) ToPB(ctx sessionctx.Context, _ kv.StoreType) (*tipb. func (p *PhysicalHashJoin) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { sc := ctx.GetSessionVars().StmtCtx client := ctx.GetClient() + // todo: mpp na-key toPB. leftJoinKeys := make([]expression.Expression, 0, len(p.LeftJoinKeys)) rightJoinKeys := make([]expression.Expression, 0, len(p.RightJoinKeys)) for _, leftKey := range p.LeftJoinKeys { @@ -465,6 +466,7 @@ func (p *PhysicalHashJoin) ToPB(ctx sessionctx.Context, storeType kv.StoreType) probeFiledTypes = append(probeFiledTypes, ty) buildFiledTypes = append(buildFiledTypes, ty) } + // todo: arenatlx, push down hash join join := &tipb.Join{ JoinType: pbJoinType, JoinExecType: tipb.JoinExecType_TypeHashJoin, diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index 7b5b4261a281c..483d0b9f92299 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -95,6 +95,19 @@ func (p *PhysicalHashJoin) ResolveIndices() (err error) { p.RightJoinKeys[i] = rArg.(*expression.Column) p.EqualConditions[i] = expression.NewFunctionInternal(fun.GetCtx(), fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) } + for i, fun := range p.NAEqualConditions { + lArg, err := fun.GetArgs()[0].ResolveIndices(lSchema) + if err != nil { + return err + } + p.LeftNAJoinKeys[i] = lArg.(*expression.Column) + rArg, err := fun.GetArgs()[1].ResolveIndices(rSchema) + if err != nil { + return err + } + p.RightNAJoinKeys[i] = rArg.(*expression.Column) + p.NAEqualConditions[i] = expression.NewFunctionInternal(fun.GetCtx(), fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) + } for i, expr := range p.LeftConditions { p.LeftConditions[i], err = expr.ResolveIndices(lSchema) if err != nil { @@ -567,6 +580,13 @@ func (p *PhysicalApply) ResolveIndices() (err error) { } p.PhysicalHashJoin.EqualConditions[i] = newSf.(*expression.ScalarFunction) } + for i, cond := range p.PhysicalHashJoin.NAEqualConditions { + newSf, err := cond.ResolveIndices(joinedSchema) + if err != nil { + return err + } + p.PhysicalHashJoin.NAEqualConditions[i] = newSf.(*expression.ScalarFunction) + } return } diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index b8df243bd2299..e477ba41fd170 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -413,6 +413,9 @@ func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (left for _, otherCond := range p.OtherConditions { parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...) } + for _, naeqCond := range p.NAEQConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(naeqCond)...) + } lChild := p.children[0] rChild := p.children[1] for _, col := range parentUsedCols { diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index 59a29d788d65a..e4b74a865728d 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -249,6 +249,9 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt func (p *LogicalJoin) updateEQCond() { lChild, rChild := p.children[0], p.children[1] var lKeys, rKeys []expression.Expression + var lNAKeys, rNAKeys []expression.Expression + // We need two steps here: + // step1: try best to extract normal EQ condition from OtherCondition to join EqualConditions. for i := len(p.OtherConditions) - 1; i >= 0; i-- { need2Remove := false if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { @@ -273,33 +276,78 @@ func (p *LogicalJoin) updateEQCond() { p.OtherConditions = append(p.OtherConditions[:i], p.OtherConditions[i+1:]...) } } - if len(lKeys) > 0 { - needLProj, needRProj := false, false - for i := range lKeys { - _, lOk := lKeys[i].(*expression.Column) - _, rOk := rKeys[i].(*expression.Column) - needLProj = needLProj || !lOk - needRProj = needRProj || !rOk - } + // eg: explain select * from t1, t3 where t1.a+1 = t3.a; + // tidb only accept the join key in EqualCondition as a normal column (join OP take granted for that) + // so once we found the left and right children's schema can supply the all columns in complicated EQ condition that used by left/right key. + // we will add a layer of projection here to convert the complicated expression of EQ's left or right side to be a normal column. + adjustKeyForm := func(leftKeys, rightKeys []expression.Expression, isNA bool) { + if len(leftKeys) > 0 { + needLProj, needRProj := false, false + for i := range leftKeys { + _, lOk := leftKeys[i].(*expression.Column) + _, rOk := rightKeys[i].(*expression.Column) + needLProj = needLProj || !lOk + needRProj = needRProj || !rOk + } - var lProj, rProj *LogicalProjection - if needLProj { - lProj = p.getProj(0) - } - if needRProj { - rProj = p.getProj(1) + var lProj, rProj *LogicalProjection + if needLProj { + lProj = p.getProj(0) + } + if needRProj { + rProj = p.getProj(1) + } + for i := range leftKeys { + lKey, rKey := leftKeys[i], rightKeys[i] + if lProj != nil { + lKey = lProj.appendExpr(lKey) + } + if rProj != nil { + rKey = rProj.appendExpr(rKey) + } + eqCond := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) + if isNA { + p.NAEQConditions = append(p.NAEQConditions, eqCond.(*expression.ScalarFunction)) + } else { + p.EqualConditions = append(p.EqualConditions, eqCond.(*expression.ScalarFunction)) + } + } } - for i := range lKeys { - lKey, rKey := lKeys[i], rKeys[i] - if lProj != nil { - lKey = lProj.appendExpr(lKey) + } + adjustKeyForm(lKeys, rKeys, false) + + // Step2: when step1 is finished, then we can determine whether we need to extract NA-EQ from OtherCondition to NAEQConditions. + // when there are still no EqualConditions, let's try to be a NAAJ. + // todo: by now, when there is already a normal EQ condition, just keep NA-EQ as other-condition filters above it. + // eg: select * from stu where stu.name not in (select name from exam where exam.stu_id = stu.id); + // combination of and for join key is little complicated for now. + canBeNAAJ := (p.JoinType == AntiSemiJoin || p.JoinType == AntiLeftOuterSemiJoin) && len(p.EqualConditions) == 0 + if canBeNAAJ && p.SCtx().GetSessionVars().OptimizerEnableNAAJ { + for i := len(p.OtherConditions) - 1; i >= 0; i-- { + need2Remove := false + if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { + // not a naaj operator, continue. + if !expression.IsEQCondFromIn(eqCond) { + continue + } + // here must be a EQCondFromIn. + lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] + if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { + lNAKeys = append(lNAKeys, lExpr) + rNAKeys = append(rNAKeys, rExpr) + need2Remove = true + } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { + lNAKeys = append(lNAKeys, rExpr) + rNAKeys = append(rNAKeys, lExpr) + need2Remove = true + } } - if rProj != nil { - rKey = rProj.appendExpr(rKey) + if need2Remove { + p.OtherConditions = append(p.OtherConditions[:i], p.OtherConditions[i+1:]...) } - eqCond := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) - p.EqualConditions = append(p.EqualConditions, eqCond.(*expression.ScalarFunction)) } + // here is for cases like: select (a+1, b*3) not in (select a,b from t2) from t1. + adjustKeyForm(lNAKeys, rNAKeys, true) } } diff --git a/planner/core/stats.go b/planner/core/stats.go index 216ec5112166a..3d89ce0ea29da 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -1111,14 +1111,29 @@ type fullJoinRowCountHelper struct { rightJoinKeys []*expression.Column leftSchema *expression.Schema rightSchema *expression.Schema + + leftNAJoinKeys []*expression.Column + rightNAJoinKeys []*expression.Column } func (h *fullJoinRowCountHelper) estimate() float64 { if h.cartesian { return h.leftProfile.RowCount * h.rightProfile.RowCount } +<<<<<<< HEAD leftKeyNDV := getColsNDV(h.leftJoinKeys, h.leftSchema, h.leftProfile) rightKeyNDV := getColsNDV(h.rightJoinKeys, h.rightSchema, h.rightProfile) +======= + var leftKeyNDV, rightKeyNDV float64 + var leftColCnt, rightColCnt int + if len(h.leftJoinKeys) > 0 || len(h.rightJoinKeys) > 0 { + leftKeyNDV, leftColCnt = getColsNDVWithMatchedLen(h.leftJoinKeys, h.leftSchema, h.leftProfile) + rightKeyNDV, rightColCnt = getColsNDVWithMatchedLen(h.rightJoinKeys, h.rightSchema, h.rightProfile) + } else { + leftKeyNDV, leftColCnt = getColsNDVWithMatchedLen(h.leftNAJoinKeys, h.leftSchema, h.leftProfile) + rightKeyNDV, rightColCnt = getColsNDVWithMatchedLen(h.rightNAJoinKeys, h.rightSchema, h.rightProfile) + } +>>>>>>> 0823fdb6b... planner, executor: implement the null-aware antiSemiJoin and null-aware antiLeftOuterSemiJoin (hash join with inner build) (#37512) count := h.leftProfile.RowCount * h.rightProfile.RowCount / math.Max(leftKeyNDV, rightKeyNDV) return count } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 3e07917a4d650..1e1bdd324db29 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -840,6 +840,9 @@ type SessionVars struct { // EnableOuterJoinWithJoinReorder enables TiDB to involve the outer join into the join reorder. EnableOuterJoinReorder bool + // OptimizerEnableNAAJ enables TiDB to use null-aware anti join. + OptimizerEnableNAAJ bool + // EnableTablePartition enables table partition feature. EnableTablePartition string diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index d286e40a82c33..6b23921c3a58e 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -220,6 +220,10 @@ var defaultSysVars = []*SysVar{ s.EnableOuterJoinReorder = TiDBOptOn(val) return nil }}, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBOptimizerEnableNAAJ, Value: BoolToOnOff(DefTiDBEnableNAAJ), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { + s.OptimizerEnableNAAJ = TiDBOptOn(val) + return nil + }}, {Scope: ScopeSession, Name: TiDBDDLReorgPriority, Value: "PRIORITY_LOW", Type: TypeEnum, skipInit: true, PossibleValues: []string{"PRIORITY_LOW", "PRIORITY_NORMAL", "PRIORITY_HIGH"}, SetSession: func(s *SessionVars, val string) error { s.setDDLReorgPriority(val) return nil diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 49d4da41dc7d7..2cdf4b817298e 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -144,6 +144,9 @@ const ( TiDBOptimizerEnableOuterJoinReorder = "tidb_enable_outer_join_reorder" + // TiDBOptimizerEnableNAAJ is used to open the newly null-aware anti join + TiDBOptimizerEnableNAAJ = "tidb_enable_null_aware_anti_join" + // TiDBTxnMode is used to control the transaction behavior. TiDBTxnMode = "tidb_txn_mode" @@ -868,6 +871,7 @@ const ( DefTiDBOptimizerSelectivityLevel = 0 DefTiDBOptimizerEnableNewOFGB = false DefTiDBEnableOuterJoinReorder = false + DefTiDBEnableNAAJ = false DefTiDBAllowBatchCop = 1 DefTiDBAllowMPPExecution = true DefTiDBHashExchangeWithNewCollation = true diff --git a/util/bitmap/concurrent.go b/util/bitmap/concurrent.go index 1ca214588cf78..7bcb9d83cc657 100644 --- a/util/bitmap/concurrent.go +++ b/util/bitmap/concurrent.go @@ -37,6 +37,16 @@ type ConcurrentBitmap struct { bitLen int } +// Clone clones a new bitmap with the old bit set. +func (cb *ConcurrentBitmap) Clone() *ConcurrentBitmap { + cp := NewConcurrentBitmap(cb.bitLen) + needLen := len(cp.segments) + for i := 0; i < needLen; i++ { + cp.segments[i] = cb.segments[i] + } + return cp +} + // NewConcurrentBitmap initializes a ConcurrentBitmap which can store // bitLen of bits. func NewConcurrentBitmap(bitLen int) *ConcurrentBitmap { @@ -47,6 +57,20 @@ func NewConcurrentBitmap(bitLen int) *ConcurrentBitmap { } } +// Reset clean the bitmap if the length is suitable, otherwise renewing one. +func (cb *ConcurrentBitmap) Reset(bitLen int) { + segmentLen := (bitLen + segmentWidth - 1) >> segmentWidthPower + if segmentLen <= len(cb.segments) { + for i := range cb.segments { + cb.segments[i] = 0 + } + cb.bitLen = bitLen + } else { + cb.segments = make([]uint32, segmentLen) + cb.bitLen = bitLen + } +} + // BytesConsumed returns size of this bitmap in bytes. func (cb *ConcurrentBitmap) BytesConsumed() int64 { return bytesConcurrentBitmap + int64(segmentWidth/8*cap(cb.segments)) @@ -81,6 +105,19 @@ func (cb *ConcurrentBitmap) Set(bitIndex int) (isSetter bool) { } } +// UnsafeSet sets the bit on bitIndex to be 1 (bitIndex starts from 0). +// isSetter indicates whether the function call this time triggers the bit from 0 to 1. +// bitIndex bigger than bitLen initialized will be ignored. +// (this version is concurrent unsafe if the caller can make sure write is in single thread) +func (cb *ConcurrentBitmap) UnsafeSet(bitIndex int) { + if bitIndex < 0 || bitIndex >= cb.bitLen { + return + } + + mask := bitMask >> uint32(bitIndex%segmentWidth) + cb.segments[bitIndex>>segmentWidthPower] = cb.segments[bitIndex>>segmentWidthPower] | mask +} + // UnsafeIsSet returns if a bit on bitIndex is set (bitIndex starts from 0). // bitIndex bigger than bitLen initialized will return false. // This method is not thread-safe as it does not use atomic load. diff --git a/util/bitmap/concurrent_test.go b/util/bitmap/concurrent_test.go index 958fe57b20d23..f7d2c8acba949 100644 --- a/util/bitmap/concurrent_test.go +++ b/util/bitmap/concurrent_test.go @@ -77,3 +77,17 @@ func TestConcurrentBitmapUniqueSetter(t *testing.T) { assert.Less(t, clearCounter, uint64(loopCount)) assert.Equal(t, setterCounter, clearCounter+1) } + +// TestResetConcurrentBitmap test the reset of concurrentBitmap. +func TestResetConcurrentBitmap(t *testing.T) { + bm := NewConcurrentBitmap(32) + bm.Set(1) + bm.Set(3) + bm.Set(7) + bm.Set(16) + bm.Reset(8) + assert.Equal(t, bm.bitLen, 8) + assert.Equal(t, bm.UnsafeIsSet(1), false) + assert.Equal(t, bm.UnsafeIsSet(3), false) + assert.Equal(t, bm.UnsafeIsSet(7), false) +} From d8a8e0f77dbe9a32980178cd378d0c72baa813fb Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Tue, 20 Sep 2022 14:52:35 +0800 Subject: [PATCH 2/4] resolve conflict Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/plan_cost.go | 11 ----------- planner/core/stats.go | 14 ++++---------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/planner/core/plan_cost.go b/planner/core/plan_cost.go index 3b79b17b1f48d..d5f497af1d4f7 100644 --- a/planner/core/plan_cost.go +++ b/planner/core/plan_cost.go @@ -904,16 +904,6 @@ func (p *PhysicalHashJoin) GetCost(lCnt, rCnt float64, isMPP bool, costFlag uint diskCost := buildCnt * sessVars.GetDiskFactor() * rowSize // Number of matched row pairs regarding the equal join conditions. helper := &fullJoinRowCountHelper{ -<<<<<<< HEAD - cartesian: false, - leftProfile: p.children[0].statsInfo(), - rightProfile: p.children[1].statsInfo(), - leftJoinKeys: p.LeftJoinKeys, - rightJoinKeys: p.RightJoinKeys, - leftSchema: p.children[0].Schema(), - rightSchema: p.children[1].Schema(), -======= - sctx: p.SCtx(), cartesian: false, leftProfile: p.children[0].statsInfo(), rightProfile: p.children[1].statsInfo(), @@ -923,7 +913,6 @@ func (p *PhysicalHashJoin) GetCost(lCnt, rCnt float64, isMPP bool, costFlag uint rightSchema: p.children[1].Schema(), leftNAJoinKeys: p.LeftNAJoinKeys, rightNAJoinKeys: p.RightNAJoinKeys, ->>>>>>> 0823fdb6b... planner, executor: implement the null-aware antiSemiJoin and null-aware antiLeftOuterSemiJoin (hash join with inner build) (#37512) } numPairs := helper.estimate() // For semi-join class, if `OtherConditions` is empty, we already know diff --git a/planner/core/stats.go b/planner/core/stats.go index 3d89ce0ea29da..2d216fd6c8690 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -1120,20 +1120,14 @@ func (h *fullJoinRowCountHelper) estimate() float64 { if h.cartesian { return h.leftProfile.RowCount * h.rightProfile.RowCount } -<<<<<<< HEAD - leftKeyNDV := getColsNDV(h.leftJoinKeys, h.leftSchema, h.leftProfile) - rightKeyNDV := getColsNDV(h.rightJoinKeys, h.rightSchema, h.rightProfile) -======= var leftKeyNDV, rightKeyNDV float64 - var leftColCnt, rightColCnt int if len(h.leftJoinKeys) > 0 || len(h.rightJoinKeys) > 0 { - leftKeyNDV, leftColCnt = getColsNDVWithMatchedLen(h.leftJoinKeys, h.leftSchema, h.leftProfile) - rightKeyNDV, rightColCnt = getColsNDVWithMatchedLen(h.rightJoinKeys, h.rightSchema, h.rightProfile) + leftKeyNDV = getColsNDV(h.leftJoinKeys, h.leftSchema, h.leftProfile) + rightKeyNDV = getColsNDV(h.rightJoinKeys, h.rightSchema, h.rightProfile) } else { - leftKeyNDV, leftColCnt = getColsNDVWithMatchedLen(h.leftNAJoinKeys, h.leftSchema, h.leftProfile) - rightKeyNDV, rightColCnt = getColsNDVWithMatchedLen(h.rightNAJoinKeys, h.rightSchema, h.rightProfile) + leftKeyNDV = getColsNDV(h.leftNAJoinKeys, h.leftSchema, h.leftProfile) + rightKeyNDV = getColsNDV(h.rightNAJoinKeys, h.rightSchema, h.rightProfile) } ->>>>>>> 0823fdb6b... planner, executor: implement the null-aware antiSemiJoin and null-aware antiLeftOuterSemiJoin (hash join with inner build) (#37512) count := h.leftProfile.RowCount * h.rightProfile.RowCount / math.Max(leftKeyNDV, rightKeyNDV) return count } From 5e8e1d9105e916985ec854825ff92dd0a456ea3f Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 22 Sep 2022 15:03:03 +0800 Subject: [PATCH 3/4] fix linter Signed-off-by: AilinKid <314806019@qq.com> --- .golangci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 91ca980280857..9b8e3507d88a1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,7 +10,6 @@ linters: - structcheck - deadcode - gosimple - - goimports - errcheck - staticcheck - stylecheck From 3f4b6859fbc54acd602376d8e0ef1932c6c6e860 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 22 Sep 2022 15:27:19 +0800 Subject: [PATCH 4/4] parser fmt Signed-off-by: AilinKid <314806019@qq.com> --- parser/ast/ddl.go | 10 ++++--- parser/ast/dml.go | 32 ++++++++++++--------- parser/ast/misc.go | 11 ++++--- parser/auth/mysql_native_password.go | 27 ++++++++--------- parser/format/format.go | 32 ++++++++++++++------- parser/goyacc/main.go | 20 ++++++------- parser/model/model.go | 5 ++-- parser/test_driver/test_driver_mydecimal.go | 13 ++++----- parser/types/etc.go | 4 ++- 9 files changed, 88 insertions(+), 66 deletions(-) diff --git a/parser/ast/ddl.go b/parser/ast/ddl.go index 52d489a89cf85..a55a6b0dedcb1 100644 --- a/parser/ast/ddl.go +++ b/parser/ast/ddl.go @@ -670,10 +670,12 @@ const ( ) // IndexOption is the index options. -// KEY_BLOCK_SIZE [=] value -// | index_type -// | WITH PARSER parser_name -// | COMMENT 'string' +// +// KEY_BLOCK_SIZE [=] value +// | index_type +// | WITH PARSER parser_name +// | COMMENT 'string' +// // See http://dev.mysql.com/doc/refman/5.7/en/create-table.html type IndexOption struct { node diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 93ee525d30bf5..55921995cb0f7 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -93,25 +93,31 @@ func (*Join) resultSet() {} // NewCrossJoin builds a cross join without `on` or `using` clause. // If the right child is a join tree, we need to handle it differently to make the precedence get right. // Here is the example: t1 join t2 join t3 -// JOIN ON t2.a = t3.a -// t1 join / \ -// t2 t3 +// +// JOIN ON t2.a = t3.a +// t1 join / \ +// t2 t3 +// // (left) (right) // // We can not build it directly to: -// JOIN -// / \ -// t1 JOIN ON t2.a = t3.a -// / \ -// t2 t3 +// +// JOIN +// / \ +// t1 JOIN ON t2.a = t3.a +// / \ +// t2 t3 +// // The precedence would be t1 join (t2 join t3 on t2.a=t3.a), not (t1 join t2) join t3 on t2.a=t3.a // We need to find the left-most child of the right child, and build a cross join of the left-hand side // of the left child(t1), and the right hand side with the original left-most child of the right child(t2). -// JOIN t2.a = t3.a -// / \ -// JOIN t3 -// / \ -// t1 t2 +// +// JOIN t2.a = t3.a +// / \ +// JOIN t3 +// / \ +// t1 t2 +// // Besides, if the right handle side join tree's join type is right join and has explicit parentheses, we need to rewrite it to left join. // So t1 join t2 right join t3 would be rewrite to t1 join t3 left join t2. // If not, t1 join (t2 right join t3) would be (t1 join t2) right join t3. After rewrite the right join to left join. diff --git a/parser/ast/misc.go b/parser/ast/misc.go index 307242273df7d..2cef9869575e9 100644 --- a/parser/ast/misc.go +++ b/parser/ast/misc.go @@ -1885,9 +1885,10 @@ type StatisticsSpec struct { // CreateStatisticsStmt is a statement to create extended statistics. // Examples: -// CREATE STATISTICS stats1 (cardinality) ON t(a, b, c); -// CREATE STATISTICS stats2 (dependency) ON t(a, b); -// CREATE STATISTICS stats3 (correlation) ON t(a, b); +// +// CREATE STATISTICS stats1 (cardinality) ON t(a, b, c); +// CREATE STATISTICS stats2 (dependency) ON t(a, b); +// CREATE STATISTICS stats3 (correlation) ON t(a, b); type CreateStatisticsStmt struct { stmtNode @@ -1955,7 +1956,8 @@ func (n *CreateStatisticsStmt) Accept(v Visitor) (Node, bool) { // DropStatisticsStmt is a statement to drop extended statistics. // Examples: -// DROP STATISTICS stats1; +// +// DROP STATISTICS stats1; type DropStatisticsStmt struct { stmtNode @@ -2086,6 +2088,7 @@ const ( ) // ShowSlow is used for the following command: +// // admin show slow top [ internal | all] N // admin show slow recent N type ShowSlow struct { diff --git a/parser/auth/mysql_native_password.go b/parser/auth/mysql_native_password.go index 05c6127c21991..2bfc1a8190667 100644 --- a/parser/auth/mysql_native_password.go +++ b/parser/auth/mysql_native_password.go @@ -25,19 +25,20 @@ import ( // CheckScrambledPassword check scrambled password received from client. // The new authentication is performed in following manner: -// SERVER: public_seed=create_random_string() -// send(public_seed) -// CLIENT: recv(public_seed) -// hash_stage1=sha1("password") -// hash_stage2=sha1(hash_stage1) -// reply=xor(hash_stage1, sha1(public_seed,hash_stage2) -// // this three steps are done in scramble() -// send(reply) -// SERVER: recv(reply) -// hash_stage1=xor(reply, sha1(public_seed,hash_stage2)) -// candidate_hash2=sha1(hash_stage1) -// check(candidate_hash2==hash_stage2) -// // this three steps are done in check_scramble() +// +// SERVER: public_seed=create_random_string() +// send(public_seed) +// CLIENT: recv(public_seed) +// hash_stage1=sha1("password") +// hash_stage2=sha1(hash_stage1) +// reply=xor(hash_stage1, sha1(public_seed,hash_stage2) +// // this three steps are done in scramble() +// send(reply) +// SERVER: recv(reply) +// hash_stage1=xor(reply, sha1(public_seed,hash_stage2)) +// candidate_hash2=sha1(hash_stage1) +// check(candidate_hash2==hash_stage2) +// // this three steps are done in check_scramble() func CheckScrambledPassword(salt, hpwd, auth []byte) bool { //nolint: gosec crypt := sha1.New() diff --git a/parser/format/format.go b/parser/format/format.go index 5c9c137c8fa27..adada122e255e 100644 --- a/parser/format/format.go +++ b/parser/format/format.go @@ -56,21 +56,28 @@ var replace = map[rune]string{ // nest. The Formatter writes to io.Writer 'w' and inserts one 'indent' // string per current indent level value. // Behaviour of commands reaching negative indent levels is undefined. -// IndentFormatter(os.Stdout, "\t").Format("abc%d%%e%i\nx\ny\n%uz\n", 3) +// +// IndentFormatter(os.Stdout, "\t").Format("abc%d%%e%i\nx\ny\n%uz\n", 3) +// // output: -// abc3%e -// x -// y -// z +// +// abc3%e +// x +// y +// z +// // The Go quoted string literal form of the above is: -// "abc%%e\n\tx\n\tx\nz\n" +// +// "abc%%e\n\tx\n\tx\nz\n" +// // The commands can be scattered between separate invocations of Format(), // i.e. the formatter keeps track of the indent level and knows if it is // positioned on start of a line and should emit indentation(s). // The same output as above can be produced by e.g.: -// f := IndentFormatter(os.Stdout, " ") -// f.Format("abc%d%%e%i\nx\n", 3) -// f.Format("y\n%uz\n") +// +// f := IndentFormatter(os.Stdout, " ") +// f.Format("abc%d%%e%i\nx\n", 3) +// f.Format("y\n%uz\n") func IndentFormatter(w io.Writer, indent string) Formatter { return &indentFormatter{w, []byte(indent), 0, stBOL} } @@ -169,9 +176,12 @@ type flatFormatter indentFormatter // // The FlatFormatter is intended for flattening of normally nested structure textual representation to // a one top level structure per line form. -// FlatFormatter(os.Stdout, " ").Format("abc%d%%e%i\nx\ny\n%uz\n", 3) +// +// FlatFormatter(os.Stdout, " ").Format("abc%d%%e%i\nx\ny\n%uz\n", 3) +// // output in the form of a Go quoted string literal: -// "abc3%%e x y z\n" +// +// "abc3%%e x y z\n" func FlatFormatter(w io.Writer) Formatter { return (*flatFormatter)(IndentFormatter(w, "").(*indentFormatter)) } diff --git a/parser/goyacc/main.go b/parser/goyacc/main.go index 93fc90efb3afe..cc7589773baa1 100644 --- a/parser/goyacc/main.go +++ b/parser/goyacc/main.go @@ -21,7 +21,7 @@ // Goyacc is a version of yacc generating Go parsers. // -// Usage +// # Usage // // Note: If no non flag arguments are given, goyacc reads standard input. // @@ -42,9 +42,7 @@ // -xegen examplesFile Generate a file suitable for -xe automatically from the grammar. // The file must not exist. ("") // -// -// -// Changelog +// # Changelog // // 2015-03-24: The search for a custom error message is now extended to include // also the last state that was shifted into, if any. This change resolves a @@ -70,7 +68,7 @@ // by parsing code fragments. If it returns true the parser exits immediately // with return value -1. // -// Overview +// # Overview // // The generated parser is reentrant and mostly backwards compatible with // parsers generated by go tool yacc[0]. yyParse expects to be given an @@ -104,7 +102,7 @@ // generated code. Setting it to distinct values allows multiple grammars to be // placed in a single package. // -// Differences wrt go tool yacc +// # Differences wrt go tool yacc // // - goyacc implements ideas from "Generating LR Syntax Error Messages from // Examples"[1]. Use the -xe flag to pass a name of the example file. For more @@ -115,14 +113,14 @@ // // - Minor changes in parser debug output. // -// Links +// # Links // // Referenced from elsewhere: // -// [0]: http://golang.org/cmd/yacc/ -// [1]: http://people.via.ecp.fr/~stilgar/doc/compilo/parser/Generating%20LR%20Syntax%20Error%20Messages.pdf -// [2]: http://godoc.org/github.com/cznic/y#hdr-Error_Examples -// [3]: http://www.gnu.org/software/bison/manual/html_node/Precedence-Only.html#Precedence-Only +// [0]: http://golang.org/cmd/yacc/ +// [1]: http://people.via.ecp.fr/~stilgar/doc/compilo/parser/Generating%20LR%20Syntax%20Error%20Messages.pdf +// [2]: http://godoc.org/github.com/cznic/y#hdr-Error_Examples +// [3]: http://www.gnu.org/software/bison/manual/html_node/Precedence-Only.html#Precedence-Only package main import ( diff --git a/parser/model/model.go b/parser/model/model.go index a5b61033e6716..4ff1d70b6fec7 100644 --- a/parser/model/model.go +++ b/parser/model/model.go @@ -1019,8 +1019,9 @@ func (v *ViewCheckOption) String() string { } } -//revive:disable:exported // ViewInfo provides meta data describing a DB view. +// +//revive:disable:exported type ViewInfo struct { Algorithm ViewAlgorithm `json:"view_algorithm"` Definer *auth.UserIdentity `json:"view_definer"` @@ -1250,7 +1251,7 @@ func (i *IndexColumn) Clone() *IndexColumn { } // PrimaryKeyType is the type of primary key. -// Available values are 'clustered', 'nonclustered', and ''(default). +// Available values are 'clustered', 'nonclustered', and ”(default). type PrimaryKeyType int8 func (p PrimaryKeyType) String() string { diff --git a/parser/test_driver/test_driver_mydecimal.go b/parser/test_driver/test_driver_mydecimal.go index 9632cf6db5134..91bd04486689e 100644 --- a/parser/test_driver/test_driver_mydecimal.go +++ b/parser/test_driver/test_driver_mydecimal.go @@ -40,10 +40,10 @@ func fixWordCntError(wordsInt, wordsFrac int) (newWordsInt int, newWordsFrac int } /* - countLeadingZeroes returns the number of leading zeroes that can be removed from fraction. +countLeadingZeroes returns the number of leading zeroes that can be removed from fraction. - @param i start index - @param word value to compare against list of powers of 10 +@param i start index +@param word value to compare against list of powers of 10 */ func countLeadingZeroes(i int, word int32) int { leading := 0 @@ -102,11 +102,10 @@ func (d *MyDecimal) removeLeadingZeros() (wordIdx int, digitsInt int) { // ToString converts decimal to its printable string representation without rounding. // -// RETURN VALUE -// -// str - result string -// errCode - eDecOK/eDecTruncate/eDecOverflow +// RETURN VALUE // +// str - result string +// errCode - eDecOK/eDecTruncate/eDecOverflow func (d *MyDecimal) ToString() (str []byte) { str = make([]byte, d.stringSize()) digitsFrac := int(d.digitsFrac) diff --git a/parser/types/etc.go b/parser/types/etc.go index 1fdfeaf05367f..2c07f57f35876 100644 --- a/parser/types/etc.go +++ b/parser/types/etc.go @@ -109,6 +109,7 @@ func TypeStr(tp byte) (r string) { // It is used for converting Text to Blob, // or converting Char to Binary. // Args: +// // tp: type enum // cs: charset func TypeToStr(tp byte, cs string) (r string) { @@ -126,7 +127,8 @@ func TypeToStr(tp byte, cs string) (r string) { // StrToType convert a string to type enum. // Args: -// ts: type string +// +// ts: type string func StrToType(ts string) (tp byte) { ts = strings.Replace(ts, "blob", "text", 1) ts = strings.Replace(ts, "binary", "char", 1)