From b6734fcc80daf692a3ea44555ce742a6c19d9321 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]" <github-actions[bot]@users.noreply.github.com>
Date: Tue, 19 Mar 2024 15:36:01 +0000
Subject: [PATCH] PrimaryShardAllocator refactor to abstract out shard state
 and method calls (#9760)

* PrimaryShardAllocator refactor to abstract out shard state and method calls

Signed-off-by: Shivansh Arora <shivansh.arora@protonmail.com>
Signed-off-by: Shivansh Arora <hishiv@amazon.com>
(cherry picked from commit afd3969e8df4edc0ba76901fad822fbfe84fe04b)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
---
 .../gateway/PrimaryShardAllocator.java        | 119 ++++++++++++------
 1 file changed, 83 insertions(+), 36 deletions(-)

diff --git a/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java b/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java
index 2807be00feeaa..5046873830c01 100644
--- a/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java
+++ b/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java
@@ -81,7 +81,7 @@ public abstract class PrimaryShardAllocator extends BaseGatewayShardAllocator {
     /**
      * Is the allocator responsible for allocating the given {@link ShardRouting}?
      */
-    private static boolean isResponsibleFor(final ShardRouting shard) {
+    protected static boolean isResponsibleFor(final ShardRouting shard) {
         return shard.primary() // must be primary
             && shard.unassigned() // must be unassigned
             // only handle either an existing store or a snapshot recovery
@@ -89,19 +89,20 @@ private static boolean isResponsibleFor(final ShardRouting shard) {
                 || shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT);
     }
 
-    @Override
-    public AllocateUnassignedDecision makeAllocationDecision(
-        final ShardRouting unassignedShard,
-        final RoutingAllocation allocation,
-        final Logger logger
-    ) {
+    /**
+     * Skip doing fetchData call for a shard if recovery mode is snapshot. Also do not take decision if allocator is
+     * not responsible for this particular shard.
+     *
+     * @param unassignedShard unassigned shard routing
+     * @param allocation      routing allocation object
+     * @return allocation decision taken for this shard
+     */
+    protected AllocateUnassignedDecision getInEligibleShardDecision(ShardRouting unassignedShard, RoutingAllocation allocation) {
         if (isResponsibleFor(unassignedShard) == false) {
             // this allocator is not responsible for allocating this shard
             return AllocateUnassignedDecision.NOT_TAKEN;
         }
-
         final boolean explain = allocation.debugDecision();
-
         if (unassignedShard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT
             && allocation.snapshotShardSizeInfo().getShardSize(unassignedShard) == null) {
             List<NodeAllocationResult> nodeDecisions = null;
@@ -110,9 +111,45 @@ public AllocateUnassignedDecision makeAllocationDecision(
             }
             return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, nodeDecisions);
         }
+        return null;
+    }
 
+    @Override
+    public AllocateUnassignedDecision makeAllocationDecision(
+        final ShardRouting unassignedShard,
+        final RoutingAllocation allocation,
+        final Logger logger
+    ) {
+        AllocateUnassignedDecision decision = getInEligibleShardDecision(unassignedShard, allocation);
+        if (decision != null) {
+            return decision;
+        }
         final FetchResult<NodeGatewayStartedShards> shardState = fetchData(unassignedShard, allocation);
-        if (shardState.hasData() == false) {
+        List<NodeGatewayStartedShards> nodeShardStates = adaptToNodeStartedShardList(shardState);
+        return getAllocationDecision(unassignedShard, allocation, nodeShardStates, logger);
+    }
+
+    /**
+    * Transforms {@link FetchResult} of {@link NodeGatewayStartedShards} to {@link List} of {@link NodeGatewayStartedShards}
+    * Returns null if {@link FetchResult} does not have any data.
+    */
+    private static List<NodeGatewayStartedShards> adaptToNodeStartedShardList(FetchResult<NodeGatewayStartedShards> shardsState) {
+        if (!shardsState.hasData()) {
+            return null;
+        }
+        List<NodeGatewayStartedShards> nodeShardStates = new ArrayList<>();
+        shardsState.getData().forEach((node, nodeGatewayStartedShard) -> { nodeShardStates.add(nodeGatewayStartedShard); });
+        return nodeShardStates;
+    }
+
+    protected AllocateUnassignedDecision getAllocationDecision(
+        ShardRouting unassignedShard,
+        RoutingAllocation allocation,
+        List<NodeGatewayStartedShards> shardState,
+        Logger logger
+    ) {
+        final boolean explain = allocation.debugDecision();
+        if (shardState == null) {
             allocation.setHasPendingAsyncFetch();
             List<NodeAllocationResult> nodeDecisions = null;
             if (explain) {
@@ -120,7 +157,6 @@ public AllocateUnassignedDecision makeAllocationDecision(
             }
             return AllocateUnassignedDecision.no(AllocationStatus.FETCHING_SHARD_DATA, nodeDecisions);
         }
-
         // don't create a new IndexSetting object for every shard as this could cause a lot of garbage
         // on cluster restart if we allocate a boat load of shards
         final IndexMetadata indexMetadata = allocation.metadata().getIndexSafe(unassignedShard.index());
@@ -260,11 +296,11 @@ public AllocateUnassignedDecision makeAllocationDecision(
      */
     private static List<NodeAllocationResult> buildNodeDecisions(
         NodesToAllocate nodesToAllocate,
-        FetchResult<NodeGatewayStartedShards> fetchedShardData,
+        List<NodeGatewayStartedShards> fetchedShardData,
         Set<String> inSyncAllocationIds
     ) {
         List<NodeAllocationResult> nodeResults = new ArrayList<>();
-        Collection<NodeGatewayStartedShards> ineligibleShards;
+        Collection<NodeGatewayStartedShards> ineligibleShards = new ArrayList<>();
         if (nodesToAllocate != null) {
             final Set<DiscoveryNode> discoNodes = new HashSet<>();
             nodeResults.addAll(
@@ -280,15 +316,13 @@ private static List<NodeAllocationResult> buildNodeDecisions(
                     })
                     .collect(Collectors.toList())
             );
-            ineligibleShards = fetchedShardData.getData()
-                .values()
-                .stream()
+            ineligibleShards = fetchedShardData.stream()
                 .filter(shardData -> discoNodes.contains(shardData.getNode()) == false)
                 .collect(Collectors.toList());
         } else {
             // there were no shard copies that were eligible for being assigned the allocation,
             // so all fetched shard data are ineligible shards
-            ineligibleShards = fetchedShardData.getData().values();
+            ineligibleShards = fetchedShardData;
         }
 
         nodeResults.addAll(
@@ -328,12 +362,12 @@ protected static NodeShardsResult buildNodeShardsResult(
         boolean matchAnyShard,
         Set<String> ignoreNodes,
         Set<String> inSyncAllocationIds,
-        FetchResult<NodeGatewayStartedShards> shardState,
+        List<NodeGatewayStartedShards> shardState,
         Logger logger
     ) {
         List<NodeGatewayStartedShards> nodeShardStates = new ArrayList<>();
         int numberOfAllocationsFound = 0;
-        for (NodeGatewayStartedShards nodeShardState : shardState.getData().values()) {
+        for (NodeGatewayStartedShards nodeShardState : shardState) {
             DiscoveryNode node = nodeShardState.getNode();
             String allocationId = nodeShardState.allocationId();
 
@@ -386,11 +420,27 @@ protected static NodeShardsResult buildNodeShardsResult(
             }
         }
 
-        /*
-          Orders the active shards copies based on below comparators
-          1. No store exception i.e. shard copy is readable
-          2. Prefer previous primary shard
-          3. Prefer shard copy with the highest replication checkpoint. It is NO-OP for doc rep enabled indices.
+        nodeShardStates.sort(createActiveShardComparator(matchAnyShard, inSyncAllocationIds));
+
+        if (logger.isTraceEnabled()) {
+            logger.trace(
+                "{} candidates for allocation: {}",
+                shard,
+                nodeShardStates.stream().map(s -> s.getNode().getName()).collect(Collectors.joining(", "))
+            );
+        }
+        return new NodeShardsResult(nodeShardStates, numberOfAllocationsFound);
+    }
+
+    private static Comparator<NodeGatewayStartedShards> createActiveShardComparator(
+        boolean matchAnyShard,
+        Set<String> inSyncAllocationIds
+    ) {
+        /**
+         * Orders the active shards copies based on below comparators
+         * 1. No store exception i.e. shard copy is readable
+         * 2. Prefer previous primary shard
+         * 3. Prefer shard copy with the highest replication checkpoint. It is NO-OP for doc rep enabled indices.
          */
         final Comparator<NodeGatewayStartedShards> comparator; // allocation preference
         if (matchAnyShard) {
@@ -406,16 +456,7 @@ protected static NodeShardsResult buildNodeShardsResult(
                 .thenComparing(HIGHEST_REPLICATION_CHECKPOINT_FIRST_COMPARATOR);
         }
 
-        nodeShardStates.sort(comparator);
-
-        if (logger.isTraceEnabled()) {
-            logger.trace(
-                "{} candidates for allocation: {}",
-                shard,
-                nodeShardStates.stream().map(s -> s.getNode().getName()).collect(Collectors.joining(", "))
-            );
-        }
-        return new NodeShardsResult(nodeShardStates, numberOfAllocationsFound);
+        return comparator;
     }
 
     /**
@@ -457,7 +498,10 @@ private static NodesToAllocate buildNodesToAllocate(
 
     protected abstract FetchResult<NodeGatewayStartedShards> fetchData(ShardRouting shard, RoutingAllocation allocation);
 
-    private static class NodeShardsResult {
+    /**
+     * This class encapsulates the result of a call to {@link #buildNodeShardsResult}
+     */
+    static class NodeShardsResult {
         final List<NodeGatewayStartedShards> orderedAllocationCandidates;
         final int allocationsFound;
 
@@ -467,7 +511,10 @@ private static class NodeShardsResult {
         }
     }
 
-    static class NodesToAllocate {
+    /**
+     * This class encapsulates the result of a call to {@link #buildNodesToAllocate}
+     */
+    protected static class NodesToAllocate {
         final List<DecidedNode> yesNodeShards;
         final List<DecidedNode> throttleNodeShards;
         final List<DecidedNode> noNodeShards;