diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index c76357cdf7a0..2336aa8ffb4b 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -29,6 +29,7 @@ import com.google.common.collect.Sets; import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.ServiceLoader; @@ -145,6 +146,7 @@ private void createRunnerAndConsumersForPTransformRecursively( ProcessBundleDescriptor processBundleDescriptor, SetMultimap pCollectionIdsToConsumingPTransforms, ListMultimap>> pCollectionIdsToConsumers, + Set processedPTransformIds, Consumer addStartFunction, Consumer addFinishFunction, BundleSplitListener splitListener) @@ -154,10 +156,6 @@ private void createRunnerAndConsumersForPTransformRecursively( // Since we are creating the consumers first, we know that the we are building the DAG // in reverse topological order. for (String pCollectionId : pTransform.getOutputsMap().values()) { - // If we have created the consumers for this PCollection we can skip it. - if (pCollectionIdsToConsumers.containsKey(pCollectionId)) { - continue; - } for (String consumingPTransformId : pCollectionIdsToConsumingPTransforms.get(pCollectionId)) { createRunnerAndConsumersForPTransformRecursively( @@ -168,6 +166,7 @@ private void createRunnerAndConsumersForPTransformRecursively( processBundleDescriptor, pCollectionIdsToConsumingPTransforms, pCollectionIdsToConsumers, + processedPTransformIds, addStartFunction, addFinishFunction, splitListener); @@ -185,23 +184,26 @@ private void createRunnerAndConsumersForPTransformRecursively( String.format( "Cannot process composite transform: %s", TextFormat.printToString(pTransform))); } - - urnToPTransformRunnerFactoryMap - .getOrDefault(pTransform.getSpec().getUrn(), defaultPTransformRunnerFactory) - .createRunnerForPTransform( - options, - beamFnDataClient, - beamFnStateClient, - pTransformId, - pTransform, - processBundleInstructionId, - processBundleDescriptor.getPcollectionsMap(), - processBundleDescriptor.getCodersMap(), - processBundleDescriptor.getWindowingStrategiesMap(), - pCollectionIdsToConsumers, - addStartFunction, - addFinishFunction, - splitListener); + // Skip reprocessing processed pTransforms. + if (!processedPTransformIds.contains(pTransformId)) { + urnToPTransformRunnerFactoryMap + .getOrDefault(pTransform.getSpec().getUrn(), defaultPTransformRunnerFactory) + .createRunnerForPTransform( + options, + beamFnDataClient, + beamFnStateClient, + pTransformId, + pTransform, + processBundleInstructionId, + processBundleDescriptor.getPcollectionsMap(), + processBundleDescriptor.getCodersMap(), + processBundleDescriptor.getWindowingStrategiesMap(), + pCollectionIdsToConsumers, + addStartFunction, + addFinishFunction, + splitListener); + processedPTransformIds.add(pTransformId); + } } public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) @@ -213,6 +215,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); ListMultimap>> pCollectionIdsToConsumers = ArrayListMultimap.create(); + HashSet processedPTransformIds = new HashSet<>(); List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); @@ -271,6 +274,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction bundleDescriptor, pCollectionIdsToConsumingPTransforms, pCollectionIdsToConsumers, + processedPTransformIds, startFunctions::add, finishFunctions::add, splitListener);