diff --git a/src/main/java/io/lettuce/core/dynamic/SimpleBatcher.java b/src/main/java/io/lettuce/core/dynamic/SimpleBatcher.java index a2ecf48834..53e23fad6b 100644 --- a/src/main/java/io/lettuce/core/dynamic/SimpleBatcher.java +++ b/src/main/java/io/lettuce/core/dynamic/SimpleBatcher.java @@ -40,6 +40,7 @@ * * @author Mark Paluch * @author Lucio Paiva + * @author Ivo Gaydajiev */ class SimpleBatcher implements Batcher { @@ -51,6 +52,11 @@ class SimpleBatcher implements Batcher { private final AtomicBoolean flushing = new AtomicBoolean(); + // forceFlushRequested indicates that a flush was requested while there is already a flush in progress + // This flag is used to ensure we will flush again after the current flush is done + // to ensure that any commands added while dispatching the current flush are also dispatched + private final AtomicBoolean forceFlushRequested = new AtomicBoolean(); + public SimpleBatcher(StatefulConnection connection, int batchSize) { LettuceAssert.isTrue(batchSize == -1 || batchSize > 1, "Batch size must be greater zero or -1"); @@ -95,37 +101,56 @@ protected BatchTasks flush(boolean forcedFlush) { List> commands = newDrainTarget(); - while (flushing.compareAndSet(false, true)) { + while (true) { + if (flushing.compareAndSet(false, true)) { + try { - try { + int consume = -1; - int consume = -1; + if (!forcedFlush) { + long queuedItems = queue.size(); + if (queuedItems >= batchSize) { + consume = batchSize; + defaultFlush = true; + } + } - if (!forcedFlush) { - long queuedItems = queue.size(); - if (queuedItems >= batchSize) { - consume = batchSize; - defaultFlush = true; + List> batch = doFlush(forcedFlush, defaultFlush, consume); + if (batch != null) { + commands.addAll(batch); } - } - List> batch = doFlush(forcedFlush, defaultFlush, consume); - if (batch != null) { - commands.addAll(batch); - } + if (defaultFlush && !queue.isEmpty() && queue.size() > batchSize) { + continue; + } + + if (forceFlushRequested.compareAndSet(true, false)) { + continue; + } - if (defaultFlush && !queue.isEmpty() && queue.size() > batchSize) { - continue; + return new BatchTasks(commands); + + } finally { + flushing.set(false); } - return new BatchTasks(commands); + } else { + // Another thread is already flushing + if (forcedFlush) { + forceFlushRequested.set(true); + } - } finally { - flushing.set(false); + if (commands.isEmpty()) { + return BatchTasks.EMPTY; + } else { + // Scenario: A default flush is started in Thread T1. + // If multiple default batches need processing, T1 will release `flushing` and try to reacquire it. + // However, in the brief moment when T1 releases `flushing`, another thread (T2) might acquire it. + // This lead to a state where T2 has taken over processing from T1 and T1 should return commands processed + return new BatchTasks(commands); + } } } - - return BatchTasks.EMPTY; } private List> doFlush(boolean forcedFlush, boolean defaultFlush, int consume) { diff --git a/src/test/java/io/lettuce/core/dynamic/SimpleBatcherUnitTests.java b/src/test/java/io/lettuce/core/dynamic/SimpleBatcherUnitTests.java index 592bfe29ba..4b8861dfc5 100644 --- a/src/test/java/io/lettuce/core/dynamic/SimpleBatcherUnitTests.java +++ b/src/test/java/io/lettuce/core/dynamic/SimpleBatcherUnitTests.java @@ -5,6 +5,7 @@ import static org.mockito.Mockito.*; import java.util.Arrays; +import java.util.concurrent.CountDownLatch; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -21,6 +22,7 @@ /** * @author Mark Paluch + * @author Ivo Gaydajiev */ @Tag(UNIT_TEST) @ExtendWith(MockitoExtension.class) @@ -127,6 +129,44 @@ void shouldBatchWithBatchControlFlush() { verify(connection).dispatch(Arrays.asList(c1, c2)); } + @Test + void shouldDispatchCommandsQueuedDuringOngoingFlush() throws InterruptedException { + RedisCommand c1 = createCommand(); + RedisCommand c2 = createCommand(); + + CountDownLatch batchFlushLatch1 = new CountDownLatch(1); + CountDownLatch batchFlushLatch2 = new CountDownLatch(1); + + when(connection.dispatch((RedisCommand) any())).thenAnswer(invocation -> { + batchFlushLatch1.countDown(); + batchFlushLatch2.await(); + + return null; + }); + + SimpleBatcher batcher = new SimpleBatcher(connection, 4); + + Thread batchThread1 = new Thread(() -> { + batcher.batch(c1, CommandBatching.flush()); + }); + batchThread1.start(); + + Thread batchThread2 = new Thread(() -> { + try { + batchFlushLatch1.await(); + } catch (InterruptedException ignored) { + } + batcher.batch(c2, CommandBatching.flush()); + batchFlushLatch2.countDown(); + }); + batchThread2.start(); + + batchThread1.join(); + batchThread2.join(); + verify(connection, times(1)).dispatch(c1); + verify(connection, times(1)).dispatch(c2); + } + private static RedisCommand createCommand() { return new AsyncCommand<>(new Command<>(CommandType.COMMAND, null, null)); }