Skip to content

Commit

Permalink
Enrich driver parallelism type, for improved operator logging (#364)
Browse files Browse the repository at this point in the history
As a followup to the previously added operator factories, this PR
enriches the driver parallelism type, for improved operator logging.
  • Loading branch information
ChrisHegarty authored Nov 9, 2022
1 parent 38ce99e commit 330905c
Showing 1 changed file with 36 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public LocalExecutionPlan plan(PhysicalPlan node) {

PhysicalOperation physicalOperation = plan(node, context);

context.addDriverFactory(new DriverFactory(new DriverSupplier(physicalOperation), context.getDriverInstanceCount()));
context.addDriverFactory(new DriverFactory(new DriverSupplier(physicalOperation), context.driverParallelism()));

LocalExecutionPlan localExecutionPlan = new LocalExecutionPlan();
localExecutionPlan.driverFactories.addAll(context.driverFactories);
Expand Down Expand Up @@ -283,15 +283,17 @@ public PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlanContext conte
source
);
} else if (node instanceof ExchangeExec exchangeExec) {
int driverInstances = exchangeExec.getType() == ExchangeExec.Type.GATHER ? 1 : taskConcurrency;
context.setDriverInstanceCount(driverInstances);
Exchange ex = new Exchange(driverInstances, exchangeExec.getPartitioning().toExchange(), bufferMaxPages);
DriverParallelism parallelism = exchangeExec.getType() == ExchangeExec.Type.GATHER
? DriverParallelism.SINGLE
: new DriverParallelism(DriverParallelism.Type.TASK_LEVEL_PARALLELISM, taskConcurrency);
context.driverParallelism(parallelism);
Exchange ex = new Exchange(parallelism.instanceCount(), exchangeExec.getPartitioning().toExchange(), bufferMaxPages);

LocalExecutionPlanContext subContext = context.createSubContext();
PhysicalOperation source = plan(exchangeExec.child(), subContext);
Map<Object, Integer> layout = source.layout;
PhysicalOperation physicalOperation = new PhysicalOperation(new ExchangeSinkOperatorFactory(ex), source.layout, source);
context.addDriverFactory(new DriverFactory(new DriverSupplier(physicalOperation), subContext.getDriverInstanceCount()));
context.addDriverFactory(new DriverFactory(new DriverSupplier(physicalOperation), subContext.driverParallelism()));
return new PhysicalOperation(new ExchangeSourceOperatorFactory(ex), layout);
} else if (node instanceof TopNExec topNExec) {
PhysicalOperation source = plan(topNExec.child(), context);
Expand Down Expand Up @@ -380,7 +382,7 @@ private PhysicalOperation planEsQueryNode(EsQueryExec esQuery, LocalExecutionPla
dataPartitioning,
taskConcurrency
);
context.setDriverInstanceCount(operatorFactory.size());
context.driverParallelism(new DriverParallelism(DriverParallelism.Type.DATA_PARALLELISM, operatorFactory.size()));
Map<Object, Integer> layout = new HashMap<>();
for (int i = 0; i < esQuery.output().size(); i++) {
layout.put(esQuery.output().get(i).id(), i);
Expand Down Expand Up @@ -463,13 +465,28 @@ public String describe() {
}
}

/**
* The count and type of driver parallelism.
*/
record DriverParallelism(Type type, int instanceCount) {

static final DriverParallelism SINGLE = new DriverParallelism(Type.SINGLETON, 1);

enum Type {
SINGLETON,
DATA_PARALLELISM,
TASK_LEVEL_PARALLELISM
}
}

/**
* Context object used while generating a local plan. Currently only collects the driver factories as well as
* maintains information how many driver instances should be created for a given driver.
*/
public static class LocalExecutionPlanContext {
final List<DriverFactory> driverFactories;
int driverInstanceCount = 1;

private DriverParallelism driverParallelism = DriverParallelism.SINGLE;

LocalExecutionPlanContext() {
driverFactories = new ArrayList<>();
Expand All @@ -488,12 +505,12 @@ public LocalExecutionPlanContext createSubContext() {
return subContext;
}

public int getDriverInstanceCount() {
return driverInstanceCount;
public DriverParallelism driverParallelism() {
return driverParallelism;
}

public void setDriverInstanceCount(int driverInstanceCount) {
this.driverInstanceCount = driverInstanceCount;
public void driverParallelism(DriverParallelism driverParallelism) {
this.driverParallelism = driverParallelism;
}
}

Expand All @@ -510,10 +527,15 @@ public String describe() {
}
}

record DriverFactory(DriverSupplier driverSupplier, int driverInstances) implements Describable {
record DriverFactory(DriverSupplier driverSupplier, DriverParallelism driverParallelism) implements Describable {
@Override
public String describe() {
return "DriverFactory(instances=" + driverInstances + ")\n" + driverSupplier.describe();
return "DriverFactory(instances = "
+ driverParallelism.instanceCount()
+ ", type = "
+ driverParallelism.type()
+ ")\n"
+ driverSupplier.describe();
}
}

Expand All @@ -525,7 +547,7 @@ public static class LocalExecutionPlan implements Describable {

public List<Driver> createDrivers() {
return driverFactories.stream()
.flatMap(df -> IntStream.range(0, df.driverInstances).mapToObj(i -> df.driverSupplier.get()))
.flatMap(df -> IntStream.range(0, df.driverParallelism().instanceCount()).mapToObj(i -> df.driverSupplier.get()))
.collect(Collectors.toList());
}

Expand Down

0 comments on commit 330905c

Please sign in to comment.