diff --git a/fe/src/main/java/org/apache/doris/analysis/Analyzer.java b/fe/src/main/java/org/apache/doris/analysis/Analyzer.java index f59f48dbc3aad4..1f565ceae98841 100644 --- a/fe/src/main/java/org/apache/doris/analysis/Analyzer.java +++ b/fe/src/main/java/org/apache/doris/analysis/Analyzer.java @@ -433,6 +433,10 @@ public TupleDescriptor registerTableRef(TableRef ref) throws AnalysisException { return result; } + public List getAllTupleIds() { + return new ArrayList<>(tableRefMap_.keySet()); + } + /** * Resolves the given TableRef into a concrete BaseTableRef, ViewRef or * CollectionTableRef. Returns the new resolved table ref or the given table @@ -950,6 +954,29 @@ public List getUnassignedConjuncts( return result; } + + /** + * Return all registered conjuncts that are fully bound by + * given list of tuple ids, the eqJoinConjuncts and inclOjConjuncts is excluded. + */ + public List getConjuncts(List tupleIds) { + List result = Lists.newArrayList(); + List eqJoinConjunctIds = Lists.newArrayList(); + for (List conjuncts : globalState.eqJoinConjuncts.values()) { + eqJoinConjunctIds.addAll(conjuncts); + } + for (Expr e : globalState.conjuncts.values()) { + if (e.isBoundByTupleIds(tupleIds) + && !e.isAuxExpr() + && !eqJoinConjunctIds.contains(e.getId()) + && !globalState.ojClauseByConjunct.containsKey(e.getId()) + && canEvalPredicate(tupleIds, e)) { + result.add(e); + } + } + return result; + } + /** * Return all unassigned registered conjuncts that are fully bound by given * list of tuple ids diff --git a/fe/src/main/java/org/apache/doris/analysis/InPredicate.java b/fe/src/main/java/org/apache/doris/analysis/InPredicate.java index 73ff516890635f..50ecfd0ec75823 100644 --- a/fe/src/main/java/org/apache/doris/analysis/InPredicate.java +++ b/fe/src/main/java/org/apache/doris/analysis/InPredicate.java @@ -125,6 +125,10 @@ public Expr negate() { !isNotIn); } + public List getListChildren() { + return children.subList(1, children.size()); + } + public boolean isNotIn() { return isNotIn; } diff --git a/fe/src/main/java/org/apache/doris/analysis/Predicate.java b/fe/src/main/java/org/apache/doris/analysis/Predicate.java index 06b43fb29e9260..d345cdb3176bf1 100644 --- a/fe/src/main/java/org/apache/doris/analysis/Predicate.java +++ b/fe/src/main/java/org/apache/doris/analysis/Predicate.java @@ -17,6 +17,7 @@ package org.apache.doris.analysis; +import com.google.common.base.Preconditions; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Pair; @@ -96,6 +97,33 @@ public static boolean isEquivalencePredicate(Expr expr) { && ((BinaryPredicate) expr).getOp().isEquivalence(); } + public static boolean canPushDownPredicate(Expr expr) { + if (!(expr instanceof Predicate)) { + return false; + } + + if (((Predicate) expr).isSingleColumnPredicate(null, null)) { + if (expr instanceof BinaryPredicate) { + BinaryPredicate binPredicate = (BinaryPredicate) expr; + Expr right = binPredicate.getChild(1); + + // because isSingleColumnPredicate + Preconditions.checkState(right != null); + Preconditions.checkState(right.isConstant()); + + return right instanceof LiteralExpr; + } + + if (expr instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) expr; + return inPredicate.isLiteralChildren(); + } + } + + return false; + } + + /** * If predicate is of the form " = ", returns both SlotRefs, * otherwise returns null. diff --git a/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java b/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java index f4dcf68928f1e5..8c0727d3eb22de 100644 --- a/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java +++ b/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java @@ -1347,6 +1347,45 @@ private PlanNode createScanNode(Analyzer analyzer, TableRef tblRef, SelectStmt s if (scanNode instanceof OlapScanNode || scanNode instanceof EsScanNode) { Map columnFilters = Maps.newHashMap(); List conjuncts = analyzer.getUnassignedConjuncts(scanNode); + + // push down join predicate + List pushDownConjuncts = Lists.newArrayList(); + TupleId tupleId = tblRef.getId(); + List eqJoinPredicates = analyzer.getEqJoinConjuncts(tupleId); + if (eqJoinPredicates != null) { + // only inner and left outer join + if ((tblRef.getJoinOp().isInnerJoin() || tblRef.getJoinOp().isLeftOuterJoin())) { + List allConjuncts = analyzer.getConjuncts(analyzer.getAllTupleIds()); + allConjuncts.removeAll(conjuncts); + for (Expr conjunct: allConjuncts) { + if (org.apache.doris.analysis.Predicate.canPushDownPredicate(conjunct)) { + for (Expr eqJoinPredicate : eqJoinPredicates) { + // we can ensure slot is left node, because NormalizeBinaryPredicatesRule + SlotRef otherSlot = conjunct.getChild(0).unwrapSlotRef(); + + // ensure the children for eqJoinPredicate both be SlotRef + if (eqJoinPredicate.getChild(0).unwrapSlotRef() == null || eqJoinPredicate.getChild(1).unwrapSlotRef() == null) { + continue; + } + + SlotRef leftSlot = eqJoinPredicate.getChild(0).unwrapSlotRef(); + SlotRef rightSlot = eqJoinPredicate.getChild(1).unwrapSlotRef(); + + // example: t1.id = t2.id and t1.id = 1 => t2.id =1 + if (otherSlot.isBound(leftSlot.getSlotId()) && rightSlot.isBound(tupleId)) { + pushDownConjuncts.add(rewritePredicate(analyzer, conjunct, rightSlot)); + } else if (otherSlot.isBound(rightSlot.getSlotId()) && leftSlot.isBound(tupleId)) { + pushDownConjuncts.add(rewritePredicate(analyzer, conjunct, leftSlot)); + } + } + } + } + } + + LOG.debug("pushDownConjuncts: {}", pushDownConjuncts); + conjuncts.addAll(pushDownConjuncts); + } + for (Column column : tblRef.getTable().getBaseSchema()) { SlotDescriptor slotDesc = tblRef.getDesc().getColumnSlot(column.getName()); if (null == slotDesc) { @@ -1359,6 +1398,7 @@ private PlanNode createScanNode(Analyzer analyzer, TableRef tblRef, SelectStmt s } scanNode.setColumnFilters(columnFilters); scanNode.setSortColumn(tblRef.getSortColumn()); + scanNode.addConjuncts(pushDownConjuncts); } // assignConjuncts(scanNode, analyzer); scanNode.init(analyzer); @@ -1372,6 +1412,26 @@ private PlanNode createScanNode(Analyzer analyzer, TableRef tblRef, SelectStmt s return scanNode; } + // Rewrite the oldPredicate with new leftChild + // For example: oldPredicate is t1.id = 1, leftChild is t2.id, will return t2.id = 1 + private Expr rewritePredicate(Analyzer analyzer, Expr oldPredicate, Expr leftChild) { + if (oldPredicate instanceof BinaryPredicate) { + BinaryPredicate oldBP = (BinaryPredicate) oldPredicate; + BinaryPredicate bp = new BinaryPredicate(oldBP.getOp(), leftChild, oldBP.getChild(1)); + bp.analyzeNoThrow(analyzer); + return bp; + } + + if (oldPredicate instanceof InPredicate) { + InPredicate oldIP = (InPredicate) oldPredicate; + InPredicate ip = new InPredicate(leftChild, oldIP.getListChildren(), oldIP.isNotIn()); + ip.analyzeNoThrow(analyzer); + return ip; + } + + return oldPredicate; + } + /** * Return join conjuncts that can be used for hash table lookups. - for inner joins, those are equi-join predicates * in which one side is fully bound by lhsIds and the other by rhs' id; - for outer joins: same type of conjuncts as diff --git a/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java index 003637631bf52f..53a3c15bfb6972 100644 --- a/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java +++ b/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java @@ -98,6 +98,32 @@ public static void beforeClass() throws Exception { " \"replication_num\" = \"1\"\n" + ");"); + createTable("CREATE TABLE test.join1 (\n" + + " `dt` int(11) COMMENT \"\",\n" + + " `id` int(11) COMMENT \"\",\n" + + " `value` varchar(8) COMMENT \"\"\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(`dt`, `id`)\n" + + "PARTITION BY RANGE(`dt`)\n" + + "(PARTITION p1 VALUES LESS THAN (\"10\"))\n" + + "DISTRIBUTED BY HASH(`id`) BUCKETS 10\n" + + "PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + ");"); + + createTable("CREATE TABLE test.join2 (\n" + + " `dt` int(11) COMMENT \"\",\n" + + " `id` int(11) COMMENT \"\",\n" + + " `value` varchar(8) COMMENT \"\"\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(`dt`, `id`)\n" + + "PARTITION BY RANGE(`dt`)\n" + + "(PARTITION p1 VALUES LESS THAN (\"10\"))\n" + + "DISTRIBUTED BY HASH(`id`) BUCKETS 10\n" + + "PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + ");"); + createTable("CREATE TABLE test.bitmap_table_2 (\n" + " `id` int(11) NULL COMMENT \"\",\n" + " `id2` bitmap bitmap_union NULL\n" + @@ -504,4 +530,104 @@ public void testDateTypeEquality() throws Exception { Catalog.getCurrentCatalog().getLoadManager().createLoadJobV1FromStmt(loadStmt, EtlJobType.HADOOP, System.currentTimeMillis()); } + + @Test + public void testJoinPredicateTransitivity() throws Exception { + connectContext.setDatabase("default_cluster:test"); + + // test left join : left table where binary predicate + String sql = "select join1.id\n" + + "from join1\n" + + "left join join2 on join1.id = join2.id\n" + + "where join1.id > 1;"; + String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); + + // test left join: left table where in predicate + sql = "select join1.id\n" + + "from join1\n" + + "left join join2 on join1.id = join2.id\n" + + "where join1.id in (2);"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` IN (2)")); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` IN (2)")); + + // test left join: left table where between predicate + sql = "select join1.id\n" + + "from join1\n" + + "left join join2 on join1.id = join2.id\n" + + "where join1.id BETWEEN 1 AND 2;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` >= 1, `join1`.`id` <= 2")); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` >= 1, `join2`.`id` <= 2")); + + // test left join: left table join predicate, left table couldn't push down + sql = "select *\n from join1\n" + + "left join join2 on join1.id = join2.id\n" + + "and join1.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("other join predicates: `join1`.`id` > 1")); + Assert.assertFalse(explainString.contains("PREDICATES: `join1`.`id` > 1")); + + // test left join: right table where predicate. + // If we eliminate outer join, we could push predicate down to join1 and join2. + // Currently, we push predicate to join1 and keep join predicate for join2 + sql = "select *\n from join1\n" + + "left join join2 on join1.id = join2.id\n" + + "where join2.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); + Assert.assertFalse(explainString.contains("other join predicates: `join2`.`id` > 1")); + + // test left join: right table join predicate, only push down right table + sql = "select *\n from join1\n" + + "left join join2 on join1.id = join2.id\n" + + "and join2.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); + Assert.assertFalse(explainString.contains("PREDICATES: `join1`.`id` > 1")); + + // test inner join: left table where predicate, both push down left table and right table + sql = "select *\n from join1\n" + + "join join2 on join1.id = join2.id\n" + + "where join1.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); + + // test inner join: left table join predicate, both push down left table and right table + sql = "select *\n from join1\n" + + "join join2 on join1.id = join2.id\n" + + "and join1.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); + + // test inner join: right table where predicate, both push down left table and right table + sql = "select *\n from join1\n" + + "join join2 on join1.id = join2.id\n" + + "where join2.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); + + // test inner join: right table join predicate, both push down left table and right table + sql = "select *\n from join1\n" + + "join join2 on join1.id = join2.id\n" + + "and join2.id > 1;"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); + System.out.println(explainString); + Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); + Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); + } }