Skip to content

Commit

Permalink
Cache index shard limit to optimise ShardLimitsAllocationDecider (ope…
Browse files Browse the repository at this point in the history
…nsearch-project#14962)

* Cache index shard limit per node

Signed-off-by: Rishab Nahata <rnnahata@amazon.com>
  • Loading branch information
imRishN authored Jul 29, 2024
1 parent 95fe9cb commit 122f3f0
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.benchmark.routing.allocation;

import org.opensearch.Version;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.routing.RoutingTable;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.allocation.AllocationService;
import org.opensearch.common.logging.LogConfigurator;
import org.opensearch.common.settings.Settings;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.opensearch.cluster.routing.ShardRoutingState.INITIALIZING;

@Fork(1)
@Warmup(iterations = 3)
@Measurement(iterations = 3)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@SuppressWarnings("unused") // invoked by benchmarking framework
public class RerouteBenchmark {
@Param({
// indices| nodes
" 10000| 500|", })
public String indicesNodes = "1|1";
public int numIndices;
public int numNodes;
public int numShards = 10;
public int numReplicas = 1;

private AllocationService allocationService;
private ClusterState initialClusterState;

@Setup
public void setUp() throws Exception {
LogConfigurator.setNodeName("test");
final String[] params = indicesNodes.split("\\|");
numIndices = toInt(params[0]);
numNodes = toInt(params[1]);

int totalShardCount = (numReplicas + 1) * numShards * numIndices;
Metadata.Builder mb = Metadata.builder();
for (int i = 1; i <= numIndices; i++) {
mb.put(
IndexMetadata.builder("test_" + i)
.settings(Settings.builder().put("index.version.created", Version.CURRENT))
.numberOfShards(numShards)
.numberOfReplicas(numReplicas)
);
}

Metadata metadata = mb.build();
RoutingTable.Builder rb = RoutingTable.builder();
for (int i = 1; i <= numIndices; i++) {
rb.addAsNew(metadata.index("test_" + i));
}
RoutingTable routingTable = rb.build();
initialClusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
.metadata(metadata)
.routingTable(routingTable)
.nodes(setUpClusterNodes(numNodes))
.build();
}

@Benchmark
public ClusterState measureShardAllocationEmptyCluster() throws Exception {
ClusterState clusterState = initialClusterState;
allocationService = Allocators.createAllocationService(
Settings.builder()
.put("cluster.routing.allocation.awareness.attributes", "zone")
.put("cluster.routing.allocation.load_awareness.provisioned_capacity", numNodes)
.put("cluster.routing.allocation.load_awareness.skew_factor", "50")
.put("cluster.routing.allocation.node_concurrent_recoveries", "2")
.build()
);
clusterState = allocationService.reroute(clusterState, "reroute");
while (clusterState.getRoutingNodes().hasUnassignedShards()) {
clusterState = startInitializingShardsAndReroute(allocationService, clusterState);
}
return clusterState;
}

private int toInt(String v) {
return Integer.valueOf(v.trim());
}

private DiscoveryNodes.Builder setUpClusterNodes(int nodes) {
DiscoveryNodes.Builder nb = DiscoveryNodes.builder();
for (int i = 1; i <= nodes; i++) {
Map<String, String> attributes = new HashMap<>();
attributes.put("zone", "zone_" + (i % 3));
nb.add(Allocators.newNode("node_0_" + i, attributes));
}
return nb;
}

private static ClusterState startInitializingShardsAndReroute(AllocationService allocationService, ClusterState clusterState) {
return startShardsAndReroute(allocationService, clusterState, clusterState.routingTable().shardsWithState(INITIALIZING));
}

private static ClusterState startShardsAndReroute(
AllocationService allocationService,
ClusterState clusterState,
List<ShardRouting> initializingShards
) {
return allocationService.reroute(allocationService.applyStartedShards(clusterState, initializingShards), "reroute after starting");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.cluster.block.ClusterBlockLevel;
import org.opensearch.cluster.node.DiscoveryNodeFilters;
import org.opensearch.cluster.routing.allocation.IndexMetadataUpdater;
import org.opensearch.cluster.routing.allocation.decider.ShardsLimitAllocationDecider;
import org.opensearch.common.Nullable;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.collect.MapBuilder;
Expand Down Expand Up @@ -686,6 +687,8 @@ public static APIBlock readFrom(StreamInput input) throws IOException {
private final boolean isSystem;
private final boolean isRemoteSnapshot;

private final int indexTotalShardsPerNodeLimit;

private IndexMetadata(
final Index index,
final long version,
Expand All @@ -711,7 +714,8 @@ private IndexMetadata(
final int routingPartitionSize,
final ActiveShardCount waitForActiveShards,
final Map<String, RolloverInfo> rolloverInfos,
final boolean isSystem
final boolean isSystem,
final int indexTotalShardsPerNodeLimit
) {

this.index = index;
Expand Down Expand Up @@ -746,6 +750,7 @@ private IndexMetadata(
this.rolloverInfos = Collections.unmodifiableMap(rolloverInfos);
this.isSystem = isSystem;
this.isRemoteSnapshot = IndexModule.Type.REMOTE_SNAPSHOT.match(this.settings);
this.indexTotalShardsPerNodeLimit = indexTotalShardsPerNodeLimit;
assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards;
}

Expand Down Expand Up @@ -899,6 +904,10 @@ public Set<String> inSyncAllocationIds(int shardId) {
return inSyncAllocationIds.get(shardId);
}

public int getIndexTotalShardsPerNodeLimit() {
return this.indexTotalShardsPerNodeLimit;
}

@Nullable
public DiscoveryNodeFilters requireFilters() {
return requireFilters;
Expand Down Expand Up @@ -1583,6 +1592,8 @@ public IndexMetadata build() {
);
}

final int indexTotalShardsPerNodeLimit = ShardsLimitAllocationDecider.INDEX_TOTAL_SHARDS_PER_NODE_SETTING.get(settings);

final String uuid = settings.get(SETTING_INDEX_UUID, INDEX_UUID_NA_VALUE);

return new IndexMetadata(
Expand Down Expand Up @@ -1610,7 +1621,8 @@ public IndexMetadata build() {
routingPartitionSize,
waitForActiveShards,
rolloverInfos,
isSystem
isSystem,
indexTotalShardsPerNodeLimit
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

package org.opensearch.cluster.routing.allocation.decider;

import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.routing.RoutingNode;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.ShardRoutingState;
Expand Down Expand Up @@ -125,8 +124,7 @@ private Decision doDecide(
RoutingAllocation allocation,
BiPredicate<Integer, Integer> decider
) {
IndexMetadata indexMd = allocation.metadata().getIndexSafe(shardRouting.index());
final int indexShardLimit = INDEX_TOTAL_SHARDS_PER_NODE_SETTING.get(indexMd.getSettings(), settings);
final int indexShardLimit = allocation.metadata().getIndexSafe(shardRouting.index()).getIndexTotalShardsPerNodeLimit();
// Capture the limit here in case it changes during this method's
// execution
final int clusterShardLimit = this.clusterShardLimit;
Expand Down

0 comments on commit 122f3f0

Please sign in to comment.