Skip to content

Commit

Permalink
[Spark Dataset runner] Skip unconsumed additional outputs of ParDo.Mu…
Browse files Browse the repository at this point in the history
…ltiOutput to avoid caching if not necessary (resolves #24710) (#24711)
  • Loading branch information
Moritz Mack authored Dec 28, 2022
1 parent 04694d8 commit 645bf35
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ public String name() {
public interface TranslationState extends EncoderProvider {
<T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);

boolean isLeave(PCollection<?> pCollection);

<T> void putDataset(
PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean cache);

Expand Down Expand Up @@ -256,6 +258,11 @@ public <T> void putDataset(
}
}

@Override
public boolean isLeave(PCollection<?> pCollection) {
return getResult(pCollection).dependentTransforms.isEmpty();
}

@Override
public <T> Broadcast<SideInputValues<T>> getSideInputBroadcast(
PCollection<T> pCollection, SideInputValues.Loader<T> loader) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public <T> void putDataset(
state.putDataset(pCollection, dataset, cache);
}

@Override
public boolean isLeave(PCollection<?> pCollection) {
return state.isLeave(pCollection);
}

@Override
public Supplier<PipelineOptions> getOptionsSupplier() {
return state.getOptionsSupplier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator;
import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;

import java.io.Serializable;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -61,17 +60,17 @@
*/
abstract class DoFnPartitionIteratorFactory<InT, FnOutT, OutT extends @NonNull Object>
implements Function1<Iterator<WindowedValue<InT>>, Iterator<OutT>>, Serializable {
private final String stepName;
private final DoFn<InT, FnOutT> doFn;
private final DoFnSchemaInformation doFnSchema;
private final Supplier<PipelineOptions> options;
private final Coder<InT> coder;
private final WindowingStrategy<?, ?> windowingStrategy;
private final TupleTag<FnOutT> mainOutput;
private final List<TupleTag<?>> additionalOutputs;
private final Map<TupleTag<?>, Coder<?>> outputCoders;
private final Map<String, PCollectionView<?>> sideInputs;
private final SideInputReader sideInputReader;
protected final String stepName;
protected final DoFn<InT, FnOutT> doFn;
protected final DoFnSchemaInformation doFnSchema;
protected final Supplier<PipelineOptions> options;
protected final Coder<InT> coder;
protected final WindowingStrategy<?, ?> windowingStrategy;
protected final TupleTag<FnOutT> mainOutput;
protected final List<TupleTag<?>> additionalOutputs;
protected final Map<TupleTag<?>, Coder<?>> outputCoders;
protected final Map<String, PCollectionView<?>> sideInputs;
protected final SideInputReader sideInputReader;

private DoFnPartitionIteratorFactory(
AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, FnOutT>> appliedPT,
Expand Down Expand Up @@ -147,7 +146,11 @@ DoFnRunners.OutputManager outputManager(Deque<WindowedValue<OutT>> buffer) {
return new DoFnRunners.OutputManager() {
@Override
public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
buffer.add((WindowedValue<OutT>) output);
// SingleOut will only ever emmit the mainOutput. Though, there might be additional
// outputs which are skipped if unused to avoid caching.
if (mainOutput.equals(tag)) {
buffer.add((WindowedValue<OutT>) output);
}
}
};
}
Expand Down Expand Up @@ -177,8 +180,11 @@ DoFnRunners.OutputManager outputManager(Deque<Tuple2<Integer, WindowedValue<OutT
return new DoFnRunners.OutputManager() {
@Override
public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
Integer columnIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag %s", tag);
buffer.add(tuple(columnIdx, (WindowedValue<OutT>) output));
// Additional unused outputs can be skipped here. In that case columnIdx is null.
Integer columnIdx = tagColIdx.get(tag.getId());
if (columnIdx != null) {
buffer.add(tuple(columnIdx, (WindowedValue<OutT>) output));
}
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,19 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
throws IOException {

PCollection<InputT> input = (PCollection<InputT>) cxt.getInput();
Map<TupleTag<?>, PCollection<?>> outputs = cxt.getOutputs();

Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input);
SideInputReader sideInputReader =
createSideInputReader(transform.getSideInputs().values(), cxt);

TupleTag<OutputT> mainOut = transform.getMainOutputTag();
// Filter out unconsumed PCollections (except mainOut) to potentially avoid the costs of caching
// if not really beneficial.
Map<TupleTag<?>, PCollection<?>> outputs =
Maps.filterEntries(
cxt.getOutputs(),
e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeave(e.getValue())));

if (outputs.size() > 1) {
// In case of multiple outputs / tags, map each tag to a column by index.
// At the end split the result into multiple datasets selecting one column each.
Expand Down Expand Up @@ -176,7 +183,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
}
}
} else {
PCollection<OutputT> output = cxt.getOutput(transform.getMainOutputTag());
PCollection<OutputT> output = cxt.getOutput(mainOut);
DoFnPartitionIteratorFactory<InputT, ?, WindowedValue<OutputT>> doFnMapper =
DoFnPartitionIteratorFactory.singleOutput(
cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, sideInputReader);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.runners.spark.structuredstreaming;

import static java.util.stream.Collectors.toMap;
import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;

import java.io.Serializable;
import java.util.Arrays;
Expand All @@ -29,6 +30,7 @@
import org.apache.beam.sdk.values.KV;
import org.apache.spark.sql.SparkSession;
import org.junit.rules.ExternalResource;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

Expand Down Expand Up @@ -69,6 +71,24 @@ public PipelineOptions configure(PipelineOptions options) {
return opts;
}

/** {@code true} if sessions contains cached Datasets or RDDs. */
public boolean hasCachedData() {
return !session.sharedState().cacheManager().isEmpty()
|| !session.sparkContext().getPersistentRDDs().isEmpty();
}

public TestRule clearCache() {
return new ExternalResource() {
@Override
protected void after() {
// clear cached datasets
session.sharedState().cacheManager().clearCache();
// clear cached RDDs
session.sparkContext().getPersistentRDDs().foreach(fun1(t -> t._2.unpersist(true)));
}
};
}

@Override
public Statement apply(Statement base, Description description) {
builder.appName(description.getDisplayName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import static org.junit.Assert.assertTrue;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
Expand All @@ -37,67 +36,79 @@
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestRule;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test class for beam to spark {@link ParDo} translation. */
@RunWith(JUnit4.class)
public class ParDoTest implements Serializable {
@Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());

private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
return options;
}
@ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule();

@Rule
public transient TestPipeline pipeline =
TestPipeline.fromOptions(SESSION.createPipelineOptions());

@Rule public transient TestRule clearCache = SESSION.clearCache();

@Test
public void testPardo() {
PCollection<Integer> input =
pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(ParDo.of(PLUS_ONE_DOFN));
PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
pipeline.run();

assertTrue("No usage of cache expected", !SESSION.hasCachedData());
}

@Test
public void testPardoWithOutputTagsCachedRDD() {
pardoWithOutputTags("MEMORY_ONLY");
pardoWithOutputTags("MEMORY_ONLY", true);
assertTrue("Expected cached data", SESSION.hasCachedData());
}

@Test
public void testPardoWithOutputTagsCachedDataset() {
pardoWithOutputTags("MEMORY_AND_DISK");
pardoWithOutputTags("MEMORY_AND_DISK", true);
assertTrue("Expected cached data", SESSION.hasCachedData());
}

@Test
public void testPardoWithUnusedOutputTags() {
pardoWithOutputTags("MEMORY_AND_DISK", false);
assertTrue("No usage of cache expected", !SESSION.hasCachedData());
}

private void pardoWithOutputTags(String storageLevel) {
private void pardoWithOutputTags(String storageLevel, boolean evaluateAdditionalOutputs) {
pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel);

TupleTag<Integer> even = new TupleTag<Integer>() {};
TupleTag<String> unevenAsString = new TupleTag<String>() {};
TupleTag<Integer> mainTag = new TupleTag<Integer>() {};
TupleTag<String> additionalUnevenTag = new TupleTag<String>() {};

DoFn<Integer, Integer> doFn =
new DoFn<Integer, Integer>() {
@ProcessElement
public void processElement(@Element Integer i, MultiOutputReceiver out) {
if (i % 2 == 0) {
out.get(even).output(i);
out.get(mainTag).output(i);
} else {
out.get(unevenAsString).output(i.toString());
out.get(additionalUnevenTag).output(i.toString());
}
}
};

PCollectionTuple outputs =
pipeline
.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
.apply(ParDo.of(doFn).withOutputTags(even, TupleTagList.of(unevenAsString)));
.apply(ParDo.of(doFn).withOutputTags(mainTag, TupleTagList.of(additionalUnevenTag)));

PAssert.that(outputs.get(even)).containsInAnyOrder(2, 4, 6, 8, 10);
PAssert.that(outputs.get(unevenAsString)).containsInAnyOrder("1", "3", "5", "7", "9");
PAssert.that(outputs.get(mainTag)).containsInAnyOrder(2, 4, 6, 8, 10);
if (evaluateAdditionalOutputs) {
PAssert.that(outputs.get(additionalUnevenTag)).containsInAnyOrder("1", "3", "5", "7", "9");
}
pipeline.run();
}

Expand All @@ -106,10 +117,12 @@ public void testTwoPardoInRow() {
PCollection<Integer> input =
pipeline
.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
.apply(ParDo.of(PLUS_ONE_DOFN))
.apply(ParDo.of(PLUS_ONE_DOFN));
.apply("Plus 1 (1st)", ParDo.of(PLUS_ONE_DOFN))
.apply("Plus 1 (2nd)", ParDo.of(PLUS_ONE_DOFN));
PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
pipeline.run();

assertTrue("No usage of cache expected", !SESSION.hasCachedData());
}

@Test
Expand All @@ -133,6 +146,8 @@ public void processElement(ProcessContext c) {
.withSideInputs(sideInputView));
PAssert.that(input).containsInAnyOrder(4, 5, 6, 7, 8, 9, 10);
pipeline.run();

assertTrue("No usage of cache expected", !SESSION.hasCachedData());
}

@Test
Expand All @@ -158,6 +173,8 @@ public void processElement(ProcessContext c) {

PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10);
pipeline.run();

assertTrue("No usage of cache expected", !SESSION.hasCachedData());
}

@Test
Expand All @@ -183,6 +200,8 @@ public void processElement(ProcessContext c) {
.withSideInputs(sideInputView));
PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10);
pipeline.run();

assertTrue("No usage of cache expected", !SESSION.hasCachedData());
}

private static final DoFn<Integer, Integer> PLUS_ONE_DOFN =
Expand Down

0 comments on commit 645bf35

Please sign in to comment.