Skip to content

Commit

Permalink
Enable async processing for SDF on Spark runner apache#23852
Browse files Browse the repository at this point in the history
  • Loading branch information
Jozef Vilcek committed Dec 30, 2022
1 parent 5e3604f commit ba83018
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ public PCollectionTuple expand(PCollection<KV<byte[], KV<InputT, RestrictionT>>>
}
}

static class NaiveProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
public static class NaiveProcessFn<
InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
extends DoFn<KV<InputT, RestrictionT>, OutputT> {
private final DoFn<InputT, OutputT> fn;
private final Map<String, PCollectionView<?>> sideInputMapping;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,10 @@ static void prepareFilesToStage(SparkCommonPipelineOptions options) {
PipelineResources.prepareFilesForStaging(options);
}
}

@Description("Enable/disable async output for operators with possibly large output ( such as splittable DoFn )")
@Default.Boolean(true)
Boolean getEnableAsyncOperatorOutput();

void setEnableAsyncOperatorOutput(Boolean value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.LinkedListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.util.AccumulatorV2;
import scala.Tuple2;
Expand Down Expand Up @@ -80,6 +77,7 @@ public class MultiDoFnFunction<InputT, OutputT>
private final boolean stateful;
private final DoFnSchemaInformation doFnSchemaInformation;
private final Map<String, PCollectionView<?>> sideInputMapping;
private final boolean useAsyncProcessing;

/**
* @param metricsAccum The Spark {@link AccumulatorV2} that backs the Beam metrics.
Expand All @@ -92,21 +90,23 @@ public class MultiDoFnFunction<InputT, OutputT>
* @param sideInputs Side inputs used in this {@link DoFn}.
* @param windowingStrategy Input {@link WindowingStrategy}.
* @param stateful Stateful {@link DoFn}.
* @param useAsyncProcessing If it should use asynchronous processing.
*/
public MultiDoFnFunction(
MetricsContainerStepMapAccumulator metricsAccum,
String stepName,
DoFn<InputT, OutputT> doFn,
SerializablePipelineOptions options,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
Coder<InputT> inputCoder,
Map<TupleTag<?>, Coder<?>> outputCoders,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy,
boolean stateful,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping) {
MetricsContainerStepMapAccumulator metricsAccum,
String stepName,
DoFn<InputT, OutputT> doFn,
SerializablePipelineOptions options,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
Coder<InputT> inputCoder,
Map<TupleTag<?>, Coder<?>> outputCoders,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy,
boolean stateful,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping,
boolean useAsyncProcessing) {
this.metricsAccum = metricsAccum;
this.stepName = stepName;
this.doFn = SerializableUtils.clone(doFn);
Expand All @@ -120,17 +120,24 @@ public MultiDoFnFunction(
this.stateful = stateful;
this.doFnSchemaInformation = doFnSchemaInformation;
this.sideInputMapping = sideInputMapping;
this.useAsyncProcessing = useAsyncProcessing;
}

@Override
public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> call(Iterator<WindowedValue<InputT>> iter)
throws Exception {
throws Exception {
if (!wasSetupCalled && iter.hasNext()) {
DoFnInvokers.tryInvokeSetupFor(doFn, options.get());
wasSetupCalled = true;
}

DoFnOutputManager outputManager = new DoFnOutputManager();

SparkInputDataProcessor<InputT, OutputT, Tuple2<TupleTag<?>, WindowedValue<?>>> processor;
if (useAsyncProcessing) {
processor = SparkInputDataProcessor.createAsync();
} else {
processor = SparkInputDataProcessor.createSync();
}

final InMemoryTimerInternals timerInternals;
final StepContext context;
Expand Down Expand Up @@ -159,15 +166,16 @@ public TimerInternals timerInternals() {
};
} else {
timerInternals = null;
context = new SparkProcessContext.NoOpStepContext();
context = new SparkNoOpStepContext();
}


final DoFnRunner<InputT, OutputT> doFnRunner =
DoFnRunners.simpleRunner(
options.get(),
doFn,
CachedSideInputReader.of(new SparkSideInputReader(sideInputs)),
outputManager,
processor.getOutputManager(),
mainOutputTag,
additionalOutputTags,
context,
Expand All @@ -180,14 +188,13 @@ public TimerInternals timerInternals() {
DoFnRunnerWithMetrics<InputT, OutputT> doFnRunnerWithMetrics =
new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum);

return new SparkProcessContext<>(
SparkProcessContext<Object, InputT, OutputT> ctx = new SparkProcessContext<>(
doFn,
doFnRunnerWithMetrics,
outputManager,
key,
stateful ? new TimerDataIterator(timerInternals) : Collections.emptyIterator())
.processPartition(iter)
.iterator();
stateful ? new TimerDataIterator(timerInternals) : Collections.emptyIterator());

return processor.process(iter, ctx).iterator();
}

private static class TimerDataIterator implements Iterator<TimerInternals.TimerData> {
Expand Down Expand Up @@ -238,29 +245,17 @@ public void remove() {
}
}

private class DoFnOutputManager
implements SparkProcessContext.SparkOutputManager<Tuple2<TupleTag<?>, WindowedValue<?>>> {

private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();
private static class SparkNoOpStepContext implements StepContext {

@Override
public void clear() {
outputs.clear();
}

@Override
public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> iterator() {
Iterator<Map.Entry<TupleTag<?>, WindowedValue<?>>> entryIter = outputs.entries().iterator();
return Iterators.transform(entryIter, this.entryToTupleFn());
}

private <K, V> Function<Map.Entry<K, V>, Tuple2<K, V>> entryToTupleFn() {
return en -> new Tuple2<>(en.getKey(), en.getValue());
public StateInternals stateInternals() {
throw new UnsupportedOperationException("stateInternals not supported");
}

@Override
public synchronized <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
outputs.put(tag, output);
public TimerInternals timerInternals() {
throw new UnsupportedOperationException("timerInternals not supported");
}
}
}
Loading

0 comments on commit ba83018

Please sign in to comment.