Skip to content

Commit

Permalink
[opt](coordinator) optimize parallel degree of shuffle when use nerei…
Browse files Browse the repository at this point in the history
…ds (apache#44754)

optimize parallel degree of shuffle when use nereids , this pr can fix
some performance rollback when upgrade doris from 1.2 to 2.x/3.x
  • Loading branch information
924060929 authored Jan 10, 2025
1 parent 7a6c125 commit 584a256
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 121 deletions.
8 changes: 4 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ struct AggregateFunctionCollectListData<StringRef, HasLimit> {
}
max_size = rhs.max_size;

data->insert_range_from(
*rhs.data, 0,
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(max_size - size()),
rhs.size()));
data->insert_range_from(*rhs.data, 0,
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(
static_cast<size_t>(max_size - size())),
rhs.size()));
} else {
data->insert_range_from(*rhs.data, 0, rhs.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.SetMultimap;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -136,6 +135,16 @@ private FragmentIdMapping<DistributedPlan> linkPlans(FragmentIdMapping<Distribut
link.getKey(),
enableShareHashTableForBroadcastJoin
);
for (Entry<DataSink, List<AssignedJob>> kv :
((PipelineDistributedPlan) link.getValue()).getDestinations().entrySet()) {
if (kv.getValue().isEmpty()) {
int sourceFragmentId = link.getValue().getFragmentJob().getFragment().getFragmentId().asInt();
String msg = "Invalid plan which exchange not contains receiver, "
+ "exchange id: " + kv.getKey().getExchNodeId().asInt()
+ ", source fragmentId: " + sourceFragmentId;
throw new IllegalStateException(msg);
}
}
}
}
return plans;
Expand Down Expand Up @@ -184,7 +193,7 @@ private List<AssignedJob> filterInstancesWhichCanReceiveDataFromRemote(
boolean useLocalShuffle = receiverPlan.getInstanceJobs().stream()
.anyMatch(LocalShuffleAssignedJob.class::isInstance);
if (useLocalShuffle) {
return getFirstInstancePerShareScan(receiverPlan);
return getFirstInstancePerWorker(receiverPlan.getInstanceJobs());
} else if (enableShareHashTableForBroadcastJoin && linkNode.isRightChildOfBroadcastHashJoin()) {
return getFirstInstancePerWorker(receiverPlan.getInstanceJobs());
} else {
Expand Down Expand Up @@ -221,17 +230,6 @@ private List<AssignedJob> sortDestinationInstancesByBuckets(
return Arrays.asList(instances);
}

private List<AssignedJob> getFirstInstancePerShareScan(PipelineDistributedPlan plan) {
List<AssignedJob> canReceiveDataFromRemote = Lists.newArrayListWithCapacity(plan.getInstanceJobs().size());
for (AssignedJob instanceJob : plan.getInstanceJobs()) {
LocalShuffleAssignedJob localShuffleJob = (LocalShuffleAssignedJob) instanceJob;
if (!localShuffleJob.receiveDataFromLocal) {
canReceiveDataFromRemote.add(localShuffleJob);
}
}
return canReceiveDataFromRemote;
}

private List<AssignedJob> getFirstInstancePerWorker(List<AssignedJob> instances) {
Map<DistributedPlanWorker, AssignedJob> firstInstancePerWorker = Maps.newLinkedHashMap();
for (AssignedJob instance : instances) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ protected List<AssignedJob> insideMachineParallelization(

// now we should compute how many instances to process the data,
// for example: two instances
int instanceNum = degreeOfParallelism(scanSourceMaxParallel);
int instanceNum = degreeOfParallelism(scanSourceMaxParallel, useLocalShuffleToAddParallel);

if (useLocalShuffleToAddParallel) {
assignLocalShuffleJobs(scanSource, instanceNum, instances, context, worker);
Expand Down Expand Up @@ -129,7 +129,7 @@ protected void assignedDefaultJobs(ScanSource scanSource, int instanceNum, List<
protected void assignLocalShuffleJobs(ScanSource scanSource, int instanceNum, List<AssignedJob> instances,
ConnectContext context, DistributedPlanWorker worker) {
// only generate one instance to scan all data, in this step
List<ScanSource> instanceToScanRanges = scanSource.parallelize(scanNodes, 1);
List<ScanSource> assignedJoinBuckets = scanSource.parallelize(scanNodes, instanceNum);

// when data not big, but aggregation too slow, we will use 1 instance to scan data,
// and use more instances (to ***add parallel***) to process aggregate.
Expand All @@ -144,23 +144,23 @@ protected void assignLocalShuffleJobs(ScanSource scanSource, int instanceNum, Li
// |(share scan node, instance1 will scan all data and local shuffle to other local instances |
// | to parallel compute this data) |
// +------------------------------------------------------------------------------------------------+
ScanSource shareScanSource = instanceToScanRanges.get(0);
ScanSource shareScanSource = assignedJoinBuckets.get(0);

// one scan range generate multiple instances,
// different instances reference the same scan source
int shareScanId = shareScanIdGenerator.getAndIncrement();
ScanSource emptyShareScanSource = shareScanSource.newEmpty();
for (int i = 0; i < instanceNum; i++) {
LocalShuffleAssignedJob instance = new LocalShuffleAssignedJob(
instances.size(), shareScanId, i > 0,
context.nextInstanceId(), this, worker,
i == 0 ? shareScanSource : emptyShareScanSource
instances.size(), shareScanId, context.nextInstanceId(), this, worker,
// only first instance need to scan data
i == 0 ? scanSource : emptyShareScanSource
);
instances.add(instance);
}
}

protected int degreeOfParallelism(int maxParallel) {
protected int degreeOfParallelism(int maxParallel, boolean useLocalShuffleToAddParallel) {
Preconditions.checkArgument(maxParallel > 0, "maxParallel must be positive");
if (!fragment.getDataPartition().isPartitioned()) {
return 1;
Expand All @@ -179,6 +179,10 @@ protected int degreeOfParallelism(int maxParallel) {
}
}

if (useLocalShuffleToAddParallel) {
return Math.max(fragment.getParallelExecNum(), 1);
}

// the scan instance num should not larger than the tablets num
return Math.min(maxParallel, Math.max(fragment.getParallelExecNum(), 1));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,17 @@
*/
public class LocalShuffleAssignedJob extends StaticAssignedJob {
public final int shareScanId;
public final boolean receiveDataFromLocal;

public LocalShuffleAssignedJob(
int indexInUnassignedJob, int shareScanId, boolean receiveDataFromLocal, TUniqueId instanceId,
int indexInUnassignedJob, int shareScanId, TUniqueId instanceId,
UnassignedJob unassignedJob,
DistributedPlanWorker worker, ScanSource scanSource) {
super(indexInUnassignedJob, instanceId, unassignedJob, worker, scanSource);
this.shareScanId = shareScanId;
this.receiveDataFromLocal = receiveDataFromLocal;
}

@Override
protected Map<String, String> extraInfo() {
return ImmutableMap.of("shareScanIndex", String.valueOf(shareScanId));
}

@Override
protected String formatScanSourceString() {
if (receiveDataFromLocal) {
return "read data from first instance of " + getAssignedWorker();
} else {
return super.formatScanSourceString();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ public class LocalShuffleBucketJoinAssignedJob extends LocalShuffleAssignedJob {
private volatile Set<Integer> assignedJoinBucketIndexes;

public LocalShuffleBucketJoinAssignedJob(
int indexInUnassignedJob, int shareScanId, boolean receiveDataFromLocal,
int indexInUnassignedJob, int shareScanId,
TUniqueId instanceId, UnassignedJob unassignedJob,
DistributedPlanWorker worker, ScanSource scanSource,
Set<Integer> assignedJoinBucketIndexes) {
super(indexInUnassignedJob, shareScanId, receiveDataFromLocal, instanceId, unassignedJob, worker, scanSource);
super(indexInUnassignedJob, shareScanId, instanceId, unassignedJob, worker, scanSource);
this.assignedJoinBucketIndexes = Utils.fastToImmutableSet(assignedJoinBucketIndexes);
}

public Set<Integer> getAssignedJoinBucketIndexes() {
return assignedJoinBucketIndexes;
}

public void addAssignedJoinBucketIndexes(Set<Integer> joinBucketIndexes) {
public synchronized void addAssignedJoinBucketIndexes(Set<Integer> joinBucketIndexes) {
this.assignedJoinBucketIndexes = ImmutableSet.<Integer>builder()
.addAll(assignedJoinBucketIndexes)
.addAll(joinBucketIndexes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

/** UnassignedGatherJob */
public class UnassignedGatherJob extends AbstractUnassignedJob {
private boolean useLocalShuffleToAddParallel;
private boolean useSerialSource;

public UnassignedGatherJob(
StatementContext statementContext, PlanFragment fragment,
Expand All @@ -44,24 +44,24 @@ public UnassignedGatherJob(
public List<AssignedJob> computeAssignedJobs(
DistributeContext distributeContext, ListMultimap<ExchangeNode, AssignedJob> inputJobs) {
ConnectContext connectContext = statementContext.getConnectContext();
useLocalShuffleToAddParallel = fragment.useSerialSource(connectContext);
useSerialSource = fragment.useSerialSource(connectContext);

int expectInstanceNum = degreeOfParallelism();

DistributedPlanWorker selectedWorker = distributeContext.selectedWorkers.tryToSelectRandomUsedWorker();
if (useLocalShuffleToAddParallel) {
if (useSerialSource) {
// Using serial source means a serial source operator will be used in this fragment (e.g. data will be
// shuffled to only 1 exchange operator) and then split by followed local exchanger
ImmutableList.Builder<AssignedJob> instances = ImmutableList.builder();

DefaultScanSource shareScan = new DefaultScanSource(ImmutableMap.of());
LocalShuffleAssignedJob receiveDataFromRemote = new LocalShuffleAssignedJob(
0, 0, false,
0, 0,
connectContext.nextInstanceId(), this, selectedWorker, shareScan);

instances.add(receiveDataFromRemote);
for (int i = 1; i < expectInstanceNum; ++i) {
LocalShuffleAssignedJob receiveDataFromLocal = new LocalShuffleAssignedJob(
i, 0, true,
connectContext.nextInstanceId(), this, selectedWorker, shareScan);
i, 0, connectContext.nextInstanceId(), this, selectedWorker, shareScan);
instances.add(receiveDataFromLocal);
}
return instances.build();
Expand All @@ -76,6 +76,6 @@ selectedWorker, new DefaultScanSource(ImmutableMap.of())
}

protected int degreeOfParallelism() {
return useLocalShuffleToAddParallel ? fragment.getParallelExecNum() : 1;
return useSerialSource ? Math.max(1, fragment.getParallelExecNum()) : 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -184,13 +188,23 @@ protected void assignLocalShuffleJobs(ScanSource scanSource, int instanceNum, Li
Set<Integer> assignedJoinBuckets
= ((BucketScanSource) assignJoinBuckets.get(i)).bucketIndexToScanNodeToTablets.keySet();
LocalShuffleBucketJoinAssignedJob instance = new LocalShuffleBucketJoinAssignedJob(
instances.size(), shareScanId, i > 0,
context.nextInstanceId(), this, worker,
instances.size(), shareScanId, context.nextInstanceId(),
this, worker,
i == 0 ? shareScanSource : emptyShareScanSource,
Utils.fastToImmutableSet(assignedJoinBuckets)
);
instances.add(instance);
}

for (int i = assignJoinBuckets.size(); i < instanceNum; ++i) {
LocalShuffleBucketJoinAssignedJob instance = new LocalShuffleBucketJoinAssignedJob(
instances.size(), shareScanId, context.nextInstanceId(),
this, worker, emptyShareScanSource,
// these instance not need to join, because no any bucket assign to it
ImmutableSet.of()
);
instances.add(instance);
}
}

private boolean shouldFillUpInstances(List<HashJoinNode> hashJoinNodes) {
Expand Down Expand Up @@ -224,10 +238,21 @@ private List<AssignedJob> fillUpInstances(List<AssignedJob> instances) {
olapScanNode, randomPartition, missingBucketIndexes);

boolean useLocalShuffle = instances.stream().anyMatch(LocalShuffleAssignedJob.class::isInstance);
Multimap<DistributedPlanWorker, AssignedJob> workerToAssignedJobs = ArrayListMultimap.create();
int maxNumInstancePerWorker = 1;
if (useLocalShuffle) {
for (AssignedJob instance : instances) {
workerToAssignedJobs.put(instance.getAssignedWorker(), instance);
}
for (Collection<AssignedJob> instanceList : workerToAssignedJobs.asMap().values()) {
maxNumInstancePerWorker = Math.max(maxNumInstancePerWorker, instanceList.size());
}
}

List<AssignedJob> newInstances = new ArrayList<>(instances);

for (Entry<DistributedPlanWorker, Collection<Integer>> workerToBuckets : missingBuckets.asMap().entrySet()) {
Map<Integer, Map<ScanNode, ScanRanges>> scanEmptyBuckets = Maps.newLinkedHashMap();
Set<Integer> assignedJoinBuckets = Utils.fastToImmutableSet(workerToBuckets.getValue());
for (Integer bucketIndex : workerToBuckets.getValue()) {
Map<ScanNode, ScanRanges> scanTableWithEmptyData = Maps.newLinkedHashMap();
for (ScanNode scanNode : scanNodes) {
Expand All @@ -236,42 +261,62 @@ private List<AssignedJob> fillUpInstances(List<AssignedJob> instances) {
scanEmptyBuckets.put(bucketIndex, scanTableWithEmptyData);
}

AssignedJob fillUpInstance = null;
DistributedPlanWorker worker = workerToBuckets.getKey();
BucketScanSource scanSource = new BucketScanSource(scanEmptyBuckets);
if (useLocalShuffle) {
// when use local shuffle, we should ensure every backend only process one instance!
// so here we should try to merge the missing buckets into exist instances
boolean mergedBucketsInSameWorkerInstance = false;
for (AssignedJob newInstance : newInstances) {
if (newInstance.getAssignedWorker().equals(worker)) {
BucketScanSource bucketScanSource = (BucketScanSource) newInstance.getScanSource();
bucketScanSource.bucketIndexToScanNodeToTablets.putAll(scanEmptyBuckets);
mergedBucketsInSameWorkerInstance = true;

LocalShuffleBucketJoinAssignedJob instance = (LocalShuffleBucketJoinAssignedJob) newInstance;
instance.addAssignedJoinBucketIndexes(assignedJoinBuckets);
}
List<AssignedJob> sameWorkerInstances = (List) workerToAssignedJobs.get(worker);
if (sameWorkerInstances.isEmpty()) {
sameWorkerInstances = fillUpEmptyInstances(
maxNumInstancePerWorker, scanSource, worker, newInstances, context);
}
if (!mergedBucketsInSameWorkerInstance) {
fillUpInstance = new LocalShuffleBucketJoinAssignedJob(
newInstances.size(), shareScanIdGenerator.getAndIncrement(),
false, context.nextInstanceId(), this, worker, scanSource,
assignedJoinBuckets
);

LocalShuffleBucketJoinAssignedJob firstInstance
= (LocalShuffleBucketJoinAssignedJob ) sameWorkerInstances.get(0);
BucketScanSource firstInstanceScanSource
= (BucketScanSource) firstInstance.getScanSource();
firstInstanceScanSource.bucketIndexToScanNodeToTablets.putAll(scanEmptyBuckets);

Iterator<Integer> assignedJoinBuckets = new LinkedHashSet<>(workerToBuckets.getValue()).iterator();
// make sure the first instance must be assigned some buckets:
// if the first instance assigned some buckets, we start assign empty
// bucket for second instance for balance, or else assign for first instance
int index = firstInstance.getAssignedJoinBucketIndexes().isEmpty() ? -1 : 0;
while (assignedJoinBuckets.hasNext()) {
Integer bucketIndex = assignedJoinBuckets.next();
assignedJoinBuckets.remove();

index = (index + 1) % sameWorkerInstances.size();
LocalShuffleBucketJoinAssignedJob instance
= (LocalShuffleBucketJoinAssignedJob) sameWorkerInstances.get(index);
instance.addAssignedJoinBucketIndexes(ImmutableSet.of(bucketIndex));
}
} else {
fillUpInstance = assignWorkerAndDataSources(
newInstances.add(assignWorkerAndDataSources(
newInstances.size(), context.nextInstanceId(), worker, scanSource
);
}
if (fillUpInstance != null) {
newInstances.add(fillUpInstance);
));
}
}
return newInstances;
}

private List<AssignedJob> fillUpEmptyInstances(
int maxNumInstancePerWorker, BucketScanSource scanSource, DistributedPlanWorker worker,
List<AssignedJob> existsInstances, ConnectContext context) {
int shareScanId = shareScanIdGenerator.getAndIncrement();
List<AssignedJob> newInstances = new ArrayList<>(maxNumInstancePerWorker);
for (int i = 0; i < maxNumInstancePerWorker; i++) {
LocalShuffleBucketJoinAssignedJob newInstance = new LocalShuffleBucketJoinAssignedJob(
existsInstances.size(), shareScanId,
context.nextInstanceId(), this, worker,
scanSource.newEmpty(),
ImmutableSet.of()
);
existsInstances.add(newInstance);
newInstances.add(newInstance);
}
return newInstances;
}

private int fullBucketNum() {
for (ScanNode scanNode : scanNodes) {
if (scanNode instanceof OlapScanNode) {
Expand Down
Loading

0 comments on commit 584a256

Please sign in to comment.