Skip to content

Commit

Permalink
Optimize checksum creation for remote cluster state (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#16046)

* Support parallelisation in remote publication checksum computation

Signed-off-by: Himshikha Gupta <himshikh@amazon.com>
  • Loading branch information
himshikha authored and dk2k committed Oct 16, 2024
1 parent 421389b commit 561622d
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.DiffableStringMap;
import org.opensearch.common.CheckedFunction;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.io.stream.BufferedChecksumStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -22,11 +24,15 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParseException;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;

import com.jcraft.jzlib.JZlib;

Expand All @@ -37,6 +43,7 @@
*/
public class ClusterStateChecksum implements ToXContentFragment, Writeable {

public static final int COMPONENT_SIZE = 11;
static final String ROUTING_TABLE_CS = "routing_table";
static final String NODES_CS = "discovery_nodes";
static final String BLOCKS_CS = "blocks";
Expand Down Expand Up @@ -65,62 +72,103 @@ public class ClusterStateChecksum implements ToXContentFragment, Writeable {
long indicesChecksum;
long clusterStateChecksum;

public ClusterStateChecksum(ClusterState clusterState) {
try (
BytesStreamOutput out = new BytesStreamOutput();
BufferedChecksumStreamOutput checksumOut = new BufferedChecksumStreamOutput(out)
) {
clusterState.routingTable().writeVerifiableTo(checksumOut);
routingTableChecksum = checksumOut.getChecksum();

checksumOut.reset();
clusterState.nodes().writeVerifiableTo(checksumOut);
nodesChecksum = checksumOut.getChecksum();

checksumOut.reset();
clusterState.coordinationMetadata().writeVerifiableTo(checksumOut);
coordinationMetadataChecksum = checksumOut.getChecksum();

// Settings create sortedMap by default, so no explicit sorting required here.
checksumOut.reset();
Settings.writeSettingsToStream(clusterState.metadata().persistentSettings(), checksumOut);
settingMetadataChecksum = checksumOut.getChecksum();

checksumOut.reset();
Settings.writeSettingsToStream(clusterState.metadata().transientSettings(), checksumOut);
transientSettingsMetadataChecksum = checksumOut.getChecksum();

checksumOut.reset();
clusterState.metadata().templatesMetadata().writeVerifiableTo(checksumOut);
templatesMetadataChecksum = checksumOut.getChecksum();

checksumOut.reset();
checksumOut.writeStringCollection(clusterState.metadata().customs().keySet());
customMetadataMapChecksum = checksumOut.getChecksum();

checksumOut.reset();
((DiffableStringMap) clusterState.metadata().hashesOfConsistentSettings()).writeTo(checksumOut);
hashesOfConsistentSettingsChecksum = checksumOut.getChecksum();

checksumOut.reset();
checksumOut.writeMapValues(
public ClusterStateChecksum(ClusterState clusterState, ThreadPool threadpool) {
long start = threadpool.relativeTimeInNanos();
ExecutorService executorService = threadpool.executor(ThreadPool.Names.REMOTE_STATE_CHECKSUM);
CountDownLatch latch = new CountDownLatch(COMPONENT_SIZE);

executeChecksumTask((stream) -> {
clusterState.routingTable().writeVerifiableTo(stream);
return null;
}, checksum -> routingTableChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
clusterState.nodes().writeVerifiableTo(stream);
return null;
}, checksum -> nodesChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
clusterState.coordinationMetadata().writeVerifiableTo(stream);
return null;
}, checksum -> coordinationMetadataChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
Settings.writeSettingsToStream(clusterState.metadata().persistentSettings(), stream);
return null;
}, checksum -> settingMetadataChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
Settings.writeSettingsToStream(clusterState.metadata().transientSettings(), stream);
return null;
}, checksum -> transientSettingsMetadataChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
clusterState.metadata().templatesMetadata().writeVerifiableTo(stream);
return null;
}, checksum -> templatesMetadataChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
stream.writeStringCollection(clusterState.metadata().customs().keySet());
return null;
}, checksum -> customMetadataMapChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
((DiffableStringMap) clusterState.metadata().hashesOfConsistentSettings()).writeTo(stream);
return null;
}, checksum -> hashesOfConsistentSettingsChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
stream.writeMapValues(
clusterState.metadata().indices(),
(stream, value) -> value.writeVerifiableTo((BufferedChecksumStreamOutput) stream)
(checksumStream, value) -> value.writeVerifiableTo((BufferedChecksumStreamOutput) checksumStream)
);
indicesChecksum = checksumOut.getChecksum();

checksumOut.reset();
clusterState.blocks().writeVerifiableTo(checksumOut);
blocksChecksum = checksumOut.getChecksum();

checksumOut.reset();
checksumOut.writeStringCollection(clusterState.customs().keySet());
clusterStateCustomsChecksum = checksumOut.getChecksum();
} catch (IOException e) {
logger.error("Failed to create checksum for cluster state.", e);
return null;
}, checksum -> indicesChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
clusterState.blocks().writeVerifiableTo(stream);
return null;
}, checksum -> blocksChecksum = checksum, executorService, latch);

executeChecksumTask((stream) -> {
stream.writeStringCollection(clusterState.customs().keySet());
return null;
}, checksum -> clusterStateCustomsChecksum = checksum, executorService, latch);

try {
latch.await();
} catch (InterruptedException e) {
throw new RemoteStateTransferException("Failed to create checksum for cluster state.", e);
}
createClusterStateChecksum();
logger.debug("Checksum execution time {}", TimeValue.nsecToMSec(threadpool.relativeTimeInNanos() - start));
}

private void executeChecksumTask(
CheckedFunction<BufferedChecksumStreamOutput, Void, IOException> checksumTask,
Consumer<Long> checksumConsumer,
ExecutorService executorService,
CountDownLatch latch
) {
executorService.execute(() -> {
try {
long checksum = createChecksum(checksumTask);
checksumConsumer.accept(checksum);
latch.countDown();
} catch (IOException e) {
throw new RemoteStateTransferException("Failed to execute checksum task", e);
}
});
}

private long createChecksum(CheckedFunction<BufferedChecksumStreamOutput, Void, IOException> task) throws IOException {
try (
BytesStreamOutput out = new BytesStreamOutput();
BufferedChecksumStreamOutput checksumOut = new BufferedChecksumStreamOutput(out)
) {
task.apply(checksumOut);
return checksumOut.getChecksum();
}
}

private void createClusterStateChecksum() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ public RemoteClusterStateManifestInfo writeFullMetadata(ClusterState clusterStat
uploadedMetadataResults,
previousClusterUUID,
clusterStateDiffManifest,
!remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE) ? new ClusterStateChecksum(clusterState) : null,
!remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE)
? new ClusterStateChecksum(clusterState, threadpool)
: null,
false,
codecVersion
);
Expand Down Expand Up @@ -539,7 +541,9 @@ public RemoteClusterStateManifestInfo writeIncrementalMetadata(
uploadedMetadataResults,
previousManifest.getPreviousClusterUUID(),
clusterStateDiffManifest,
!remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE) ? new ClusterStateChecksum(clusterState) : null,
!remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE)
? new ClusterStateChecksum(clusterState, threadpool)
: null,
false,
previousManifest.getCodecVersion()
);
Expand Down Expand Up @@ -1010,7 +1014,9 @@ public RemoteClusterStateManifestInfo markLastStateAsCommitted(
uploadedMetadataResults,
previousManifest.getPreviousClusterUUID(),
previousManifest.getDiffManifest(),
!remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE) ? new ClusterStateChecksum(clusterState) : null,
!remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE)
? new ClusterStateChecksum(clusterState, threadpool)
: null,
true,
previousManifest.getCodecVersion()
);
Expand Down Expand Up @@ -1631,7 +1637,7 @@ void validateClusterStateFromChecksum(
String localNodeId,
boolean isFullStateDownload
) {
ClusterStateChecksum newClusterStateChecksum = new ClusterStateChecksum(clusterState);
ClusterStateChecksum newClusterStateChecksum = new ClusterStateChecksum(clusterState, threadpool);
List<String> failedValidation = newClusterStateChecksum.getMismatchEntities(manifest.getClusterStateChecksum());
if (failedValidation.isEmpty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.core.service.ReportingService;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.gateway.remote.ClusterStateChecksum;
import org.opensearch.node.Node;

import java.io.IOException;
Expand Down Expand Up @@ -118,6 +119,7 @@ public static class Names {
public static final String REMOTE_RECOVERY = "remote_recovery";
public static final String REMOTE_STATE_READ = "remote_state_read";
public static final String INDEX_SEARCHER = "index_searcher";
public static final String REMOTE_STATE_CHECKSUM = "remote_state_checksum";
}

/**
Expand Down Expand Up @@ -191,6 +193,7 @@ public static ThreadPoolType fromType(String type) {
map.put(Names.REMOTE_RECOVERY, ThreadPoolType.SCALING);
map.put(Names.REMOTE_STATE_READ, ThreadPoolType.SCALING);
map.put(Names.INDEX_SEARCHER, ThreadPoolType.RESIZABLE);
map.put(Names.REMOTE_STATE_CHECKSUM, ThreadPoolType.FIXED);
THREAD_POOL_TYPES = Collections.unmodifiableMap(map);
}

Expand Down Expand Up @@ -307,6 +310,10 @@ public ThreadPool(
runnableTaskListener
)
);
builders.put(
Names.REMOTE_STATE_CHECKSUM,
new FixedExecutorBuilder(settings, Names.REMOTE_STATE_CHECKSUM, ClusterStateChecksum.COMPONENT_SIZE, 1000)
);

for (final ExecutorBuilder<?> builder : customBuilders) {
if (builders.containsKey(builder.name())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.opensearch.gateway.remote.ClusterMetadataManifest.UploadedMetadataAttribute;
import org.opensearch.test.EqualsHashCodeTestUtils;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.junit.After;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -64,6 +67,14 @@

public class ClusterMetadataManifestTests extends OpenSearchTestCase {

private final ThreadPool threadPool = new TestThreadPool(getClass().getName());

@After
public void teardown() throws Exception {
super.tearDown();
threadPool.shutdown();
}

public void testClusterMetadataManifestXContentV0() throws IOException {
UploadedIndexMetadata uploadedIndexMetadata = new UploadedIndexMetadata("test-index", "test-uuid", "/test/upload/path", CODEC_V0);
ClusterMetadataManifest originalManifest = ClusterMetadataManifest.builder()
Expand Down Expand Up @@ -214,7 +225,7 @@ public void testClusterMetadataManifestSerializationEqualsHashCode() {
"indicesRoutingDiffPath"
)
)
.checksum(new ClusterStateChecksum(createClusterState()))
.checksum(new ClusterStateChecksum(createClusterState(), threadPool))
.build();
{ // Mutate Cluster Term
EqualsHashCodeTestUtils.checkEqualsAndHashCode(
Expand Down Expand Up @@ -647,7 +658,7 @@ public void testClusterMetadataManifestXContentV4() throws IOException {
UploadedIndexMetadata uploadedIndexMetadata = new UploadedIndexMetadata("test-index", "test-uuid", "/test/upload/path");
UploadedMetadataAttribute uploadedMetadataAttribute = new UploadedMetadataAttribute("attribute_name", "testing_attribute");
final StringKeyDiffProvider<IndexRoutingTable> routingTableIncrementalDiff = Mockito.mock(StringKeyDiffProvider.class);
ClusterStateChecksum checksum = new ClusterStateChecksum(createClusterState());
ClusterStateChecksum checksum = new ClusterStateChecksum(createClusterState(), threadPool);
ClusterMetadataManifest originalManifest = ClusterMetadataManifest.builder()
.clusterTerm(1L)
.stateVersion(1L)
Expand Down
Loading

0 comments on commit 561622d

Please sign in to comment.