Skip to content

Commit

Permalink
Fix state across search requests using weighted shard routing
Browse files Browse the repository at this point in the history
Signed-off-by: Anshu Agarwal <anshukag@amazon.com>
  • Loading branch information
Anshu Agarwal committed Jan 25, 2023
1 parent 493eae8 commit 2007e8b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
public class IndexShardRoutingTable implements Iterable<ShardRouting> {

final ShardShuffler shuffler;
final ShardShuffler shufflerForWeightedRouting;
final ShardId shardId;

final ShardRouting primary;
Expand Down Expand Up @@ -105,6 +106,7 @@ public class IndexShardRoutingTable implements Iterable<ShardRouting> {
IndexShardRoutingTable(ShardId shardId, List<ShardRouting> shards) {
this.shardId = shardId;
this.shuffler = new RotationShardShuffler(Randomness.get().nextInt());
this.shufflerForWeightedRouting = new RotationShardShuffler(Randomness.get().nextInt());
this.shards = Collections.unmodifiableList(shards);

ShardRouting primary = null;
Expand Down Expand Up @@ -323,11 +325,11 @@ public ShardIterator activeInitializingShardsWeightedIt(
double defaultWeight,
boolean isFailOpenEnabled
) {
final int seed = shuffler.nextSeed();
final int seed = shufflerForWeightedRouting.nextSeed();
List<ShardRouting> ordered = new ArrayList<>();
List<ShardRouting> orderedActiveShards = getActiveShardsByWeight(weightedRouting, nodes, defaultWeight);
List<ShardRouting> orderedListWithDistinctShards;
ordered.addAll(shuffler.shuffle(orderedActiveShards, seed));
ordered.addAll(shufflerForWeightedRouting.shuffle(orderedActiveShards, seed));
if (!allInitializingShards.isEmpty()) {
List<ShardRouting> orderedInitializingShards = getInitializingShardsByWeight(weightedRouting, nodes, defaultWeight);
ordered.addAll(orderedInitializingShards);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@
import org.opensearch.test.ClusterServiceUtils;
import org.opensearch.threadpool.TestThreadPool;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import static java.util.Collections.singletonMap;
Expand Down Expand Up @@ -768,4 +770,94 @@ public void testWeightedRoutingShardState() {
terminate(threadPool);
}
}

/**
* Test to validate that shard routing state is maintained across requests, requests are assigned to nodes
* according to assigned routing weights
*/
public void testWeightedRoutingShardStateWithDifferentWeights() {
TestThreadPool threadPool = null;
try {
Settings.Builder settings = Settings.builder()
.put("cluster.routing.allocation.node_concurrent_recoveries", 10)
.put("cluster.routing.allocation.awareness.attributes", "zone");
AllocationService strategy = createAllocationService(settings.build());

Metadata metadata = Metadata.builder()
.put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
.build();

RoutingTable routingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build();

ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
.metadata(metadata)
.routingTable(routingTable)
.build();

threadPool = new TestThreadPool("testThatOnlyNodesSupport");
ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool);

Map<String, String> node1Attributes = new HashMap<>();
node1Attributes.put("zone", "zone1");
Map<String, String> node2Attributes = new HashMap<>();
node2Attributes.put("zone", "zone2");
Map<String, String> node3Attributes = new HashMap<>();
node3Attributes.put("zone", "zone3");
clusterState = ClusterState.builder(clusterState)
.nodes(
DiscoveryNodes.builder()
.add(newNode("node1", unmodifiableMap(node1Attributes)))
.add(newNode("node2", unmodifiableMap(node2Attributes)))
.add(newNode("node3", unmodifiableMap(node3Attributes)))
.localNodeId("node1")
)
.build();
clusterState = strategy.reroute(clusterState, "reroute");

clusterState = startInitializingShardsAndReroute(strategy, clusterState);
clusterState = startInitializingShardsAndReroute(strategy, clusterState);
List<Map<String, Double>> weightsList = new ArrayList<>();
Map<String, Double> weights1 = Map.of("zone1", 1.0, "zone2", 1.0, "zone3", 0.0);
weightsList.add(weights1);

Map<String, Double> weights2 = Map.of("zone1", 1.0, "zone2", 0.0, "zone3", 1.0);
weightsList.add(weights2);

Map<String, Double> weights3 = Map.of("zone1", 0.0, "zone2", 1.0, "zone3", 1.0);
weightsList.add(weights3);

Map<String, Double> weights4 = Map.of("zone1", 1.0, "zone2", 1.0, "zone3", 0.0);
weightsList.add(weights4);

for (int i = 0; i < weightsList.size(); i++) {
WeightedRouting weightedRouting = new WeightedRouting("zone", weightsList.get(i));
ShardIterator shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true);

ShardRouting shardRouting1 = shardIterator.nextOrNull();

shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true);

ShardRouting shardRouting2 = shardIterator.nextOrNull();

shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true);

ShardRouting shardRouting3 = shardIterator.nextOrNull();

assertEquals(shardRouting1.currentNodeId(), shardRouting3.currentNodeId());
assertNotEquals(shardRouting1.currentNodeId(), shardRouting2.currentNodeId());
}

} finally {
terminate(threadPool);
}
}
}

0 comments on commit 2007e8b

Please sign in to comment.