Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,18 @@ private PlanFragment createHashJoinFragment(HashJoinNode node, PlanFragment righ
// side to be partitioned for correctness)
// - and the expected size of the hash tbl doesn't exceed perNodeMemLimit
// we set partition join as default when broadcast join cost equals partition join cost
if (node.getJoinOp() != JoinOperator.RIGHT_OUTER_JOIN
&& node.getJoinOp() != JoinOperator.FULL_OUTER_JOIN
&& (perNodeMemLimit == 0 || Math.round(
(double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)
&& (node.getInnerRef().isBroadcastJoin() || (!node.getInnerRef().isPartitionJoin()
&& isBroadcastCostSmaller(broadcastCost, partitionCost)))) {
doBroadcast = true;
if (node.getJoinOp() != JoinOperator.RIGHT_OUTER_JOIN && node.getJoinOp() != JoinOperator.FULL_OUTER_JOIN) {
if (node.getInnerRef().isBroadcastJoin()) {
// respect user join hint
doBroadcast = true;
} else if (!node.getInnerRef().isPartitionJoin()
&& isBroadcastCostSmaller(broadcastCost, partitionCost)
&& (perNodeMemLimit == 0
|| Math.round((double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)) {
doBroadcast = true;
} else {
doBroadcast = false;
}
} else {
doBroadcast = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,69 @@

package org.apache.doris.planner;

import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.thrift.TExplainLevel;
import org.apache.doris.utframe.UtFrameUtils;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import org.apache.doris.common.jmockit.Deencapsulation;
import mockit.Expectations;
import mockit.Injectable;
import mockit.Mocked;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.File;
import java.util.List;
import java.util.Set;

import mockit.Expectations;
import mockit.Injectable;
import mockit.Mocked;
import java.util.UUID;

public class DistributedPlannerTest {
private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/";
private static ConnectContext ctx;

@BeforeClass
public static void setUp() throws Exception {
UtFrameUtils.createMinDorisCluster(runningDir);
ctx = UtFrameUtils.createDefaultCtx();
String createDbStmtStr = "create database db1;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
Catalog.getCurrentCatalog().createDb(createDbStmt);
// create table tbl1
String createTblStmtStr = "create table db1.tbl1(k1 int, k2 varchar(32), v bigint sum) "
+ "AGGREGATE KEY(k1,k2) distributed by hash(k1) buckets 1 properties('replication_num' = '1');";
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
Catalog.getCurrentCatalog().createTable(createTableStmt);
// create table tbl2
createTblStmtStr = "create table db1.tbl2(k3 int, k4 varchar(32)) "
+ "DUPLICATE KEY(k3) distributed by hash(k3) buckets 1 properties('replication_num' = '1');";
createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
Catalog.getCurrentCatalog().createTable(createTableStmt);
}

@Mocked
PlannerContext plannerContext;
@After
public void tearDown() throws Exception {
FileUtils.deleteDirectory(new File(runningDir));
}

@Test
public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode assertNumRowsNode,
@Injectable PlanFragment inputFragment,
@Injectable PlanNodeId planNodeId,
@Injectable PlanFragmentId planFragmentId,
@Injectable PlanNode inputPlanRoot,
@Injectable TupleId tupleId) {
@Injectable TupleId tupleId,
@Mocked PlannerContext plannerContext) {
DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext);

List<TupleId> tupleIdList = Lists.newArrayList(tupleId);
Expand Down Expand Up @@ -82,7 +117,8 @@ public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode

@Test
public void testAssertFragmentWithUnpartitionInput(@Injectable AssertNumRowsNode assertNumRowsNode,
@Injectable PlanFragment inputFragment){
@Injectable PlanFragment inputFragment,
@Mocked PlannerContext plannerContext){
DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext);

PlanFragment assertFragment = Deencapsulation.invoke(distributedPlanner, "createAssertFragment",
Expand All @@ -91,4 +127,22 @@ public void testAssertFragmentWithUnpartitionInput(@Injectable AssertNumRowsNode
Assert.assertTrue(assertFragment.getPlanRoot() instanceof AssertNumRowsNode);
}

@Test
public void testExplicitlyBroadcastJoin() throws Exception {
String sql = "explain select * from db1.tbl1 join [BROADCAST] db1.tbl2 on tbl1.k1 = tbl2.k3";
StmtExecutor stmtExecutor = new StmtExecutor(ctx, sql);
stmtExecutor.execute();
Planner planner = stmtExecutor.planner();
List<PlanFragment> fragments = planner.getFragments();
String plan = planner.getExplainString(fragments, TExplainLevel.NORMAL);
Assert.assertEquals(1, StringUtils.countMatches(plan, "INNER JOIN (BROADCAST)"));

sql = "explain select * from db1.tbl1 join [SHUFFLE] db1.tbl2 on tbl1.k1 = tbl2.k3";
stmtExecutor = new StmtExecutor(ctx, sql);
stmtExecutor.execute();
planner = stmtExecutor.planner();
fragments = planner.getFragments();
plan = planner.getExplainString(fragments, TExplainLevel.NORMAL);
Assert.assertEquals(1, StringUtils.countMatches(plan, "INNER JOIN (PARTITIONED)"));
}
}