diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java index d886677e21c727..94c429d81ea888 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java @@ -361,13 +361,18 @@ private PlanFragment createHashJoinFragment(HashJoinNode node, PlanFragment righ // side to be partitioned for correctness) // - and the expected size of the hash tbl doesn't exceed perNodeMemLimit // we set partition join as default when broadcast join cost equals partition join cost - if (node.getJoinOp() != JoinOperator.RIGHT_OUTER_JOIN - && node.getJoinOp() != JoinOperator.FULL_OUTER_JOIN - && (perNodeMemLimit == 0 || Math.round( - (double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit) - && (node.getInnerRef().isBroadcastJoin() || (!node.getInnerRef().isPartitionJoin() - && isBroadcastCostSmaller(broadcastCost, partitionCost)))) { - doBroadcast = true; + if (node.getJoinOp() != JoinOperator.RIGHT_OUTER_JOIN && node.getJoinOp() != JoinOperator.FULL_OUTER_JOIN) { + if (node.getInnerRef().isBroadcastJoin()) { + // respect user join hint + doBroadcast = true; + } else if (!node.getInnerRef().isPartitionJoin() + && isBroadcastCostSmaller(broadcastCost, partitionCost) + && (perNodeMemLimit == 0 + || Math.round((double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)) { + doBroadcast = true; + } else { + doBroadcast = false; + } } else { doBroadcast = false; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java index 4c83d31dbbb404..6bb104381bc9c2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java @@ -17,26 +17,60 @@ package org.apache.doris.planner; +import org.apache.doris.analysis.CreateDbStmt; +import org.apache.doris.analysis.CreateTableStmt; import org.apache.doris.analysis.TupleId; +import org.apache.doris.catalog.Catalog; +import org.apache.doris.common.jmockit.Deencapsulation; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.StmtExecutor; +import org.apache.doris.thrift.TExplainLevel; +import org.apache.doris.utframe.UtFrameUtils; import com.google.common.collect.Lists; import com.google.common.collect.Sets; - -import org.apache.doris.common.jmockit.Deencapsulation; +import mockit.Expectations; +import mockit.Injectable; +import mockit.Mocked; +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.StringUtils; +import org.junit.After; import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; +import java.io.File; import java.util.List; import java.util.Set; - -import mockit.Expectations; -import mockit.Injectable; -import mockit.Mocked; +import java.util.UUID; public class DistributedPlannerTest { + private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/"; + private static ConnectContext ctx; + + @BeforeClass + public static void setUp() throws Exception { + UtFrameUtils.createMinDorisCluster(runningDir); + ctx = UtFrameUtils.createDefaultCtx(); + String createDbStmtStr = "create database db1;"; + CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx); + Catalog.getCurrentCatalog().createDb(createDbStmt); + // create table tbl1 + String createTblStmtStr = "create table db1.tbl1(k1 int, k2 varchar(32), v bigint sum) " + + "AGGREGATE KEY(k1,k2) distributed by hash(k1) buckets 1 properties('replication_num' = '1');"; + CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx); + Catalog.getCurrentCatalog().createTable(createTableStmt); + // create table tbl2 + createTblStmtStr = "create table db1.tbl2(k3 int, k4 varchar(32)) " + + "DUPLICATE KEY(k3) distributed by hash(k3) buckets 1 properties('replication_num' = '1');"; + createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx); + Catalog.getCurrentCatalog().createTable(createTableStmt); + } - @Mocked - PlannerContext plannerContext; + @After + public void tearDown() throws Exception { + FileUtils.deleteDirectory(new File(runningDir)); + } @Test public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode assertNumRowsNode, @@ -44,7 +78,8 @@ public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode @Injectable PlanNodeId planNodeId, @Injectable PlanFragmentId planFragmentId, @Injectable PlanNode inputPlanRoot, - @Injectable TupleId tupleId) { + @Injectable TupleId tupleId, + @Mocked PlannerContext plannerContext) { DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext); List tupleIdList = Lists.newArrayList(tupleId); @@ -82,7 +117,8 @@ public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode @Test public void testAssertFragmentWithUnpartitionInput(@Injectable AssertNumRowsNode assertNumRowsNode, - @Injectable PlanFragment inputFragment){ + @Injectable PlanFragment inputFragment, + @Mocked PlannerContext plannerContext){ DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext); PlanFragment assertFragment = Deencapsulation.invoke(distributedPlanner, "createAssertFragment", @@ -91,4 +127,22 @@ public void testAssertFragmentWithUnpartitionInput(@Injectable AssertNumRowsNode Assert.assertTrue(assertFragment.getPlanRoot() instanceof AssertNumRowsNode); } + @Test + public void testExplicitlyBroadcastJoin() throws Exception { + String sql = "explain select * from db1.tbl1 join [BROADCAST] db1.tbl2 on tbl1.k1 = tbl2.k3"; + StmtExecutor stmtExecutor = new StmtExecutor(ctx, sql); + stmtExecutor.execute(); + Planner planner = stmtExecutor.planner(); + List fragments = planner.getFragments(); + String plan = planner.getExplainString(fragments, TExplainLevel.NORMAL); + Assert.assertEquals(1, StringUtils.countMatches(plan, "INNER JOIN (BROADCAST)")); + + sql = "explain select * from db1.tbl1 join [SHUFFLE] db1.tbl2 on tbl1.k1 = tbl2.k3"; + stmtExecutor = new StmtExecutor(ctx, sql); + stmtExecutor.execute(); + planner = stmtExecutor.planner(); + fragments = planner.getFragments(); + plan = planner.getExplainString(fragments, TExplainLevel.NORMAL); + Assert.assertEquals(1, StringUtils.countMatches(plan, "INNER JOIN (PARTITIONED)")); + } }