Skip to content

Commit

Permalink
Add background async task to fail stale replica shards. (#7022)
Browse files Browse the repository at this point in the history
Signed-off-by: Rishikesh1159 <rishireddy1159@gmail.com>
  • Loading branch information
Rishikesh1159 authored Apr 6, 2023
1 parent d2873dc commit 2ce07f2
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardState;
import org.opensearch.indices.replication.SegmentReplicationBaseIT;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.plugins.Plugin;
Expand All @@ -29,6 +30,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static java.util.Arrays.asList;
Expand Down Expand Up @@ -200,6 +202,42 @@ public void testBelowReplicaLimit() throws Exception {
verifyStoreContent();
}

public void testFailStaleReplica() throws Exception {

Settings settings = Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build();
// Starts a primary and replica node.
final String primaryNode = internalCluster().startNode(settings);
createIndex(INDEX_NAME);
ensureYellowAndNoInitializingShards(INDEX_NAME);
final String replicaNode = internalCluster().startNode(settings);
ensureGreen(INDEX_NAME);

final IndexShard primaryShard = getIndexShard(primaryNode, INDEX_NAME);
final List<String> replicaNodes = asList(replicaNode);
assertEqualSegmentInfosVersion(replicaNodes, primaryShard);
IndexShard replicaShard = getIndexShard(replicaNode, INDEX_NAME);

final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger totalDocs = new AtomicInteger(0);
try (final Releasable ignored = blockReplication(replicaNodes, latch)) {
// Index docs until replicas are staled.
totalDocs.getAndSet(indexUntilCheckpointCount());
latch.await();
// index again while we are stale.
indexDoc();
refresh(INDEX_NAME);
totalDocs.incrementAndGet();

// Verify that replica shard is closed.
assertBusy(() -> { assertTrue(replicaShard.state().equals(IndexShardState.CLOSED)); }, 1, TimeUnit.MINUTES);
}
ensureGreen(INDEX_NAME);
final IndexShard replicaAfterFailure = getIndexShard(replicaNode, INDEX_NAME);

// Verify that new replica shard after failure is different from old replica shard.
assertNotEquals(replicaAfterFailure.routingEntry().allocationId().getId(), replicaShard.routingEntry().allocationId().getId());
}

public void testBulkWritesRejected() throws Exception {
final String primaryNode = internalCluster().startNode();
createIndex(INDEX_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractAsyncTask;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.util.Comparator;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -29,7 +37,7 @@
*
* @opensearch.internal
*/
public class SegmentReplicationPressureService {
public class SegmentReplicationPressureService implements Closeable {

private volatile boolean isSegmentReplicationBackpressureEnabled;
private volatile int maxCheckpointsBehind;
Expand All @@ -38,6 +46,10 @@ public class SegmentReplicationPressureService {

private static final Logger logger = LogManager.getLogger(SegmentReplicationPressureService.class);

/**
* When enabled, writes will be rejected when a replica shard falls behind by both the MAX_REPLICATION_TIME_SETTING time value and MAX_INDEXING_CHECKPOINTS number of checkpoints.
* Once a shard falls behind double the MAX_REPLICATION_TIME_SETTING time value it will be marked as failed.
*/
public static final Setting<Boolean> SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED = Setting.boolSetting(
"segrep.pressure.enabled",
false,
Expand Down Expand Up @@ -70,13 +82,28 @@ public class SegmentReplicationPressureService {
);

private final IndicesService indicesService;

private final ThreadPool threadPool;
private final SegmentReplicationStatsTracker tracker;

private final ShardStateAction shardStateAction;

private final AsyncFailStaleReplicaTask failStaleReplicaTask;

@Inject
public SegmentReplicationPressureService(Settings settings, ClusterService clusterService, IndicesService indicesService) {
public SegmentReplicationPressureService(
Settings settings,
ClusterService clusterService,
IndicesService indicesService,
ShardStateAction shardStateAction,
ThreadPool threadPool
) {
this.indicesService = indicesService;
this.tracker = new SegmentReplicationStatsTracker(this.indicesService);

this.shardStateAction = shardStateAction;
this.threadPool = threadPool;

final ClusterSettings clusterSettings = clusterService.getClusterSettings();
this.isSegmentReplicationBackpressureEnabled = SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.get(settings);
clusterSettings.addSettingsUpdateConsumer(
Expand All @@ -92,6 +119,13 @@ public SegmentReplicationPressureService(Settings settings, ClusterService clust

this.maxAllowedStaleReplicas = MAX_ALLOWED_STALE_SHARDS.get(settings);
clusterSettings.addSettingsUpdateConsumer(MAX_ALLOWED_STALE_SHARDS, this::setMaxAllowedStaleReplicas);

this.failStaleReplicaTask = new AsyncFailStaleReplicaTask(this);
}

// visible for testing
AsyncFailStaleReplicaTask getFailStaleReplicaTask() {
return failStaleReplicaTask;
}

public void isSegrepLimitBreached(ShardId shardId) {
Expand Down Expand Up @@ -154,4 +188,94 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) {
public void setMaxReplicationTime(TimeValue maxReplicationTime) {
this.maxReplicationTime = maxReplicationTime;
}

@Override
public void close() throws IOException {
failStaleReplicaTask.close();
}

// Background Task to fail replica shards if they are too far behind primary shard.
final static class AsyncFailStaleReplicaTask extends AbstractAsyncTask {

final SegmentReplicationPressureService pressureService;

static final TimeValue INTERVAL = TimeValue.timeValueSeconds(30);

AsyncFailStaleReplicaTask(SegmentReplicationPressureService pressureService) {
super(logger, pressureService.threadPool, INTERVAL, true);
this.pressureService = pressureService;
rescheduleIfNecessary();
}

@Override
protected boolean mustReschedule() {
return true;
}

@Override
protected void runInternal() {
if (pressureService.isSegmentReplicationBackpressureEnabled) {
final SegmentReplicationStats stats = pressureService.tracker.getStats();

// Find the shardId in node which is having stale replicas with highest current replication time.
// This way we only fail one shardId's stale replicas in every iteration of this background async task and there by decrease
// load gradually on node.
stats.getShardStats()
.entrySet()
.stream()
.flatMap(
entry -> pressureService.getStaleReplicas(entry.getValue().getReplicaStats())
.stream()
.map(r -> Tuple.tuple(entry.getKey(), r.getCurrentReplicationTimeMillis()))
)
.max(Comparator.comparingLong(Tuple::v2))
.map(Tuple::v1)
.ifPresent(shardId -> {
final Set<SegmentReplicationShardStats> staleReplicas = pressureService.getStaleReplicas(
stats.getShardStats().get(shardId).getReplicaStats()
);
final IndexService indexService = pressureService.indicesService.indexService(shardId.getIndex());
final IndexShard primaryShard = indexService.getShard(shardId.getId());
for (SegmentReplicationShardStats staleReplica : staleReplicas) {
if (staleReplica.getCurrentReplicationTimeMillis() > 2 * pressureService.maxReplicationTime.millis()) {
pressureService.shardStateAction.remoteShardFailed(
shardId,
staleReplica.getAllocationId(),
primaryShard.getOperationPrimaryTerm(),
true,
"replica too far behind primary, marking as stale",
null,
new ActionListener<>() {
@Override
public void onResponse(Void unused) {
logger.trace(
"Successfully failed remote shardId [{}] allocation id [{}]",
shardId,
staleReplica.getAllocationId()
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to send remote shard failure", e);
}
}
);
}
}
});
}
}

@Override
protected String getThreadPool() {
return ThreadPool.Names.GENERIC;
}

@Override
public String toString() {
return "fail_stale_replica";
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

package org.opensearch.index;

import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.threadpool.ThreadPool;

import java.util.Iterator;
import java.util.List;
Expand All @@ -29,13 +32,20 @@
import java.util.concurrent.TimeUnit;

import static java.util.Arrays.asList;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING;
import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;

public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevelReplicationTestCase {

private static ShardStateAction shardStateAction = Mockito.mock(ShardStateAction.class);
private static final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
Expand Down Expand Up @@ -181,6 +191,36 @@ public void testIsSegrepLimitBreached_underStaleNodeLimit() throws Exception {
}
}

public void testFailStaleReplicaTask() throws Exception {
final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
.put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(10))
.build();

try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) {
shards.startAll();
final IndexShard primaryShard = shards.getPrimary();
SegmentReplicationPressureService service = buildPressureService(settings, primaryShard);

// index docs in batches without refreshing
indexInBatches(5, shards, primaryShard);

// assert that replica shard is few checkpoints behind primary
Set<SegmentReplicationShardStats> replicationStats = primaryShard.getReplicationStats();
assertEquals(1, replicationStats.size());
SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get();
assertEquals(5, shardStats.getCheckpointsBehindCount());

// call the background task
service.getFailStaleReplicaTask().runInternal();

// verify that remote shard failed method is called which fails the replica shards falling behind.
verify(shardStateAction, times(1)).remoteShardFailed(any(), anyString(), anyLong(), anyBoolean(), anyString(), any(), any());
replicateSegments(primaryShard, shards.getReplicas());
}
}

private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception {
int totalDocs = 0;
for (int i = 0; i < count; i++) {
Expand All @@ -202,6 +242,6 @@ private SegmentReplicationPressureService buildPressureService(Settings settings
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS));

return new SegmentReplicationPressureService(settings, clusterService, indicesService);
return new SegmentReplicationPressureService(settings, clusterService, indicesService, shardStateAction, mock(ThreadPool.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1984,7 +1984,13 @@ public void onFailure(final Exception e) {
new UpdateHelper(scriptService),
actionFilters,
new IndexingPressureService(settings, clusterService),
new SegmentReplicationPressureService(settings, clusterService, mock(IndicesService.class)),
new SegmentReplicationPressureService(
settings,
clusterService,
mock(IndicesService.class),
mock(ShardStateAction.class),
mock(ThreadPool.class)
),
new SystemIndices(emptyMap())
);
actions.put(
Expand Down

0 comments on commit 2ce07f2

Please sign in to comment.