diff --git a/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java b/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java index d9d480e7b2b27..ce38dd3bb236c 100644 --- a/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java +++ b/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java @@ -98,7 +98,7 @@ public class SegmentReplicationPressureService implements Closeable { private final SegmentReplicationStatsTracker tracker; private final ShardStateAction shardStateAction; - private final AsyncFailStaleReplicaTask failStaleReplicaTask; + private volatile AsyncFailStaleReplicaTask failStaleReplicaTask; @Inject public SegmentReplicationPressureService( @@ -202,6 +202,15 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) { public void setReplicationTimeLimitFailReplica(TimeValue replicationTimeLimitFailReplica) { this.replicationTimeLimitFailReplica = replicationTimeLimitFailReplica; + updateAsyncFailReplicaTask(); + } + + private synchronized void updateAsyncFailReplicaTask() { + try { + failStaleReplicaTask.close(); + } finally { + failStaleReplicaTask = new AsyncFailStaleReplicaTask(this); + } } public void setReplicationTimeLimitBackpressure(TimeValue replicationTimeLimitBackpressure) { @@ -228,13 +237,13 @@ final static class AsyncFailStaleReplicaTask extends AbstractAsyncTask { @Override protected boolean mustReschedule() { - return true; + return pressureService.shouldScheduleAsyncFailTask(); } @Override protected void runInternal() { // Do not fail the replicas if time limit is set to 0 (i.e. disabled). - if (TimeValue.ZERO.equals(pressureService.replicationTimeLimitFailReplica) == false) { + if (pressureService.shouldScheduleAsyncFailTask()) { final SegmentReplicationStats stats = pressureService.tracker.getStats(); // Find the shardId in node which is having stale replicas with highest current replication time. @@ -302,4 +311,8 @@ public String toString() { } + boolean shouldScheduleAsyncFailTask() { + return TimeValue.ZERO.equals(replicationTimeLimitFailReplica) == false; + } + } diff --git a/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java b/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java index 478fdcb24f76a..a9725f638cc53 100644 --- a/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java +++ b/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java @@ -217,6 +217,8 @@ public void testFailStaleReplicaTask() throws Exception { assertEquals(5, shardStats.getCheckpointsBehindCount()); // call the background task + assertTrue(service.getFailStaleReplicaTask().mustReschedule()); + assertTrue(service.getFailStaleReplicaTask().isScheduled()); service.getFailStaleReplicaTask().runInternal(); // verify that remote shard failed method is called which fails the replica shards falling behind. @@ -257,6 +259,41 @@ public void testFailStaleReplicaTaskDisabled() throws Exception { } } + public void testFailStaleReplicaTaskToggleOnOff() throws Exception { + final Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT) + .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true) + .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueMillis(10)) + .put(MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING.getKey(), TimeValue.timeValueMillis(1)) + .build(); + + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + final IndexShard primaryShard = shards.getPrimary(); + SegmentReplicationPressureService service = buildPressureService(settings, primaryShard); + + // index docs in batches without refreshing + indexInBatches(5, shards, primaryShard); + + // assert that replica shard is few checkpoints behind primary + Set replicationStats = primaryShard.getReplicationStatsForTrackedReplicas(); + assertEquals(1, replicationStats.size()); + SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get(); + assertEquals(5, shardStats.getCheckpointsBehindCount()); + + assertTrue(service.getFailStaleReplicaTask().mustReschedule()); + assertTrue(service.getFailStaleReplicaTask().isScheduled()); + replicateSegments(primaryShard, shards.getReplicas()); + + service.setReplicationTimeLimitFailReplica(TimeValue.ZERO); + assertFalse(service.getFailStaleReplicaTask().mustReschedule()); + assertFalse(service.getFailStaleReplicaTask().isScheduled()); + service.setReplicationTimeLimitFailReplica(TimeValue.timeValueMillis(1)); + assertTrue(service.getFailStaleReplicaTask().mustReschedule()); + assertTrue(service.getFailStaleReplicaTask().isScheduled()); + } + } + private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception { int totalDocs = 0; for (int i = 0; i < count; i++) {