Skip to content

Commit 9d56ce6

Browse files
refactor poller and handle race conditions on consumer reinitializations
Signed-off-by: Varun Bharadwaj <varunbharadwaj1995@gmail.com>
1 parent 31119ad commit 9d56ce6

File tree

5 files changed

+118
-75
lines changed

5 files changed

+118
-75
lines changed

CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3030
- Implement GRPC FunctionScoreQuery ([#19888](https://github.com/opensearch-project/OpenSearch/pull/19888))
3131
- Implement error_trace parameter for bulk requests ([#19985](https://github.com/opensearch-project/OpenSearch/pull/19985))
3232
- Allow the truncate filter in normalizers ([#19778](https://github.com/opensearch-project/OpenSearch/issues/19778))
33-
- Support pull-based ingestion message mappers and raw payload support ([#19765](https://github.com/opensearch-project/OpenSearch/pull/19765)]
3433
- Support pull-based ingestion message mappers and raw payload support ([#19765](https://github.com/opensearch-project/OpenSearch/pull/19765))
3534
- Support dynamic consumer configuration update in pull-based ingestion ([#19963](https://github.com/opensearch-project/OpenSearch/pull/19963))
3635

server/src/main/java/org/opensearch/index/engine/IngestionEngine.java

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -534,25 +534,16 @@ private void updateIngestionSourceParams(Map<String, Object> updatedParams) {
534534
return;
535535
}
536536

537-
try {
538-
logger.info("Ingestion source params updated, reinitializing consumer");
539-
540-
// Get current ingestion source with updated params from index metadata
541-
IndexMetadata indexMetadata = engineConfig.getIndexSettings().getIndexMetadata();
542-
assert indexMetadata != null;
543-
IngestionSource updatedIngestionSource = Objects.requireNonNull(indexMetadata.getIngestionSource());
537+
logger.info("Ingestion source params updated, reinitializing consumer");
544538

545-
// Initialize the factory with updated params
546-
ingestionConsumerFactory.initialize(updatedIngestionSource);
547-
548-
// Request consumer reinitialization in the poller
549-
streamPoller.requestConsumerReinitialization();
539+
// Get current ingestion source with updated params from index metadata
540+
IndexMetadata indexMetadata = engineConfig.getIndexSettings().getIndexMetadata();
541+
assert indexMetadata != null;
542+
IngestionSource updatedIngestionSource = Objects.requireNonNull(indexMetadata.getIngestionSource());
550543

551-
logger.info("Successfully processed ingestion source params update");
552-
} catch (Exception e) {
553-
logger.error("Failed to update ingestion source params", e);
554-
throw new OpenSearchException("Failed to update ingestion source params", e);
555-
}
544+
// Request consumer reinitialization in the poller
545+
streamPoller.requestConsumerReinitialization(updatedIngestionSource);
546+
logger.info("Successfully processed ingestion source params update");
556547
}
557548

558549
/**

server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.cluster.ClusterChangedEvent;
1414
import org.opensearch.cluster.ClusterState;
1515
import org.opensearch.cluster.block.ClusterBlockLevel;
16+
import org.opensearch.cluster.metadata.IngestionSource;
1617
import org.opensearch.common.Nullable;
1718
import org.opensearch.common.metrics.CounterMetric;
1819
import org.opensearch.index.IndexSettings;
@@ -66,7 +67,6 @@ public class DefaultStreamPoller implements StreamPoller {
6667

6768
// start of the batch, inclusive
6869
private IngestionShardPointer initialBatchStartPointer;
69-
private boolean includeBatchStartPointer = false;
7070

7171
private ResetState resetState;
7272
private final String resetValue;
@@ -86,6 +86,9 @@ public class DefaultStreamPoller implements StreamPoller {
8686

8787
private PartitionedBlockingQueueContainer blockingQueueContainer;
8888

89+
// Force the consumer to start reading from this pointer. This is used in case of failures, or during reinitialization.
90+
private IngestionShardPointer forcedShardPointer = null;
91+
8992
private DefaultStreamPoller(
9093
IngestionShardPointer startPointer,
9194
IngestionConsumerFactory consumerFactory,
@@ -173,8 +176,6 @@ public void start() {
173176
}
174177

175178
started = true;
176-
// when we start, we need to include the batch start pointer in the read for the first read
177-
includeBatchStartPointer = true;
178179
consumerThread.submit(this::startPoll);
179180
blockingQueueContainer.startProcessorThreads();
180181
}
@@ -191,33 +192,20 @@ protected void startPoll() {
191192
}
192193
logger.info("Starting poller for shard {}", shardId);
193194

194-
// Force the consumer to start reading from this pointer. This is used in case of failures, or during reinitialization.
195-
IngestionShardPointer forcedShardPointer = null;
196195
while (true) {
197196
try {
198197
if (closed) {
199198
state = State.CLOSED;
199+
closeConsumer();
200200
break;
201201
}
202202

203-
// Initialize consumer if not already initialized or if reinitialization is requested
203+
// Initialize/reinitialization consumer
204204
if (this.consumer == null || reinitializeConsumer) {
205-
handleConsumerInitialization(CONSUMER_INIT_RETRY_INTERVAL_MS);
206-
207-
// If consumer reinitialization is requested, update the new consumer's start offset to the latest batch start pointer.
208-
// First clear the blocking queue partitions, and then retrieve the latest batch start pointer. This
209-
// will ensure we resume from the earliest offset possible without too much duplicate processing.
210-
if (reinitializeConsumer && includeBatchStartPointer == false) {
211-
blockingQueueContainer.clearAllQueues();
212-
forcedShardPointer = getBatchStartPointer();
213-
}
214-
reinitializeConsumer = false;
205+
handleConsumerInitialization();
215206
continue;
216207
}
217208

218-
// reset the consumer offset
219-
handleResetState();
220-
221209
// Update lag periodically. Lag is updated even if the poller is paused.
222210
updatePointerBasedLagIfNeeded();
223211

@@ -234,10 +222,8 @@ protected void startPoll() {
234222
state = State.POLLING;
235223
List<IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message>> results;
236224

237-
if (includeBatchStartPointer) {
238-
results = consumer.readNext(initialBatchStartPointer, true, maxPollSize, pollTimeout);
239-
includeBatchStartPointer = false;
240-
} else if (forcedShardPointer != null) {
225+
// Force the consumer to start from forcedShardPointer if available
226+
if (forcedShardPointer != null) {
241227
results = consumer.readNext(forcedShardPointer, true, maxPollSize, pollTimeout);
242228
forcedShardPointer = null;
243229
} else {
@@ -335,7 +321,6 @@ public void close() {
335321
closed = true;
336322
if (!started) {
337323
logger.info("consumer thread not started");
338-
closeConsumer();
339324
return;
340325
}
341326
long startTime = System.currentTimeMillis(); // Record the start time
@@ -354,7 +339,6 @@ public void close() {
354339
}
355340
consumerThread.shutdown();
356341
blockingQueueContainer.close();
357-
closeConsumer();
358342
logger.info("closed the poller of shard {}", shardId);
359343
}
360344

@@ -478,15 +462,18 @@ public IngestionShardConsumer getConsumer() {
478462

479463
/**
480464
* Mark the poller's consumer for reinitialization. A new consumer will be initialized and start consuming from the
481-
* latest batchStartPointer.
465+
* latest batchStartPointer. This method also reinitializes the consumer factory with the updated ingestion source.
466+
* @param updatedIngestionSource the updated ingestion source with new configuration parameters
482467
*/
483468
@Override
484-
public void requestConsumerReinitialization() {
469+
public synchronized void requestConsumerReinitialization(IngestionSource updatedIngestionSource) {
485470
if (closed) {
486471
logger.warn("Cannot reinitialize consumer for closed poller of shard {}", shardId);
487472
return;
488473
}
489474

475+
// Reinitialize the consumer factory with updated configuration
476+
consumerFactory.initialize(updatedIngestionSource);
490477
logger.info("Configuration parameters updated for index {} shard {}, requesting consumer reinitialization", indexName, shardId);
491478
reinitializeConsumer = true;
492479
}
@@ -508,63 +495,75 @@ public void clusterChanged(ClusterChangedEvent event) {
508495
}
509496

510497
/**
511-
* Handles the reset state logic, updating the initialBatchStartPointer based on the reset state.
498+
* Handles the reset state logic.
499+
* Returns the resulting IngestionShardPointer for reset or null if no reset required.
512500
*/
513-
private void handleResetState() {
514-
if (resetState != ResetState.NONE) {
501+
private IngestionShardPointer getResetShardPointer() {
502+
IngestionShardPointer resetPointer = null;
503+
if (resetState != ResetState.NONE && consumer != null) {
515504
switch (resetState) {
516505
case EARLIEST:
517-
initialBatchStartPointer = consumer.earliestPointer();
518-
logger.info("Resetting pointer by seeking to earliest pointer {}", initialBatchStartPointer.asString());
506+
resetPointer = consumer.earliestPointer();
507+
logger.info("Resetting pointer by seeking to earliest pointer {}", resetPointer.asString());
519508
break;
520509
case LATEST:
521-
initialBatchStartPointer = consumer.latestPointer();
522-
logger.info("Resetting pointer by seeking to latest pointer {}", initialBatchStartPointer.asString());
510+
resetPointer = consumer.latestPointer();
511+
logger.info("Resetting pointer by seeking to latest pointer {}", resetPointer.asString());
523512
break;
524513
case RESET_BY_OFFSET:
525-
initialBatchStartPointer = consumer.pointerFromOffset(resetValue);
526-
logger.info("Resetting pointer by seeking to pointer {}", initialBatchStartPointer.asString());
514+
resetPointer = consumer.pointerFromOffset(resetValue);
515+
logger.info("Resetting pointer by seeking to pointer {}", resetPointer.asString());
527516
break;
528517
case RESET_BY_TIMESTAMP:
529-
initialBatchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue));
518+
resetPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue));
530519
logger.info(
531520
"Resetting pointer by seeking to timestamp {}, corresponding pointer {}",
532521
resetValue,
533-
initialBatchStartPointer.asString()
522+
resetPointer.asString()
534523
);
535524
break;
536525
}
537526
resetState = ResetState.NONE;
538527
}
528+
529+
return resetPointer;
539530
}
540531

541532
/**
542-
* Handles consumer initialization and reinitialization logic.
543-
* If reinitializing, closes the existing consumer before creating a new one.
544-
*
545-
* @param sleepDurationOnError duration to sleep if initialization fails
533+
* Handles consumer initialization and reinitialization logic. Closes existing consumer if available and clears the
534+
* blocking queues before initializing a new consumer. Also forces the consumer to start reading from the initial
535+
* batchStartPointer if first time initialization, or from the latest available batchStartPointer on reinitialization.
546536
*/
547-
private void handleConsumerInitialization(int sleepDurationOnError) {
548-
if (reinitializeConsumer) {
549-
logger.info("Reinitializing consumer for index {} shard {} due to configuration changes", indexName, shardId);
550-
closeConsumer();
537+
private void handleConsumerInitialization() {
538+
closeConsumer();
539+
blockingQueueContainer.clearAllQueues();
540+
initializeConsumer();
541+
542+
// Handle consumer offset reset the first time an index is created. The reset offset takes precedence if available.
543+
IngestionShardPointer resetShardPointer = getResetShardPointer();
544+
if (resetShardPointer != null) {
545+
initialBatchStartPointer = resetShardPointer;
551546
}
552-
initializeConsumer(sleepDurationOnError);
547+
548+
// Force the consumer to start from the batchStartPointer. This will be the initialBatchStartPointer for first
549+
// time initialization, or the latest batchStartPointer based on processed messages.
550+
forcedShardPointer = getBatchStartPointer();
553551
}
554552

555553
/**
556554
* Initializes the consumer using the provided consumerFactory. If an error is encountered during initialization,
557555
* the poller thread sleeps for the provided duration before retrying/proceeding with the polling loop.
558556
*/
559-
private void initializeConsumer(int sleepDurationOnError) {
557+
private synchronized void initializeConsumer() {
560558
try {
559+
reinitializeConsumer = false;
561560
this.consumer = consumerFactory.createShardConsumer(consumerClientId, shardId);
562561
logger.info("Successfully initialized consumer for shard {}", shardId);
563562
} catch (Exception e) {
564563
logger.warn("Failed to create consumer for shard {}: {}", shardId, e.getMessage());
565564
totalConsumerErrorCount.inc();
566565
try {
567-
Thread.sleep(sleepDurationOnError);
566+
Thread.sleep(CONSUMER_INIT_RETRY_INTERVAL_MS);
568567
} catch (InterruptedException ie) {
569568
Thread.currentThread().interrupt();
570569
}

server/src/main/java/org/opensearch/indices/pollingingest/StreamPoller.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
package org.opensearch.indices.pollingingest;
1010

1111
import org.opensearch.cluster.ClusterStateListener;
12+
import org.opensearch.cluster.metadata.IngestionSource;
1213
import org.opensearch.common.annotation.ExperimentalApi;
1314
import org.opensearch.index.IngestionShardConsumer;
1415
import org.opensearch.index.IngestionShardPointer;
@@ -76,10 +77,11 @@ public interface StreamPoller extends Closeable, ClusterStateListener {
7677
IngestionShardConsumer getConsumer();
7778

7879
/**
79-
* Requests the poller to reinitialize the consumer.
80+
* Requests the poller to reinitialize the consumer with updated ingestion source configuration.
8081
* This is called when ingestion source params are dynamically updated.
82+
* @param updatedIngestionSource the updated ingestion source with new configuration parameters
8183
*/
82-
void requestConsumerReinitialization();
84+
void requestConsumerReinitialization(IngestionSource updatedIngestionSource);
8385

8486
/**
8587
* A state to indicate the current state of the poller

server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.opensearch.cluster.block.ClusterBlock;
1515
import org.opensearch.cluster.block.ClusterBlockLevel;
1616
import org.opensearch.cluster.block.ClusterBlocks;
17+
import org.opensearch.cluster.metadata.IngestionSource;
1718
import org.opensearch.common.settings.Settings;
1819
import org.opensearch.core.rest.RestStatus;
1920
import org.opensearch.index.IndexSettings;
@@ -29,6 +30,7 @@
2930

3031
import java.nio.charset.StandardCharsets;
3132
import java.util.ArrayList;
33+
import java.util.Arrays;
3234
import java.util.Collections;
3335
import java.util.EnumSet;
3436
import java.util.List;
@@ -183,6 +185,10 @@ public void testResetStateEarliest() throws InterruptedException {
183185
}
184186

185187
public void testResetStateLatest() throws InterruptedException {
188+
// Clear messages first and add them after poller starts
189+
// This ensures latestPointer() returns the correct value at initialization time
190+
messages.clear();
191+
186192
poller = new DefaultStreamPoller(
187193
new FakeIngestionSource.FakeIngestionShardPointer(0),
188194
fakeConsumerFactory,
@@ -200,12 +206,28 @@ public void testResetStateLatest() throws InterruptedException {
200206
new DefaultIngestionMessageMapper()
201207
);
202208

209+
// Set up latch to wait for 2 messages to be processed
210+
CountDownLatch latch = new CountDownLatch(2);
211+
doAnswer(invocation -> {
212+
latch.countDown();
213+
return null;
214+
}).when(processor).process(any(), any());
215+
203216
poller.start();
204217
waitUntil(() -> poller.getState() == DefaultStreamPoller.State.POLLING, awaitTime, TimeUnit.MILLISECONDS);
205-
// no messages processed
206-
verify(processor, never()).process(any(), any());
207-
// reset to the latest
208-
assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(2), poller.getBatchStartPointer());
218+
219+
// Verify batch start pointer was set to latest (which is 0 since messages list was empty)
220+
assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(0), poller.getBatchStartPointer());
221+
222+
// Now add messages after poller has started with LATEST reset
223+
messages.add("{\"_id\":\"1\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8));
224+
messages.add("{\"_id\":\"2\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8));
225+
226+
// Wait for messages to be processed
227+
latch.await();
228+
229+
// Verify that the messages added after starting from latest are processed
230+
verify(processor, times(2)).process(any(), any());
209231
}
210232

211233
public void testResetStateRewindByOffset() throws InterruptedException {
@@ -610,7 +632,9 @@ public void testConsumerReinitializationAfterProcessingMessages() throws Excepti
610632
return null;
611633
}).when(processor).process(any(), any());
612634

613-
poller.requestConsumerReinitialization();
635+
// Create a mock ingestion source for reinitialization
636+
IngestionSource mockIngestionSource = new IngestionSource.Builder("test").build();
637+
poller.requestConsumerReinitialization(mockIngestionSource);
614638

615639
// Add a 3rd message
616640
messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"charlie\", \"age\": 30}}".getBytes(StandardCharsets.UTF_8));
@@ -653,7 +677,8 @@ public void testConsumerReinitializationWithNoInitialMessages() throws Exception
653677
verify(processor, never()).process(any(), any());
654678

655679
// Request consumer reinitialization
656-
poller.requestConsumerReinitialization();
680+
IngestionSource mockIngestionSource = new IngestionSource.Builder("test").build();
681+
poller.requestConsumerReinitialization(mockIngestionSource);
657682

658683
// Add a message
659684
messages.add("{\"_id\":\"1\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8));
@@ -670,4 +695,31 @@ public void testConsumerReinitializationWithNoInitialMessages() throws Exception
670695
// Verify 1 message was processed
671696
verify(processor, times(1)).process(any(), any());
672697
}
698+
699+
public void testGetBatchStartPointerWithNullInitialPointer() {
700+
// Create a mock blocking queue container that returns null pointers
701+
PartitionedBlockingQueueContainer mockContainer = mock(PartitionedBlockingQueueContainer.class);
702+
when(mockContainer.getCurrentShardPointers()).thenReturn(Arrays.asList(null, null, null));
703+
704+
// Create poller with null initial batch start pointer
705+
poller = new DefaultStreamPoller(
706+
null,
707+
fakeConsumerFactory,
708+
"",
709+
0,
710+
mockContainer,
711+
StreamPoller.ResetState.NONE,
712+
"",
713+
errorStrategy,
714+
StreamPoller.State.NONE,
715+
1000,
716+
1000,
717+
10000,
718+
indexSettings,
719+
new DefaultIngestionMessageMapper()
720+
);
721+
722+
// When all queues return null and initialBatchStartPointer is null, getBatchStartPointer should return null
723+
assertNull(poller.getBatchStartPointer());
724+
}
673725
}

0 commit comments

Comments
 (0)