diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java index 7c9e0b9dcf11d..ee81685daecc1 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java @@ -105,6 +105,9 @@ public class AsyncWaitOperator private transient TimestampedCollector timestampedCollector; + /** Whether object reuse has been enabled or disabled. */ + private transient boolean isObjectReuseEnabled; + public AsyncWaitOperator( @Nonnull AsyncFunction asyncFunction, long timeout, @@ -158,6 +161,8 @@ public void setup( public void open() throws Exception { super.open(); + this.isObjectReuseEnabled = getExecutionConfig().isObjectReuseEnabled(); + if (recoveredStreamElements != null) { for (StreamElement element : recoveredStreamElements.get()) { if (element.isRecord()) { @@ -178,7 +183,16 @@ public void open() throws Exception { } @Override - public void processElement(StreamRecord element) throws Exception { + public void processElement(StreamRecord record) throws Exception { + StreamRecord element; + // copy the element avoid the element is reused + if (isObjectReuseEnabled) { + //noinspection unchecked + element = (StreamRecord) inStreamElementSerializer.copy(record); + } else { + element = record; + } + // add element first to the queue final ResultFuture entry = addToWorkQueue(element); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java index 19184ac1e1653..986989f320b61 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java @@ -22,9 +22,12 @@ import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.Utils; +import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; @@ -100,9 +103,10 @@ public class AsyncWaitOperatorTest extends TestLogger { private static final long TIMEOUT = 1000L; - @Rule public Timeout timeoutRule = new Timeout(10, TimeUnit.SECONDS); + @Rule public Timeout timeoutRule = new Timeout(100, TimeUnit.SECONDS); - private static class MyAsyncFunction extends RichAsyncFunction { + private abstract static class MyAbstractAsyncFunction + extends RichAsyncFunction { private static final long serialVersionUID = 8522411971886428444L; private static final long TERMINATION_TIMEOUT = 5000L; @@ -115,7 +119,7 @@ private static class MyAsyncFunction extends RichAsyncFunction public void open(Configuration parameters) throws Exception { super.open(parameters); - synchronized (MyAsyncFunction.class) { + synchronized (MyAbstractAsyncFunction.class) { if (counter == 0) { executorService = Executors.newFixedThreadPool(THREAD_POOL_SIZE); } @@ -132,7 +136,7 @@ public void close() throws Exception { } private void freeExecutor() { - synchronized (MyAsyncFunction.class) { + synchronized (MyAbstractAsyncFunction.class) { --counter; if (counter == 0) { @@ -151,6 +155,10 @@ private void freeExecutor() { } } } + } + + private static class MyAsyncFunction extends MyAbstractAsyncFunction { + private static final long serialVersionUID = -1504699677704123889L; @Override public void asyncInvoke(final Integer input, final ResultFuture resultFuture) @@ -183,7 +191,7 @@ public LazyAsyncFunction() { @Override public void asyncInvoke(final Integer input, final ResultFuture resultFuture) throws Exception { - this.executorService.submit( + executorService.submit( new Runnable() { @Override public void run() { @@ -203,6 +211,23 @@ public static void countDown() { } } + private static class InputReusedAsyncFunction extends MyAbstractAsyncFunction> { + + private static final long serialVersionUID = 8627909616410487720L; + + @Override + public void asyncInvoke(Tuple1 input, ResultFuture resultFuture) + throws Exception { + executorService.submit( + new Runnable() { + @Override + public void run() { + resultFuture.complete(Collections.singletonList(input.f0 * 2)); + } + }); + } + } + /** * A special {@link LazyAsyncFunction} for timeout handling. Complete the result future with 3 * times the input when the timeout occurred. @@ -596,6 +621,67 @@ public void testStateSnapshotAndRestore() throws Exception { restoredTaskHarness.getOutput()); } + @SuppressWarnings("rawtypes") + @Test + public void testStateSnapshotAndRestoreWithObjectReused() throws Exception { + TypeSerializer[] fieldSerializers = new TypeSerializer[] {IntSerializer.INSTANCE}; + TupleSerializer inputSerializer = + new TupleSerializer<>(Tuple1.class, fieldSerializers); + AsyncWaitOperatorFactory, Integer> factory = + new AsyncWaitOperatorFactory<>( + new InputReusedAsyncFunction(), + TIMEOUT, + 4, + AsyncDataStream.OutputMode.ORDERED); + + //noinspection unchecked + final OneInputStreamOperatorTestHarness, Integer> testHarness = + new OneInputStreamOperatorTestHarness(factory, inputSerializer); + // enable object reuse + testHarness.getExecutionConfig().enableObjectReuse(); + + final long initialTime = 0L; + Tuple1 reusedTuple = new Tuple1<>(); + StreamRecord> reusedRecord = new StreamRecord<>(reusedTuple, -1L); + + testHarness.setup(); + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + reusedTuple.setFields(1); + reusedRecord.setTimestamp(initialTime + 1); + testHarness.processElement(reusedRecord); + + reusedTuple.setFields(2); + reusedRecord.setTimestamp(initialTime + 2); + testHarness.processElement(reusedRecord); + + reusedTuple.setFields(3); + reusedRecord.setTimestamp(initialTime + 3); + testHarness.processElement(reusedRecord); + + reusedTuple.setFields(4); + reusedRecord.setTimestamp(initialTime + 4); + testHarness.processElement(reusedRecord); + } + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(new StreamRecord<>(2, initialTime + 1)); + expectedOutput.add(new StreamRecord<>(4, initialTime + 2)); + expectedOutput.add(new StreamRecord<>(6, initialTime + 3)); + expectedOutput.add(new StreamRecord<>(8, initialTime + 4)); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.endInput(); + testHarness.close(); + } + + TestHarnessUtil.assertOutputEquals( + "StateAndRestoredWithObjectReuse Test Output was not correct.", + expectedOutput, + testHarness.getOutput()); + } + @Test public void testAsyncTimeoutFailure() throws Exception { testAsyncTimeout(