From cef2aae63325c6109296b79ae8448b5c78afa2b8 Mon Sep 17 00:00:00 2001
From: Ankit Kala <ankikala@amazon.com>
Date: Thu, 31 Aug 2023 11:02:06 +0530
Subject: [PATCH] Decouple replication lag from logic to fail stale replicas
 (#9507)

* Decouple replication lag from replication timer logic used to fail stale replicas

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Added changelog entry

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Addressed comments

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Addressed comments 2

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Addressed comments

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Retry gradle

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* fix UT

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Addressed comments

Signed-off-by: Ankit Kala <ankikala@amazon.com>

* Retry Gradle

Signed-off-by: Ankit Kala <ankikala@amazon.com>

---------

Signed-off-by: Ankit Kala <ankikala@amazon.com>
(cherry picked from commit d66df10b248457d3d9778131d6939dd1a2185e39)
---
 CHANGELOG.md                                  |  2 +
 .../index/SegmentReplicationPressureIT.java   | 18 +++-
 .../common/settings/ClusterSettings.java      |  3 +-
 .../SegmentReplicationPressureService.java    | 37 ++++++--
 .../index/SegmentReplicationShardStats.java   | 24 +++++
 .../index/seqno/ReplicationTracker.java       | 90 +++++++++++++------
 .../opensearch/index/shard/IndexShard.java    | 31 +++++++
 .../common/SegmentReplicationLagTimer.java    | 48 ++++++++++
 .../cat/RestCatSegmentReplicationAction.java  |  2 +-
 ...egmentReplicationPressureServiceTests.java | 51 +++++++++--
 .../index/seqno/ReplicationTrackerTests.java  | 85 ++++++++++++++++++
 .../RestCatSegmentReplicationActionTests.java |  3 +-
 12 files changed, 343 insertions(+), 51 deletions(-)
 create mode 100644 server/src/main/java/org/opensearch/indices/replication/common/SegmentReplicationLagTimer.java

diff --git a/CHANGELOG.md b/CHANGELOG.md
index bfadf4f34d6ba..52e69ba3e7a4c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,12 +5,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 
 ## [Unreleased 2.x]
 ### Added
+<<<<<<< HEAD
 - Add task cancellation monitoring service ([#7642](https://github.com/opensearch-project/OpenSearch/pull/7642))
 - Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452))
 - Add Remote store as a segment replication source ([#7653](https://github.com/opensearch-project/OpenSearch/pull/7653))
 - Implement concurrent aggregations support without profile option ([#7514](https://github.com/opensearch-project/OpenSearch/pull/7514))
 - Add dynamic index and cluster setting for concurrent segment search ([#7956](https://github.com/opensearch-project/OpenSearch/pull/7956))
 - Add descending order search optimization through reverse segment read. ([#7967](https://github.com/opensearch-project/OpenSearch/pull/7967))
+- Decouple replication lag from logic to fail stale replicas ([#9507](https://github.com/opensearch-project/OpenSearch/pull/9507))
 
 ### Dependencies
 - Bump `com.azure:azure-storage-common` from 12.21.0 to 12.21.1 (#7566, #7814)
diff --git a/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java b/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java
index 98541310649db..85c61b8c83cc0 100644
--- a/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java
+++ b/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java
@@ -37,7 +37,8 @@
 import static java.util.Arrays.asList;
 import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS;
 import static org.opensearch.index.SegmentReplicationPressureService.MAX_INDEXING_CHECKPOINTS;
-import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING;
+import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING;
+import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_BACKPRESSURE_SETTING;
 import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;
 import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
 import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertHitCount;
@@ -52,7 +53,7 @@ protected Settings nodeSettings(int nodeOrdinal) {
         return Settings.builder()
             .put(super.nodeSettings(nodeOrdinal))
             .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
-            .put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueSeconds(1))
+            .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueSeconds(1))
             .put(MAX_INDEXING_CHECKPOINTS.getKey(), MAX_CHECKPOINTS_BEHIND)
             .build();
     }
@@ -223,7 +224,10 @@ public void testBelowReplicaLimit() throws Exception {
 
     public void testFailStaleReplica() throws Exception {
 
-        Settings settings = Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build();
+        Settings settings = Settings.builder()
+            .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueMillis(500))
+            .put(MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING.getKey(), TimeValue.timeValueMillis(1000))
+            .build();
         // Starts a primary and replica node.
         final String primaryNode = internalCluster().startNode(settings);
         createIndex(INDEX_NAME);
@@ -258,7 +262,13 @@ public void testFailStaleReplica() throws Exception {
     }
 
     public void testWithDocumentReplicationEnabledIndex() throws Exception {
-        Settings settings = Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build();
+        assumeTrue(
+            "Can't create DocRep index with remote store enabled. Skipping.",
+            Objects.equals(featureFlagSettings().get(FeatureFlags.REMOTE_STORE, "false"), "false")
+        );
+        Settings settings = Settings.builder()
+            .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueMillis(500))
+            .build();
         // Starts a primary and replica node.
         final String primaryNode = internalCluster().startNode(settings);
         createIndex(
diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
index e00bc039fa0d7..e501ed0f34080 100644
--- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
+++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
@@ -639,7 +639,8 @@ public void apply(Settings value, Settings current, Settings previous) {
                 SearchBackpressureSettings.SETTING_CANCELLATION_BURST,   // deprecated
                 SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED,
                 SegmentReplicationPressureService.MAX_INDEXING_CHECKPOINTS,
-                SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING,
+                SegmentReplicationPressureService.MAX_REPLICATION_TIME_BACKPRESSURE_SETTING,
+                SegmentReplicationPressureService.MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING,
                 SegmentReplicationPressureService.MAX_ALLOWED_STALE_SHARDS,
 
                 // Settings related to Searchable Snapshots
diff --git a/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java b/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java
index 1fb8f0be52296..f8f415fbcc752 100644
--- a/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java
+++ b/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java
@@ -42,7 +42,8 @@ public class SegmentReplicationPressureService implements Closeable {
     private volatile boolean isSegmentReplicationBackpressureEnabled;
     private volatile int maxCheckpointsBehind;
     private volatile double maxAllowedStaleReplicas;
-    private volatile TimeValue maxReplicationTime;
+    private volatile TimeValue replicationTimeLimitBackpressure;
+    private volatile TimeValue replicationTimeLimitFailReplica;
 
     private static final Logger logger = LogManager.getLogger(SegmentReplicationPressureService.class);
 
@@ -65,13 +66,23 @@ public class SegmentReplicationPressureService implements Closeable {
         Setting.Property.NodeScope
     );
 
-    public static final Setting<TimeValue> MAX_REPLICATION_TIME_SETTING = Setting.positiveTimeSetting(
+    // Time limit on max allowed replica staleness after which backpressure kicks in on primary.
+    public static final Setting<TimeValue> MAX_REPLICATION_TIME_BACKPRESSURE_SETTING = Setting.positiveTimeSetting(
         "segrep.pressure.time.limit",
         TimeValue.timeValueMinutes(5),
         Setting.Property.Dynamic,
         Setting.Property.NodeScope
     );
 
+    // Time limit on max allowed replica staleness after which we start failing the replica shard.
+    // Defaults to 0(disabled)
+    public static final Setting<TimeValue> MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING = Setting.positiveTimeSetting(
+        "segrep.replication.time.limit",
+        TimeValue.timeValueMinutes(0),
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+
     public static final Setting<Double> MAX_ALLOWED_STALE_SHARDS = Setting.doubleSetting(
         "segrep.pressure.replica.stale.limit",
         .5,
@@ -114,8 +125,11 @@ public SegmentReplicationPressureService(
         this.maxCheckpointsBehind = MAX_INDEXING_CHECKPOINTS.get(settings);
         clusterSettings.addSettingsUpdateConsumer(MAX_INDEXING_CHECKPOINTS, this::setMaxCheckpointsBehind);
 
-        this.maxReplicationTime = MAX_REPLICATION_TIME_SETTING.get(settings);
-        clusterSettings.addSettingsUpdateConsumer(MAX_REPLICATION_TIME_SETTING, this::setMaxReplicationTime);
+        this.replicationTimeLimitBackpressure = MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.get(settings);
+        clusterSettings.addSettingsUpdateConsumer(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING, this::setReplicationTimeLimitBackpressure);
+
+        this.replicationTimeLimitFailReplica = MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING.get(settings);
+        clusterSettings.addSettingsUpdateConsumer(MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING, this::setReplicationTimeLimitFailReplica);
 
         this.maxAllowedStaleReplicas = MAX_ALLOWED_STALE_SHARDS.get(settings);
         clusterSettings.addSettingsUpdateConsumer(MAX_ALLOWED_STALE_SHARDS, this::setMaxAllowedStaleReplicas);
@@ -159,7 +173,7 @@ private void validateReplicationGroup(IndexShard shard) {
     private Set<SegmentReplicationShardStats> getStaleReplicas(final Set<SegmentReplicationShardStats> replicas) {
         return replicas.stream()
             .filter(entry -> entry.getCheckpointsBehindCount() > maxCheckpointsBehind)
-            .filter(entry -> entry.getCurrentReplicationTimeMillis() > maxReplicationTime.millis())
+            .filter(entry -> entry.getCurrentReplicationTimeMillis() > replicationTimeLimitBackpressure.millis())
             .collect(Collectors.toSet());
     }
 
@@ -187,8 +201,12 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) {
         this.maxAllowedStaleReplicas = maxAllowedStaleReplicas;
     }
 
-    public void setMaxReplicationTime(TimeValue maxReplicationTime) {
-        this.maxReplicationTime = maxReplicationTime;
+    public void setReplicationTimeLimitFailReplica(TimeValue replicationTimeLimitFailReplica) {
+        this.replicationTimeLimitFailReplica = replicationTimeLimitFailReplica;
+    }
+
+    public void setReplicationTimeLimitBackpressure(TimeValue replicationTimeLimitBackpressure) {
+        this.replicationTimeLimitBackpressure = replicationTimeLimitBackpressure;
     }
 
     @Override
@@ -216,7 +234,8 @@ protected boolean mustReschedule() {
 
         @Override
         protected void runInternal() {
-            if (pressureService.isSegmentReplicationBackpressureEnabled) {
+            // Do not fail the replicas if time limit is set to 0 (i.e. disabled).
+            if (TimeValue.ZERO.equals(pressureService.replicationTimeLimitFailReplica) == false) {
                 final SegmentReplicationStats stats = pressureService.tracker.getStats();
 
                 // Find the shardId in node which is having stale replicas with highest current replication time.
@@ -242,7 +261,7 @@ protected void runInternal() {
                         }
                         final IndexShard primaryShard = indexService.getShard(shardId.getId());
                         for (SegmentReplicationShardStats staleReplica : staleReplicas) {
-                            if (staleReplica.getCurrentReplicationTimeMillis() > 2 * pressureService.maxReplicationTime.millis()) {
+                            if (staleReplica.getCurrentReplicationTimeMillis() > pressureService.replicationTimeLimitFailReplica.millis()) {
                                 pressureService.shardStateAction.remoteShardFailed(
                                     shardId,
                                     staleReplica.getAllocationId(),
diff --git a/server/src/main/java/org/opensearch/index/SegmentReplicationShardStats.java b/server/src/main/java/org/opensearch/index/SegmentReplicationShardStats.java
index eca57195db81c..5f5b8b513f15c 100644
--- a/server/src/main/java/org/opensearch/index/SegmentReplicationShardStats.java
+++ b/server/src/main/java/org/opensearch/index/SegmentReplicationShardStats.java
@@ -29,6 +29,10 @@ public class SegmentReplicationShardStats implements Writeable, ToXContentFragme
     private final String allocationId;
     private final long checkpointsBehindCount;
     private final long bytesBehindCount;
+    // Total Replication lag observed.
+    private final long currentReplicationLagMillis;
+    // Total time taken for replicas to catch up. Similar to replication lag except this
+    // doesn't include time taken by primary to upload data to remote store.
     private final long currentReplicationTimeMillis;
     private final long lastCompletedReplicationTimeMillis;
 
@@ -40,12 +44,14 @@ public SegmentReplicationShardStats(
         long checkpointsBehindCount,
         long bytesBehindCount,
         long currentReplicationTimeMillis,
+        long currentReplicationLagMillis,
         long lastCompletedReplicationTime
     ) {
         this.allocationId = allocationId;
         this.checkpointsBehindCount = checkpointsBehindCount;
         this.bytesBehindCount = bytesBehindCount;
         this.currentReplicationTimeMillis = currentReplicationTimeMillis;
+        this.currentReplicationLagMillis = currentReplicationLagMillis;
         this.lastCompletedReplicationTimeMillis = lastCompletedReplicationTime;
     }
 
@@ -55,6 +61,7 @@ public SegmentReplicationShardStats(StreamInput in) throws IOException {
         this.bytesBehindCount = in.readVLong();
         this.currentReplicationTimeMillis = in.readVLong();
         this.lastCompletedReplicationTimeMillis = in.readVLong();
+        this.currentReplicationLagMillis = in.readVLong();
     }
 
     public String getAllocationId() {
@@ -73,6 +80,19 @@ public long getCurrentReplicationTimeMillis() {
         return currentReplicationTimeMillis;
     }
 
+    /**
+     * Total Replication lag observed.
+     * @return currentReplicationLagMillis
+     */
+    public long getCurrentReplicationLagMillis() {
+        return currentReplicationLagMillis;
+    }
+
+    /**
+     * Total time taken for replicas to catch up. Similar to replication lag except this doesn't include time taken by
+     * primary to upload data to remote store.
+     * @return lastCompletedReplicationTimeMillis
+     */
     public long getLastCompletedReplicationTimeMillis() {
         return lastCompletedReplicationTimeMillis;
     }
@@ -93,6 +113,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
         builder.field("checkpoints_behind", checkpointsBehindCount);
         builder.field("bytes_behind", new ByteSizeValue(bytesBehindCount).toString());
         builder.field("current_replication_time", new TimeValue(currentReplicationTimeMillis));
+        builder.field("current_replication_lag", new TimeValue(currentReplicationLagMillis));
         builder.field("last_completed_replication_time", new TimeValue(lastCompletedReplicationTimeMillis));
         if (currentReplicationState != null) {
             builder.startObject();
@@ -110,6 +131,7 @@ public void writeTo(StreamOutput out) throws IOException {
         out.writeVLong(bytesBehindCount);
         out.writeVLong(currentReplicationTimeMillis);
         out.writeVLong(lastCompletedReplicationTimeMillis);
+        out.writeVLong(currentReplicationLagMillis);
     }
 
     @Override
@@ -121,6 +143,8 @@ public String toString() {
             + checkpointsBehindCount
             + ", bytesBehindCount="
             + bytesBehindCount
+            + ", currentReplicationLagMillis="
+            + currentReplicationLagMillis
             + ", currentReplicationTimeMillis="
             + currentReplicationTimeMillis
             + ", lastCompletedReplicationTimeMillis="
diff --git a/server/src/main/java/org/opensearch/index/seqno/ReplicationTracker.java b/server/src/main/java/org/opensearch/index/seqno/ReplicationTracker.java
index 6b34d6641fcf2..e8498bc6628c7 100644
--- a/server/src/main/java/org/opensearch/index/seqno/ReplicationTracker.java
+++ b/server/src/main/java/org/opensearch/index/seqno/ReplicationTracker.java
@@ -60,7 +60,7 @@
 import org.opensearch.index.shard.ReplicationGroup;
 import org.opensearch.index.shard.ShardId;
 import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
-import org.opensearch.indices.replication.common.ReplicationTimer;
+import org.opensearch.indices.replication.common.SegmentReplicationLagTimer;
 
 import java.io.IOException;
 import java.nio.file.Path;
@@ -716,7 +716,7 @@ public static class CheckpointState implements Writeable {
          * Map of ReplicationCheckpoints to ReplicationTimers.  Timers are added as new checkpoints are published, and removed when
          * the replica is caught up.
          */
-        Map<ReplicationCheckpoint, ReplicationTimer> checkpointTimers;
+        Map<ReplicationCheckpoint, SegmentReplicationLagTimer> checkpointTimers;
 
         /**
          * The time it took to complete the most recent replication event.
@@ -1184,9 +1184,9 @@ public synchronized void updateVisibleCheckpointForShard(final String allocation
             cps.checkpointTimers.entrySet().removeIf((entry) -> {
                 boolean result = visibleCheckpoint.equals(entry.getKey()) || visibleCheckpoint.isAheadOf(entry.getKey());
                 if (result) {
-                    final ReplicationTimer timer = entry.getValue();
+                    final SegmentReplicationLagTimer timer = entry.getValue();
                     timer.stop();
-                    lastFinished.set(Math.max(lastFinished.get(), timer.time()));
+                    lastFinished.set(Math.max(lastFinished.get(), timer.totalElapsedTime()));
                 }
                 return result;
             });
@@ -1206,36 +1206,71 @@ public synchronized void updateVisibleCheckpointForShard(final String allocation
     }
 
     /**
-     * After a new checkpoint is published, start a timer for each replica to the checkpoint.
+     * After a new checkpoint is published, create a timer for each replica to the checkpoint.
      * @param checkpoint {@link ReplicationCheckpoint}
      */
     public synchronized void setLatestReplicationCheckpoint(ReplicationCheckpoint checkpoint) {
         assert indexSettings.isSegRepEnabled();
-        assert handoffInProgress == false;
-        if (checkpoint.equals(lastPublishedReplicationCheckpoint) == false) {
-            this.lastPublishedReplicationCheckpoint = checkpoint;
-            for (Map.Entry<String, CheckpointState> entry : checkpoints.entrySet()) {
-                if (entry.getKey().equals(this.shardAllocationId) == false) {
-                    final CheckpointState cps = entry.getValue();
-                    if (cps.inSync) {
-                        cps.checkpointTimers.computeIfAbsent(checkpoint, ignored -> {
-                            final ReplicationTimer replicationTimer = new ReplicationTimer();
-                            replicationTimer.start();
-                            return replicationTimer;
-                        });
-                        logger.trace(
-                            () -> new ParameterizedMessage(
-                                "updated last published checkpoint to {} - timers [{}]",
-                                checkpoint,
-                                cps.checkpointTimers.keySet()
-                            )
-                        );
-                    }
+        if (checkpoint.equals(latestReplicationCheckpoint) == false) {
+            this.latestReplicationCheckpoint = checkpoint;
+        }
+        if (primaryMode) {
+            createReplicationLagTimers();
+        }
+    }
+
+    public ReplicationCheckpoint getLatestReplicationCheckpoint() {
+        return this.latestReplicationCheckpoint;
+    }
+
+    private void createReplicationLagTimers() {
+        for (Map.Entry<String, CheckpointState> entry : checkpoints.entrySet()) {
+            final String allocationId = entry.getKey();
+            if (allocationId.equals(this.shardAllocationId) == false) {
+                final CheckpointState cps = entry.getValue();
+                // if the shard is in checkpoints but is unavailable or out of sync we will not track its replication state.
+                // it is possible for a shard to be in-sync but not yet removed from the checkpoints collection after a failover event.
+                if (cps.inSync
+                    && replicationGroup.getUnavailableInSyncShards().contains(allocationId) == false
+                    && latestReplicationCheckpoint.isAheadOf(cps.visibleReplicationCheckpoint)) {
+                    cps.checkpointTimers.computeIfAbsent(latestReplicationCheckpoint, ignored -> new SegmentReplicationLagTimer());
+                    logger.trace(
+                        () -> new ParameterizedMessage(
+                            "updated last published checkpoint for {} at visible cp {} to {} - timers [{}]",
+                            allocationId,
+                            cps.visibleReplicationCheckpoint,
+                            latestReplicationCheckpoint,
+                            cps.checkpointTimers.keySet()
+                        )
+                    );
                 }
             }
         }
     }
 
+    /**
+     * After a new checkpoint is published, start a timer per replica for the checkpoint.
+     * @param checkpoint {@link ReplicationCheckpoint}
+     */
+    public synchronized void startReplicationLagTimers(ReplicationCheckpoint checkpoint) {
+        assert indexSettings.isSegRepEnabled();
+        if (checkpoint.equals(latestReplicationCheckpoint) == false) {
+            this.latestReplicationCheckpoint = checkpoint;
+        }
+        if (primaryMode) {
+            checkpoints.entrySet().stream().filter(e -> !e.getKey().equals(this.shardAllocationId)).forEach(e -> {
+                String allocationId = e.getKey();
+                final CheckpointState cps = e.getValue();
+                if (cps.inSync
+                    && replicationGroup.getUnavailableInSyncShards().contains(allocationId) == false
+                    && latestReplicationCheckpoint.isAheadOf(cps.visibleReplicationCheckpoint)
+                    && cps.checkpointTimers.containsKey(latestReplicationCheckpoint)) {
+                    cps.checkpointTimers.get(latestReplicationCheckpoint).start();
+                }
+            });
+        }
+    }
+
     /**
      * Fetch stats on segment replication.
      * @return {@link Tuple} V1 - TimeValue in ms - mean replication lag for this primary to its entire group,
@@ -1259,14 +1294,15 @@ private SegmentReplicationShardStats buildShardStats(
         final String allocationId,
         final CheckpointState checkpointState
     ) {
-        final Map<ReplicationCheckpoint, ReplicationTimer> checkpointTimers = checkpointState.checkpointTimers;
+        final Map<ReplicationCheckpoint, SegmentReplicationLagTimer> checkpointTimers = checkpointState.checkpointTimers;
         return new SegmentReplicationShardStats(
             allocationId,
             checkpointTimers.size(),
             checkpointState.visibleReplicationCheckpoint == null
                 ? latestCheckpointLength
                 : Math.max(latestCheckpointLength - checkpointState.visibleReplicationCheckpoint.getLength(), 0),
-            checkpointTimers.values().stream().mapToLong(ReplicationTimer::time).max().orElse(0),
+            checkpointTimers.values().stream().mapToLong(SegmentReplicationLagTimer::time).max().orElse(0),
+            checkpointTimers.values().stream().mapToLong(SegmentReplicationLagTimer::totalElapsedTime).max().orElse(0),
             checkpointState.lastCompletedReplicationLag
         );
     }
diff --git a/server/src/main/java/org/opensearch/index/shard/IndexShard.java b/server/src/main/java/org/opensearch/index/shard/IndexShard.java
index f5e349eb54b99..1404c4991d2fb 100644
--- a/server/src/main/java/org/opensearch/index/shard/IndexShard.java
+++ b/server/src/main/java/org/opensearch/index/shard/IndexShard.java
@@ -1836,6 +1836,10 @@ static Engine.Searcher wrapSearcher(
         }
     }
 
+    public void onCheckpointPublished(ReplicationCheckpoint checkpoint) {
+        replicationTracker.startReplicationLagTimers(checkpoint);
+    }
+
     /**
      * Used with segment replication during relocation handoff, this method updates current read only engine to global
      * checkpoint followed by changing to writeable engine
@@ -4386,6 +4390,33 @@ public void afterRefresh(boolean didRefresh) throws IOException {
         }
     }
 
+    /**
+     * Refresh listener to update the Shard's ReplicationCheckpoint post refresh.
+     */
+    private class ReplicationCheckpointUpdater implements ReferenceManager.RefreshListener {
+        @Override
+        public void beforeRefresh() throws IOException {}
+
+        @Override
+        public void afterRefresh(boolean didRefresh) throws IOException {
+            if (didRefresh) {
+                // We're only starting to track the replication checkpoint. The timers for replication are started when
+                // the checkpoint is published. This is done so that the timers do not include the time spent by primary
+                // in uploading the segments to remote store.
+                updateReplicationCheckpoint();
+            }
+        }
+    }
+
+    private void updateReplicationCheckpoint() {
+        final Tuple<GatedCloseable<SegmentInfos>, ReplicationCheckpoint> tuple = getLatestSegmentInfosAndCheckpoint();
+        try (final GatedCloseable<SegmentInfos> ignored = tuple.v1()) {
+            replicationTracker.setLatestReplicationCheckpoint(tuple.v2());
+        } catch (IOException e) {
+            throw new OpenSearchException("Error Closing SegmentInfos Snapshot", e);
+        }
+    }
+
     private EngineConfig.TombstoneDocSupplier tombstoneDocSupplier() {
         final RootObjectMapper.Builder noopRootMapper = new RootObjectMapper.Builder("__noop");
         final DocumentMapper noopDocumentMapper = mapperService != null
diff --git a/server/src/main/java/org/opensearch/indices/replication/common/SegmentReplicationLagTimer.java b/server/src/main/java/org/opensearch/indices/replication/common/SegmentReplicationLagTimer.java
new file mode 100644
index 0000000000000..c97edba72da0d
--- /dev/null
+++ b/server/src/main/java/org/opensearch/indices/replication/common/SegmentReplicationLagTimer.java
@@ -0,0 +1,48 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.indices.replication.common;
+
+import org.opensearch.common.unit.TimeValue;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+
+import java.io.IOException;
+
+/**
+ * Wrapper class for Replication Timer which also tracks time elapsed since the timer was created.
+ * Currently, this is being used to calculate
+ * 1. Replication Lag: Total time taken by replica to sync after primary refreshed.
+ * 2. Replication event time: Total time taken by replica to sync after primary published the checkpoint
+ *                     (excludes the time spent by primary for uploading the segments to remote store).
+ *
+ * @opensearch.internal
+ */
+public class SegmentReplicationLagTimer extends ReplicationTimer {
+    private long creationTime;
+
+    public SegmentReplicationLagTimer() {
+        super();
+        creationTime = System.nanoTime();
+    }
+
+    public SegmentReplicationLagTimer(StreamInput in) throws IOException {
+        super(in);
+        creationTime = in.readVLong();
+    }
+
+    @Override
+    public synchronized void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        out.writeVLong(creationTime);
+    }
+
+    public long totalElapsedTime() {
+        return TimeValue.nsecToMSec(Math.max(System.nanoTime() - creationTime, 0));
+    }
+}
diff --git a/server/src/main/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationAction.java b/server/src/main/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationAction.java
index 0130f9cd14c36..9e9fb1e74d2be 100644
--- a/server/src/main/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationAction.java
+++ b/server/src/main/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationAction.java
@@ -170,7 +170,7 @@ public Table buildSegmentReplicationTable(RestRequest request, SegmentReplicatio
                     t.addCell(state.getTargetNode().getHostName());
                     t.addCell(shardStats.getCheckpointsBehindCount());
                     t.addCell(new ByteSizeValue(shardStats.getBytesBehindCount()));
-                    t.addCell(new TimeValue(shardStats.getCurrentReplicationTimeMillis()));
+                    t.addCell(new TimeValue(shardStats.getCurrentReplicationLagMillis()));
                     t.addCell(new TimeValue(shardStats.getLastCompletedReplicationTimeMillis()));
                     t.addCell(perGroupStats.getRejectedRequestCount());
                     if (detailed) {
diff --git a/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java b/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java
index 1ebdd111bfed3..54539602a5d95 100644
--- a/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java
+++ b/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java
@@ -32,16 +32,18 @@
 import java.util.concurrent.TimeUnit;
 
 import static java.util.Arrays.asList;
+import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING;
+import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_BACKPRESSURE_SETTING;
+import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
-import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING;
-import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevelReplicationTestCase {
 
@@ -49,7 +51,7 @@ public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevel
     private static final Settings settings = Settings.builder()
         .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
         .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
-        .put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueSeconds(5))
+        .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueSeconds(5))
         .build();
 
     public void testIsSegrepLimitBreached() throws Exception {
@@ -195,7 +197,8 @@ public void testFailStaleReplicaTask() throws Exception {
         final Settings settings = Settings.builder()
             .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
             .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
-            .put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(10))
+            .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueMillis(10))
+            .put(MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING.getKey(), TimeValue.timeValueMillis(20))
             .build();
 
         try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) {
@@ -221,6 +224,38 @@ public void testFailStaleReplicaTask() throws Exception {
         }
     }
 
+    public void testFailStaleReplicaTaskDisabled() throws Exception {
+        final Settings settings = Settings.builder()
+            .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
+            .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
+            .put(MAX_REPLICATION_TIME_BACKPRESSURE_SETTING.getKey(), TimeValue.timeValueMillis(10))
+            .put(MAX_REPLICATION_LIMIT_STALE_REPLICA_SETTING.getKey(), TimeValue.timeValueMillis(0))
+            .build();
+
+        try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) {
+            shards.startAll();
+            final IndexShard primaryShard = shards.getPrimary();
+            SegmentReplicationPressureService service = buildPressureService(settings, primaryShard);
+            Mockito.reset(shardStateAction);
+
+            // index docs in batches without refreshing
+            indexInBatches(5, shards, primaryShard);
+
+            // assert that replica shard is few checkpoints behind primary
+            Set<SegmentReplicationShardStats> replicationStats = primaryShard.getReplicationStats();
+            assertEquals(1, replicationStats.size());
+            SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get();
+            assertEquals(5, shardStats.getCheckpointsBehindCount());
+
+            // call the background task
+            service.getFailStaleReplicaTask().runInternal();
+
+            // verify that remote shard failed method is never called as it is disabled.
+            verify(shardStateAction, never()).remoteShardFailed(any(), anyString(), anyLong(), anyBoolean(), anyString(), any(), any());
+            replicateSegments(primaryShard, shards.getReplicas());
+        }
+    }
+
     private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception {
         int totalDocs = 0;
         for (int i = 0; i < count; i++) {
diff --git a/server/src/test/java/org/opensearch/index/seqno/ReplicationTrackerTests.java b/server/src/test/java/org/opensearch/index/seqno/ReplicationTrackerTests.java
index 7cfc95d7f5cff..9b5653531d540 100644
--- a/server/src/test/java/org/opensearch/index/seqno/ReplicationTrackerTests.java
+++ b/server/src/test/java/org/opensearch/index/seqno/ReplicationTrackerTests.java
@@ -51,6 +51,7 @@
 import org.opensearch.index.shard.ShardId;
 import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
 import org.opensearch.indices.replication.common.ReplicationType;
+import org.opensearch.indices.replication.common.SegmentReplicationLagTimer;
 import org.opensearch.test.IndexSettingsModule;
 
 import java.io.IOException;
@@ -1827,14 +1828,18 @@ public void testSegmentReplicationCheckpointTracking() {
         );
 
         tracker.setLatestReplicationCheckpoint(initialCheckpoint);
+        tracker.startReplicationLagTimers(initialCheckpoint);
         tracker.setLatestReplicationCheckpoint(secondCheckpoint);
+        tracker.startReplicationLagTimers(secondCheckpoint);
         tracker.setLatestReplicationCheckpoint(thirdCheckpoint);
+        tracker.startReplicationLagTimers(thirdCheckpoint);
 
         Set<SegmentReplicationShardStats> groupStats = tracker.getSegmentReplicationStats();
         assertEquals(inSyncAllocationIds.size(), groupStats.size());
         for (SegmentReplicationShardStats shardStat : groupStats) {
             assertEquals(3, shardStat.getCheckpointsBehindCount());
             assertEquals(100L, shardStat.getBytesBehindCount());
+            assertTrue(shardStat.getCurrentReplicationLagMillis() >= shardStat.getCurrentReplicationTimeMillis());
         }
 
         // simulate replicas moved up to date.
@@ -1868,6 +1873,75 @@ public void testSegmentReplicationCheckpointTracking() {
         }
     }
 
+    public void testSegmentReplicationCheckpointTrackingInvalidAllocationIDs() {
+        Settings settings = Settings.builder().put(SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT).build();
+        final long initialClusterStateVersion = randomNonNegativeLong();
+        final int numberOfActiveAllocationsIds = randomIntBetween(2, 16);
+        final int numberOfInitializingIds = randomIntBetween(2, 16);
+        final Tuple<Set<AllocationId>, Set<AllocationId>> activeAndInitializingAllocationIds = randomActiveAndInitializingAllocationIds(
+            numberOfActiveAllocationsIds,
+            numberOfInitializingIds
+        );
+        final Set<AllocationId> activeAllocationIds = activeAndInitializingAllocationIds.v1();
+        final Set<AllocationId> initializingIds = activeAndInitializingAllocationIds.v2();
+        AllocationId primaryId = activeAllocationIds.iterator().next();
+        IndexShardRoutingTable routingTable = routingTable(initializingIds, primaryId);
+        final ReplicationTracker tracker = newTracker(primaryId, settings);
+        tracker.updateFromClusterManager(initialClusterStateVersion, ids(activeAllocationIds), routingTable);
+        tracker.activatePrimaryMode(NO_OPS_PERFORMED);
+
+        initializingIds.forEach(aId -> markAsTrackingAndInSyncQuietly(tracker, aId.getId(), NO_OPS_PERFORMED));
+
+        assertEquals(tracker.getReplicationGroup().getRoutingTable(), routingTable);
+        assertEquals(
+            "All active & initializing ids are now marked in-sync",
+            Sets.union(ids(activeAllocationIds), ids(initializingIds)),
+            tracker.getReplicationGroup().getInSyncAllocationIds()
+        );
+
+        assertEquals(
+            "Active ids are in-sync but still unavailable",
+            tracker.getReplicationGroup().getUnavailableInSyncShards(),
+            Sets.difference(ids(activeAllocationIds), Set.of(primaryId.getId()))
+        );
+        assertTrue(activeAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a.getId()).inSync));
+
+        final ReplicationCheckpoint initialCheckpoint = new ReplicationCheckpoint(
+            tracker.shardId(),
+            0L,
+            1,
+            1,
+            1L,
+            Codec.getDefault().getName()
+        );
+        tracker.setLatestReplicationCheckpoint(initialCheckpoint);
+        tracker.startReplicationLagTimers(initialCheckpoint);
+
+        // we expect that the only returned ids from getSegmentReplicationStats will be the initializing ids we marked with
+        // markAsTrackingAndInSyncQuietly.
+        // This is because the ids marked active initially are still unavailable (don't have an associated routing entry).
+        final Set<String> expectedIds = ids(initializingIds);
+        Set<SegmentReplicationShardStats> groupStats = tracker.getSegmentReplicationStats();
+        final Set<String> actualIds = groupStats.stream().map(SegmentReplicationShardStats::getAllocationId).collect(Collectors.toSet());
+        assertEquals(expectedIds, actualIds);
+        for (SegmentReplicationShardStats shardStat : groupStats) {
+            assertEquals(1, shardStat.getCheckpointsBehindCount());
+        }
+
+        // simulate replicas moved up to date.
+        final Map<String, ReplicationTracker.CheckpointState> checkpoints = tracker.checkpoints;
+        for (String id : expectedIds) {
+            final ReplicationTracker.CheckpointState checkpointState = checkpoints.get(id);
+            assertEquals(1, checkpointState.checkpointTimers.size());
+            tracker.updateVisibleCheckpointForShard(id, initialCheckpoint);
+            assertEquals(0, checkpointState.checkpointTimers.size());
+        }
+
+        // Unknown allocation ID will be ignored.
+        tracker.updateVisibleCheckpointForShard("randomAllocationID", initialCheckpoint);
+        assertThrows(AssertionError.class, () -> tracker.updateVisibleCheckpointForShard(tracker.shardAllocationId, initialCheckpoint));
+    }
+
     public void testPrimaryContextHandoffWithRemoteTranslogEnabled() throws IOException {
         Settings settings = Settings.builder().put("index.remote_store.translog.enabled", "true").build();
         final IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings);
@@ -2061,4 +2135,15 @@ public void testIllegalStateExceptionIfUnknownAllocationIdWithRemoteTranslogEnab
         expectThrows(IllegalStateException.class, () -> tracker.markAllocationIdAsInSync(randomAlphaOfLength(10), randomNonNegativeLong()));
     }
 
+    public void testSegRepTimer() throws Throwable {
+        SegmentReplicationLagTimer timer = new SegmentReplicationLagTimer();
+        Thread.sleep(100);
+        timer.start();
+        Thread.sleep(100);
+        timer.stop();
+        assertTrue("Total time since timer started should be greater than 100", timer.time() >= 100);
+        assertTrue("Total time since timer was created should be greater than 200", timer.totalElapsedTime() >= 200);
+        assertTrue("Total elapsed time should be greater than time since timer start", timer.totalElapsedTime() - timer.time() >= 100);
+    }
+
 }
diff --git a/server/src/test/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationActionTests.java b/server/src/test/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationActionTests.java
index 7a0d80d9538ad..67dd7a0684084 100644
--- a/server/src/test/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationActionTests.java
+++ b/server/src/test/java/org/opensearch/rest/action/cat/RestCatSegmentReplicationActionTests.java
@@ -84,6 +84,7 @@ public void testSegmentReplicationAction() throws IOException {
                 0L,
                 0L,
                 0L,
+                0L,
                 0L
             );
             segmentReplicationShardStats.setCurrentReplicationState(state);
@@ -141,7 +142,7 @@ public void testSegmentReplicationAction() throws IOException {
                 currentReplicationState.getTargetNode().getHostName(),
                 shardStats.getCheckpointsBehindCount(),
                 new ByteSizeValue(shardStats.getBytesBehindCount()),
-                new TimeValue(shardStats.getCurrentReplicationTimeMillis()),
+                new TimeValue(shardStats.getCurrentReplicationLagMillis()),
                 new TimeValue(shardStats.getLastCompletedReplicationTimeMillis()),
                 rejectedRequestCount
             );