From 645bf35056dcd852a6e034cc912f37dfb7707314 Mon Sep 17 00:00:00 2001
From: Moritz Mack <mmack@talend.com>
Date: Wed, 28 Dec 2022 17:51:59 +0100
Subject: [PATCH] [Spark Dataset runner] Skip unconsumed additional outputs of
 ParDo.MultiOutput to avoid caching if not necessary (resolves #24710)
 (#24711)

---
 .../translation/PipelineTranslator.java       |  7 ++
 .../translation/TransformTranslator.java      |  5 ++
 .../batch/DoFnPartitionIteratorFactory.java   | 36 ++++++----
 .../batch/ParDoTranslatorBatch.java           | 11 ++-
 .../structuredstreaming/SparkSessionRule.java | 20 ++++++
 .../translation/batch/ParDoTest.java          | 69 ++++++++++++-------
 6 files changed, 106 insertions(+), 42 deletions(-)

diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
index 8d751d5d8173..05f542702f19 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
@@ -152,6 +152,8 @@ public String name() {
   public interface TranslationState extends EncoderProvider {
     <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
 
+    boolean isLeave(PCollection<?> pCollection);
+
     <T> void putDataset(
         PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean cache);
 
@@ -256,6 +258,11 @@ public <T> void putDataset(
       }
     }
 
+    @Override
+    public boolean isLeave(PCollection<?> pCollection) {
+      return getResult(pCollection).dependentTransforms.isEmpty();
+    }
+
     @Override
     public <T> Broadcast<SideInputValues<T>> getSideInputBroadcast(
         PCollection<T> pCollection, SideInputValues.Loader<T> loader) {
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
index 8a3c7579f541..e0bbb2af820e 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
@@ -149,6 +149,11 @@ public <T> void putDataset(
       state.putDataset(pCollection, dataset, cache);
     }
 
+    @Override
+    public boolean isLeave(PCollection<?> pCollection) {
+      return state.isLeave(pCollection);
+    }
+
     @Override
     public Supplier<PipelineOptions> getOptionsSupplier() {
       return state.getOptionsSupplier();
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
index c760efd229c8..64a4f591ff74 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
@@ -19,7 +19,6 @@
 
 import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator;
 import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
-import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 
 import java.io.Serializable;
 import java.util.ArrayDeque;
@@ -61,17 +60,17 @@
  */
 abstract class DoFnPartitionIteratorFactory<InT, FnOutT, OutT extends @NonNull Object>
     implements Function1<Iterator<WindowedValue<InT>>, Iterator<OutT>>, Serializable {
-  private final String stepName;
-  private final DoFn<InT, FnOutT> doFn;
-  private final DoFnSchemaInformation doFnSchema;
-  private final Supplier<PipelineOptions> options;
-  private final Coder<InT> coder;
-  private final WindowingStrategy<?, ?> windowingStrategy;
-  private final TupleTag<FnOutT> mainOutput;
-  private final List<TupleTag<?>> additionalOutputs;
-  private final Map<TupleTag<?>, Coder<?>> outputCoders;
-  private final Map<String, PCollectionView<?>> sideInputs;
-  private final SideInputReader sideInputReader;
+  protected final String stepName;
+  protected final DoFn<InT, FnOutT> doFn;
+  protected final DoFnSchemaInformation doFnSchema;
+  protected final Supplier<PipelineOptions> options;
+  protected final Coder<InT> coder;
+  protected final WindowingStrategy<?, ?> windowingStrategy;
+  protected final TupleTag<FnOutT> mainOutput;
+  protected final List<TupleTag<?>> additionalOutputs;
+  protected final Map<TupleTag<?>, Coder<?>> outputCoders;
+  protected final Map<String, PCollectionView<?>> sideInputs;
+  protected final SideInputReader sideInputReader;
 
   private DoFnPartitionIteratorFactory(
       AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, FnOutT>> appliedPT,
@@ -147,7 +146,11 @@ DoFnRunners.OutputManager outputManager(Deque<WindowedValue<OutT>> buffer) {
       return new DoFnRunners.OutputManager() {
         @Override
         public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
-          buffer.add((WindowedValue<OutT>) output);
+          // SingleOut will only ever emmit the mainOutput. Though, there might be additional
+          // outputs which are skipped if unused to avoid caching.
+          if (mainOutput.equals(tag)) {
+            buffer.add((WindowedValue<OutT>) output);
+          }
         }
       };
     }
@@ -177,8 +180,11 @@ DoFnRunners.OutputManager outputManager(Deque<Tuple2<Integer, WindowedValue<OutT
       return new DoFnRunners.OutputManager() {
         @Override
         public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
-          Integer columnIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag %s", tag);
-          buffer.add(tuple(columnIdx, (WindowedValue<OutT>) output));
+          // Additional unused outputs can be skipped here. In that case columnIdx is null.
+          Integer columnIdx = tagColIdx.get(tag.getId());
+          if (columnIdx != null) {
+            buffer.add(tuple(columnIdx, (WindowedValue<OutT>) output));
+          }
         }
       };
     }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 3083ff5101b9..4d545e438133 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -110,12 +110,19 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
       throws IOException {
 
     PCollection<InputT> input = (PCollection<InputT>) cxt.getInput();
-    Map<TupleTag<?>, PCollection<?>> outputs = cxt.getOutputs();
 
     Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input);
     SideInputReader sideInputReader =
         createSideInputReader(transform.getSideInputs().values(), cxt);
 
+    TupleTag<OutputT> mainOut = transform.getMainOutputTag();
+    // Filter out unconsumed PCollections (except mainOut) to potentially avoid the costs of caching
+    // if not really beneficial.
+    Map<TupleTag<?>, PCollection<?>> outputs =
+        Maps.filterEntries(
+            cxt.getOutputs(),
+            e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeave(e.getValue())));
+
     if (outputs.size() > 1) {
       // In case of multiple outputs / tags, map each tag to a column by index.
       // At the end split the result into multiple datasets selecting one column each.
@@ -176,7 +183,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
         }
       }
     } else {
-      PCollection<OutputT> output = cxt.getOutput(transform.getMainOutputTag());
+      PCollection<OutputT> output = cxt.getOutput(mainOut);
       DoFnPartitionIteratorFactory<InputT, ?, WindowedValue<OutputT>> doFnMapper =
           DoFnPartitionIteratorFactory.singleOutput(
               cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, sideInputReader);
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
index 33eef26dddda..278fd012d77e 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
@@ -18,6 +18,7 @@
 package org.apache.beam.runners.spark.structuredstreaming;
 
 import static java.util.stream.Collectors.toMap;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
 
 import java.io.Serializable;
 import java.util.Arrays;
@@ -29,6 +30,7 @@
 import org.apache.beam.sdk.values.KV;
 import org.apache.spark.sql.SparkSession;
 import org.junit.rules.ExternalResource;
+import org.junit.rules.TestRule;
 import org.junit.runner.Description;
 import org.junit.runners.model.Statement;
 
@@ -69,6 +71,24 @@ public PipelineOptions configure(PipelineOptions options) {
     return opts;
   }
 
+  /** {@code true} if sessions contains cached Datasets or RDDs. */
+  public boolean hasCachedData() {
+    return !session.sharedState().cacheManager().isEmpty()
+        || !session.sparkContext().getPersistentRDDs().isEmpty();
+  }
+
+  public TestRule clearCache() {
+    return new ExternalResource() {
+      @Override
+      protected void after() {
+        // clear cached datasets
+        session.sharedState().cacheManager().clearCache();
+        // clear cached RDDs
+        session.sparkContext().getPersistentRDDs().foreach(fun1(t -> t._2.unpersist(true)));
+      }
+    };
+  }
+
   @Override
   public Statement apply(Statement base, Description description) {
     builder.appName(description.getDisplayName());
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
index f319173ed2bb..672a2db4fe1e 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
@@ -17,14 +17,13 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
+import static org.junit.Assert.assertTrue;
+
 import java.io.Serializable;
 import java.util.List;
 import java.util.Map;
 import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
@@ -37,23 +36,23 @@
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
+import org.junit.ClassRule;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TestRule;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Test class for beam to spark {@link ParDo} translation. */
 @RunWith(JUnit4.class)
 public class ParDoTest implements Serializable {
-  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
-
-  private static PipelineOptions testOptions() {
-    SparkStructuredStreamingPipelineOptions options =
-        PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
-    options.setRunner(SparkStructuredStreamingRunner.class);
-    options.setTestMode(true);
-    return options;
-  }
+  @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule();
+
+  @Rule
+  public transient TestPipeline pipeline =
+      TestPipeline.fromOptions(SESSION.createPipelineOptions());
+
+  @Rule public transient TestRule clearCache = SESSION.clearCache();
 
   @Test
   public void testPardo() {
@@ -61,32 +60,42 @@ public void testPardo() {
         pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(ParDo.of(PLUS_ONE_DOFN));
     PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
     pipeline.run();
+
+    assertTrue("No usage of cache expected", !SESSION.hasCachedData());
   }
 
   @Test
   public void testPardoWithOutputTagsCachedRDD() {
-    pardoWithOutputTags("MEMORY_ONLY");
+    pardoWithOutputTags("MEMORY_ONLY", true);
+    assertTrue("Expected cached data", SESSION.hasCachedData());
   }
 
   @Test
   public void testPardoWithOutputTagsCachedDataset() {
-    pardoWithOutputTags("MEMORY_AND_DISK");
+    pardoWithOutputTags("MEMORY_AND_DISK", true);
+    assertTrue("Expected cached data", SESSION.hasCachedData());
+  }
+
+  @Test
+  public void testPardoWithUnusedOutputTags() {
+    pardoWithOutputTags("MEMORY_AND_DISK", false);
+    assertTrue("No usage of cache expected", !SESSION.hasCachedData());
   }
 
-  private void pardoWithOutputTags(String storageLevel) {
+  private void pardoWithOutputTags(String storageLevel, boolean evaluateAdditionalOutputs) {
     pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel);
 
-    TupleTag<Integer> even = new TupleTag<Integer>() {};
-    TupleTag<String> unevenAsString = new TupleTag<String>() {};
+    TupleTag<Integer> mainTag = new TupleTag<Integer>() {};
+    TupleTag<String> additionalUnevenTag = new TupleTag<String>() {};
 
     DoFn<Integer, Integer> doFn =
         new DoFn<Integer, Integer>() {
           @ProcessElement
           public void processElement(@Element Integer i, MultiOutputReceiver out) {
             if (i % 2 == 0) {
-              out.get(even).output(i);
+              out.get(mainTag).output(i);
             } else {
-              out.get(unevenAsString).output(i.toString());
+              out.get(additionalUnevenTag).output(i.toString());
             }
           }
         };
@@ -94,10 +103,12 @@ public void processElement(@Element Integer i, MultiOutputReceiver out) {
     PCollectionTuple outputs =
         pipeline
             .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
-            .apply(ParDo.of(doFn).withOutputTags(even, TupleTagList.of(unevenAsString)));
+            .apply(ParDo.of(doFn).withOutputTags(mainTag, TupleTagList.of(additionalUnevenTag)));
 
-    PAssert.that(outputs.get(even)).containsInAnyOrder(2, 4, 6, 8, 10);
-    PAssert.that(outputs.get(unevenAsString)).containsInAnyOrder("1", "3", "5", "7", "9");
+    PAssert.that(outputs.get(mainTag)).containsInAnyOrder(2, 4, 6, 8, 10);
+    if (evaluateAdditionalOutputs) {
+      PAssert.that(outputs.get(additionalUnevenTag)).containsInAnyOrder("1", "3", "5", "7", "9");
+    }
     pipeline.run();
   }
 
@@ -106,10 +117,12 @@ public void testTwoPardoInRow() {
     PCollection<Integer> input =
         pipeline
             .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
-            .apply(ParDo.of(PLUS_ONE_DOFN))
-            .apply(ParDo.of(PLUS_ONE_DOFN));
+            .apply("Plus 1 (1st)", ParDo.of(PLUS_ONE_DOFN))
+            .apply("Plus 1 (2nd)", ParDo.of(PLUS_ONE_DOFN));
     PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
     pipeline.run();
+
+    assertTrue("No usage of cache expected", !SESSION.hasCachedData());
   }
 
   @Test
@@ -133,6 +146,8 @@ public void processElement(ProcessContext c) {
                     .withSideInputs(sideInputView));
     PAssert.that(input).containsInAnyOrder(4, 5, 6, 7, 8, 9, 10);
     pipeline.run();
+
+    assertTrue("No usage of cache expected", !SESSION.hasCachedData());
   }
 
   @Test
@@ -158,6 +173,8 @@ public void processElement(ProcessContext c) {
 
     PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10);
     pipeline.run();
+
+    assertTrue("No usage of cache expected", !SESSION.hasCachedData());
   }
 
   @Test
@@ -183,6 +200,8 @@ public void processElement(ProcessContext c) {
                     .withSideInputs(sideInputView));
     PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10);
     pipeline.run();
+
+    assertTrue("No usage of cache expected", !SESSION.hasCachedData());
   }
 
   private static final DoFn<Integer, Integer> PLUS_ONE_DOFN =