diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 8729bc2032ca..769128f4706d 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -205,7 +205,6 @@ def commonLegacyExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesGaugeMetrics', 'org.apache.beam.sdk.testing.UsesMultimapState', 'org.apache.beam.sdk.testing.UsesTestStream', - 'org.apache.beam.sdk.testing.UsesParDoLifecycle', // doesn't support remote runner 'org.apache.beam.sdk.testing.UsesMetricsPusher', 'org.apache.beam.sdk.testing.UsesBundleFinalizer', 'org.apache.beam.sdk.testing.UsesBoundedTrieMetrics', // Dataflow QM as of now does not support returning back BoundedTrie in metric result. @@ -520,8 +519,7 @@ task validatesRunnerV2 { excludedTests: [ 'org.apache.beam.sdk.transforms.ReshuffleTest.testReshuffleWithTimestampsStreaming', - // TODO(https://github.com/apache/beam/issues/18592): respect ParDo lifecycle. - 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testFnCallSequenceStateful', + // These tests use static state and don't work with remote execution. 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', @@ -563,7 +561,7 @@ task validatesRunnerV2Streaming { 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState', 'org.apache.beam.sdk.transforms.GroupByKeyTest.testCombiningAccumulatingProcessingTime', - // TODO(https://github.com/apache/beam/issues/18592): respect ParDo lifecycle. + // These tests use static state and don't work with remote execution. 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java index 91fb640a1757..d3f2aacc74d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java @@ -105,11 +105,32 @@ public DataflowMapTaskExecutor create( Networks.replaceDirectedNetworkNodes( network, createOutputReceiversTransform(stageName, counterSet)); - // Swap out all the ParallelInstruction nodes with Operation nodes - Networks.replaceDirectedNetworkNodes( - network, - createOperationTransformForParallelInstructionNodes( - stageName, network, options, readerFactory, sinkFactory, executionContext)); + // Swap out all the ParallelInstruction nodes with Operation nodes. While updating the network, + // we keep track of + // the created Operations so that if an exception is encountered we can properly abort started + // operations. + ArrayList createdOperations = new ArrayList<>(); + try { + Networks.replaceDirectedNetworkNodes( + network, + createOperationTransformForParallelInstructionNodes( + stageName, + network, + options, + readerFactory, + sinkFactory, + executionContext, + createdOperations)); + } catch (RuntimeException exn) { + for (Operation o : createdOperations) { + try { + o.abort(); + } catch (Exception exn2) { + exn.addSuppressed(exn2); + } + } + throw exn; + } // Collect all the operations within the network and attach all the operations as receivers // to preceding output receivers. @@ -144,7 +165,8 @@ Function createOperationTransformForParallelInstructionNodes( final PipelineOptions options, final ReaderFactory readerFactory, final SinkFactory sinkFactory, - final DataflowExecutionContext executionContext) { + final DataflowExecutionContext executionContext, + final List createdOperations) { return new TypeSafeNodeFunction(ParallelInstructionNode.class) { @Override @@ -156,20 +178,22 @@ public Node typedApply(ParallelInstructionNode node) { instruction.getOriginalName(), instruction.getSystemName(), instruction.getName()); + OperationNode result; try { DataflowOperationContext context = executionContext.createOperationContext(nameContext); if (instruction.getRead() != null) { - return createReadOperation( - network, node, options, readerFactory, executionContext, context); + result = + createReadOperation( + network, node, options, readerFactory, executionContext, context); } else if (instruction.getWrite() != null) { - return createWriteOperation(node, options, sinkFactory, executionContext, context); + result = createWriteOperation(node, options, sinkFactory, executionContext, context); } else if (instruction.getParDo() != null) { - return createParDoOperation(network, node, options, executionContext, context); + result = createParDoOperation(network, node, options, executionContext, context); } else if (instruction.getPartialGroupByKey() != null) { - return createPartialGroupByKeyOperation( - network, node, options, executionContext, context); + result = + createPartialGroupByKeyOperation(network, node, options, executionContext, context); } else if (instruction.getFlatten() != null) { - return createFlattenOperation(network, node, context); + result = createFlattenOperation(network, node, context); } else { throw new IllegalArgumentException( String.format("Unexpected instruction: %s", instruction)); @@ -177,6 +201,8 @@ public Node typedApply(ParallelInstructionNode node) { } catch (Exception e) { throw new RuntimeException(e); } + createdOperations.add(result.getOperation()); + return result; } }; } @@ -328,7 +354,6 @@ public Node typedApply(InstructionOutputNode input) { Coder coder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(cloudOutput.getCodec())); - @SuppressWarnings("unchecked") ElementCounter outputCounter = new DataflowOutputCounter( cloudOutput.getName(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java index 877e3198e91d..3128cb84f810 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java @@ -18,8 +18,8 @@ package org.apache.beam.runners.dataflow.worker.util.common.worker; import java.io.Closeable; +import java.util.ArrayList; import java.util.List; -import java.util.ListIterator; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; @@ -36,7 +36,9 @@ public class MapTaskExecutor implements WorkExecutor { private static final Logger LOG = LoggerFactory.getLogger(MapTaskExecutor.class); /** The operations in the map task, in execution order. */ - public final List operations; + public final ArrayList operations; + + private boolean closed = false; private final ExecutionStateTracker executionStateTracker; @@ -54,7 +56,7 @@ public MapTaskExecutor( CounterSet counters, ExecutionStateTracker executionStateTracker) { this.counters = counters; - this.operations = operations; + this.operations = new ArrayList<>(operations); this.executionStateTracker = executionStateTracker; } @@ -63,6 +65,7 @@ public CounterSet getOutputCounters() { return counters; } + /** May be reused if execute() returns without an exception being thrown. */ @Override public void execute() throws Exception { LOG.debug("Executing map task"); @@ -74,13 +77,11 @@ public void execute() throws Exception { // Starting a root operation such as a ReadOperation does the work // of processing the input dataset. LOG.debug("Starting operations"); - ListIterator iterator = operations.listIterator(operations.size()); - while (iterator.hasPrevious()) { + for (int i = operations.size() - 1; i >= 0; --i) { if (Thread.currentThread().isInterrupted()) { throw new InterruptedException("Worker aborted"); } - Operation op = iterator.previous(); - op.start(); + operations.get(i).start(); } // Finish operations, in forward-execution-order, so that a @@ -94,16 +95,13 @@ public void execute() throws Exception { op.finish(); } } catch (Exception | Error exn) { - LOG.debug("Aborting operations", exn); - for (Operation op : operations) { - try { - op.abort(); - } catch (Exception | Error exn2) { - exn.addSuppressed(exn2); - if (exn2 instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - } + try { + closeInternal(); + } catch (Exception closeExn) { + exn.addSuppressed(closeExn); + } + if (exn instanceof InterruptedException) { + Thread.currentThread().interrupt(); } throw exn; } @@ -164,6 +162,45 @@ public void abort() { } } + private void closeInternal() throws Exception { + Preconditions.checkState(!closed); + LOG.debug("Aborting operations"); + @Nullable Exception exn = null; + for (Operation op : operations) { + try { + op.abort(); + } catch (Exception | Error exn2) { + if (exn2 instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (exn == null) { + if (exn2 instanceof Exception) { + exn = (Exception) exn2; + } else { + exn = new RuntimeException(exn2); + } + } else { + exn.addSuppressed(exn2); + } + } + } + closed = true; + if (exn != null) { + throw exn; + } + } + + @Override + public void close() { + if (!closed) { + try { + closeInternal(); + } catch (Exception e) { + LOG.error("Exception while closing MapTaskExecutor, ignoring", e); + } + } + } + @Override public List reportProducedEmptyOutput() { List emptyOutputSinkIndexes = Lists.newArrayList(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java index e77ae309d359..3443ae0022bc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java @@ -24,11 +24,16 @@ import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray; import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -52,6 +57,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; @@ -254,8 +260,9 @@ public void testExecutionContextPlumbing() throws Exception { List instructions = Arrays.asList( createReadInstruction("Read", ReaderFactoryTest.SingletonTestReaderFactory.class), - createParDoInstruction(0, 0, "DoFn1", "DoFnUserName"), - createParDoInstruction(1, 0, "DoFnWithContext", "DoFnWithContextUserName")); + createParDoInstruction(0, 0, "DoFn1", "DoFnUserName", new TestDoFn()), + createParDoInstruction( + 1, 0, "DoFnWithContext", "DoFnWithContextUserName", new TestDoFn())); MapTask mapTask = new MapTask(); mapTask.setStageName(STAGE); @@ -330,6 +337,7 @@ public void testCreateReadOperation() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -338,11 +346,13 @@ public void testCreateReadOperation() throws Exception { PipelineOptionsFactory.create(), readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ReadOperation.class)); ReadOperation readOperation = (ReadOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(readOperation)); assertEquals(1, readOperation.receivers.length); assertEquals(0, readOperation.receivers[0].getReceiverCount()); @@ -391,6 +401,7 @@ public void testCreateWriteOperation() throws Exception { ParallelInstructionNode.create( createWriteInstruction(producerIndex, producerOutputNum, "WriteOperation"), ExecutionLocation.UNKNOWN); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -399,11 +410,13 @@ public void testCreateWriteOperation() throws Exception { options, readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(WriteOperation.class)); WriteOperation writeOperation = (WriteOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(writeOperation)); assertEquals(0, writeOperation.receivers.length); assertEquals(Operation.InitializationState.UNSTARTED, writeOperation.initializationState); @@ -461,17 +474,15 @@ public TestSink create( static ParallelInstruction createParDoInstruction( int producerIndex, int producerOutputNum, String systemName) { - return createParDoInstruction(producerIndex, producerOutputNum, systemName, ""); + return createParDoInstruction(producerIndex, producerOutputNum, systemName, "", new TestDoFn()); } static ParallelInstruction createParDoInstruction( - int producerIndex, int producerOutputNum, String systemName, String userName) { + int producerIndex, int producerOutputNum, String systemName, String userName, DoFn fn) { InstructionInput cloudInput = new InstructionInput(); cloudInput.setProducerInstructionIndex(producerIndex); cloudInput.setOutputNum(producerOutputNum); - TestDoFn fn = new TestDoFn(); - String serializedFn = StringUtils.byteArrayToJsonString( SerializableUtils.serializeToByteArray( @@ -541,14 +552,16 @@ public void testCreateParDoOperation() throws Exception { .getMultiOutputInfos() .get(0)))); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( - STAGE, network, options, readerRegistry, sinkRegistry, context) + STAGE, network, options, readerRegistry, sinkRegistry, context, createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class)); ParDoOperation parDoOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(parDoOperation)); assertEquals(1, parDoOperation.receivers.length); assertEquals(0, parDoOperation.receivers[0].getReceiverCount()); @@ -608,6 +621,7 @@ public void testCreatePartialGroupByKeyOperation() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -616,11 +630,13 @@ public void testCreatePartialGroupByKeyOperation() throws Exception { PipelineOptionsFactory.create(), readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class)); ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(pgbkOperation)); assertEquals(1, pgbkOperation.receivers.length); assertEquals(0, pgbkOperation.receivers[0].getReceiverCount()); @@ -660,6 +676,7 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -668,11 +685,13 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception { options, readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class)); ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(pgbkOperation)); assertEquals(1, pgbkOperation.receivers.length); assertEquals(0, pgbkOperation.receivers[0].getReceiverCount()); @@ -738,6 +757,7 @@ public void testCreateFlattenOperation() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -746,15 +766,108 @@ public void testCreateFlattenOperation() throws Exception { options, readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(FlattenOperation.class)); FlattenOperation flattenOperation = (FlattenOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(flattenOperation)); assertEquals(1, flattenOperation.receivers.length); assertEquals(0, flattenOperation.receivers[0].getReceiverCount()); assertEquals(Operation.InitializationState.UNSTARTED, flattenOperation.initializationState); } + + static class TestTeardownDoFn extends DoFn { + static AtomicInteger setupCalls = new AtomicInteger(); + static AtomicInteger teardownCalls = new AtomicInteger(); + + private final boolean throwExceptionOnSetup; + private boolean setupCalled = false; + + TestTeardownDoFn(boolean throwExceptionOnSetup) { + this.throwExceptionOnSetup = throwExceptionOnSetup; + } + + @Setup + public void setup() { + assertFalse(setupCalled); + setupCalled = true; + setupCalls.addAndGet(1); + if (throwExceptionOnSetup) { + throw new RuntimeException("Test setup exception"); + } + } + + @ProcessElement + public void process(ProcessContext c) { + fail("no elements should be processed"); + } + + @Teardown + public void teardown() { + assertTrue(setupCalled); + setupCalled = false; + teardownCalls.addAndGet(1); + } + } + + @Test + public void testCreateMapTaskExecutorException() throws Exception { + List instructions = + Arrays.asList( + createReadInstruction("Read"), + createParDoInstruction(0, 0, "DoFn1", "DoFn1", new TestTeardownDoFn(false)), + createParDoInstruction(0, 0, "DoFn2", "DoFn2", new TestTeardownDoFn(false)), + createParDoInstruction(0, 0, "ErrorFn", "", new TestTeardownDoFn(true)), + createParDoInstruction(0, 0, "DoFn3", "DoFn3", new TestTeardownDoFn(false)), + createFlattenInstruction(1, 0, 2, 0, "Flatten"), + createWriteInstruction(3, 0, "Write")); + + MapTask mapTask = new MapTask(); + mapTask.setStageName(STAGE); + mapTask.setSystemName("systemName"); + mapTask.setInstructions(instructions); + mapTask.setFactory(Transport.getJsonFactory()); + + assertThrows( + "Test setup exception", + RuntimeException.class, + () -> + mapTaskExecutorFactory.create( + mapTaskToNetwork.apply(mapTask), + options, + STAGE, + readerRegistry, + sinkRegistry, + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + counterSet, + idGenerator)); + assertEquals(3, TestTeardownDoFn.setupCalls.getAndSet(0)); + // We only tear-down the instruction we were unable to create. The other + // infos are cached within UserParDoFnFactory and not torn-down. + assertEquals(1, TestTeardownDoFn.teardownCalls.getAndSet(0)); + + assertThrows( + "Test setup exception", + RuntimeException.class, + () -> + mapTaskExecutorFactory.create( + mapTaskToNetwork.apply(mapTask), + options, + STAGE, + readerRegistry, + sinkRegistry, + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + counterSet, + idGenerator)); + // The non-erroring functions are cached, and a new setup call is called on + // erroring dofn. + assertEquals(1, TestTeardownDoFn.setupCalls.get()); + // We only tear-down the instruction we were unable to create. The other + // infos are cached within UserParDoFnFactory and not torn-down. + assertEquals(1, TestTeardownDoFn.teardownCalls.get()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java index bb92fca3d8be..9e45425562a3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java @@ -198,7 +198,7 @@ public void testOutputReceivers() throws Exception { new TestDoFn( ImmutableList.of( new TupleTag<>("tag1"), new TupleTag<>("tag2"), new TupleTag<>("tag3"))); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -279,7 +279,7 @@ public void testOutputReceivers() throws Exception { @SuppressWarnings("AssertionFailureIgnored") public void testUnexpectedNumberOfReceivers() throws Exception { TestDoFn fn = new TestDoFn(Collections.emptyList()); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -330,7 +330,7 @@ private List stackTraceFrameStrings(Throwable t) { @Test public void testErrorPropagation() throws Exception { TestErrorDoFn fn = new TestErrorDoFn(); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -423,7 +423,7 @@ public void testUndeclaredSideOutputs() throws Exception { new TupleTag<>("undecl1"), new TupleTag<>("undecl2"), new TupleTag<>("undecl3"))); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -485,7 +485,7 @@ public void processElement(ProcessContext c) throws Exception { } StateTestingDoFn fn = new StateTestingDoFn(); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -578,7 +578,7 @@ public void processElement(ProcessContext c) { } DoFn fn = new RepeaterDoFn(); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java index 2eeaa06eb5eb..aa10ad48d081 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java @@ -519,4 +519,45 @@ public void testAbort() throws Exception { Mockito.verify(o2, atLeastOnce()).abortReadLoop(); Mockito.verify(stateTracker).deactivate(); } + + @Test + public void testCloseAbortsOperations() throws Exception { + Operation o1 = Mockito.mock(Operation.class); + Operation o2 = Mockito.mock(Operation.class); + List operations = Arrays.asList(o1, o2); + ExecutionStateTracker stateTracker = Mockito.spy(ExecutionStateTracker.newForTest()); + + try (MapTaskExecutor executor = new MapTaskExecutor(operations, counterSet, stateTracker)) {} + + Mockito.verify(o1).abort(); + Mockito.verify(o2).abort(); + Mockito.verify(stateTracker).deactivate(); + } + + @Test + public void testExceptionAndThenCloseAbortsJustOnce() throws Exception { + Operation o1 = Mockito.mock(Operation.class); + Operation o2 = Mockito.mock(Operation.class); + Mockito.doThrow(new Exception("in start")).when(o2).start(); + + ExecutionStateTracker stateTracker = Mockito.spy(ExecutionStateTracker.newForTest()); + MapTaskExecutor executor = new MapTaskExecutor(Arrays.asList(o1, o2), counterSet, stateTracker); + try { + executor.execute(); + fail("Should have thrown"); + } catch (Exception e) { + + } + InOrder inOrder = Mockito.inOrder(o1, o2, stateTracker); + inOrder.verify(stateTracker).activate(); + inOrder.verify(o2).start(); + + // Order of abort doesn't matter + Mockito.verify(o1).abort(); + Mockito.verify(o2).abort(); + Mockito.verify(stateTracker).deactivate(); + Mockito.verifyNoMoreInteractions(o1, o2); + // Closing after already closed should not call abort again. + executor.close(); + } }