Skip to content

Commit

Permalink
Code cleanup.
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Handalian <handalm@amazon.com>
  • Loading branch information
mch2 committed May 26, 2022
1 parent 16cafd1 commit 4c3e74d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ class FileChunkTransportRequestHandler implements TransportRequestHandler<FileCh
final AtomicLong bytesSinceLastPause = new AtomicLong();

@Override
public void messageReceived(FileChunkRequest request, TransportChannel channel, Task task) throws Exception {
public void messageReceived(final FileChunkRequest request, TransportChannel channel, Task task) throws Exception {
try (ReplicationRef<RecoveryTarget> recoveryRef = onGoingRecoveries.getSafe(request.recoveryId(), request.shardId())) {
final RecoveryTarget recoveryTarget = recoveryRef.get();
final ActionListener<Void> listener = recoveryTarget.createOrFinishListener(channel, Actions.FILE_CHUNK, request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

package org.opensearch.indices.recovery;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.store.RateLimiter;
import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionListener;
Expand Down Expand Up @@ -61,6 +63,8 @@
*/
public class RemoteRecoveryTargetHandler implements RecoveryTargetHandler {

private static final Logger logger = LogManager.getLogger(RemoteRecoveryTargetHandler.class);

private final TransportService transportService;
private final long recoveryId;
private final ShardId shardId;
Expand All @@ -85,10 +89,14 @@ public RemoteRecoveryTargetHandler(
Consumer<Long> onSourceThrottle
) {
this.transportService = transportService;
// It is safe to pass the retry timeout value here because RemoteRecoveryTargetHandler
// created per recovery. Any change to RecoverySettings will be applied on the next
// recovery.
this.retryableTransportClient = new RetryableTransportClient(
transportService,
targetNode,
recoverySettings.internalActionRetryTimeout()
recoverySettings.internalActionRetryTimeout(),
logger
);
this.recoveryId = recoveryId;
this.shardId = shardId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.opensearch.indices.recovery;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.LegacyESVersion;
Expand Down Expand Up @@ -40,20 +39,21 @@
*/
public final class RetryableTransportClient {

private static final Logger logger = LogManager.getLogger(RetryableTransportClient.class);

private final ThreadPool threadPool;
private final Map<Object, RetryableAction<?>> onGoingRetryableActions = ConcurrentCollections.newConcurrentMap();
private volatile boolean isCancelled = false;
private final TransportService transportService;
private final TimeValue retryTimeout;
private final DiscoveryNode targetNode;

public RetryableTransportClient(TransportService transportService, DiscoveryNode targetNode, TimeValue retryTimeout) {
private final Logger logger;

public RetryableTransportClient(TransportService transportService, DiscoveryNode targetNode, TimeValue retryTimeout, Logger logger) {
this.threadPool = transportService.getThreadPool();
this.transportService = transportService;
this.retryTimeout = retryTimeout;
this.targetNode = targetNode;
this.logger = logger;
}

/**
Expand Down Expand Up @@ -105,7 +105,7 @@ public boolean shouldRetry(Exception e) {
onGoingRetryableActions.put(key, retryableAction);
retryableAction.run();
if (isCancelled) {
retryableAction.cancel(new CancellableThreads.ExecutionCancelledException("recovery was cancelled"));
retryableAction.cancel(new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"));
}
}

Expand All @@ -114,7 +114,7 @@ public void cancel() {
if (onGoingRetryableActions.isEmpty()) {
return;
}
final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("recovery was cancelled");
final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled");
// Dispatch to generic as cancellation calls can come on the cluster state applier thread
threadPool.generic().execute(() -> {
for (RetryableAction<?> action : onGoingRetryableActions.values()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public class SegmentReplicationState implements ReplicationState {
* @opensearch.internal
*/
public enum Stage {
INIT((byte) 0),
DONE((byte) 0),

DONE((byte) 1);
INIT((byte) 1);

private static final Stage[] STAGES = new Stage[Stage.values().length];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ default void onFailure(ReplicationState state, OpenSearchException e, boolean se
void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure);
}

/**
* Runnable implementation to trigger a replication event.
*/
private class ReplicationRunner implements Runnable {

final long replicationId;
Expand Down Expand Up @@ -156,8 +159,7 @@ private class FileChunkTransportRequestHandler implements TransportRequestHandle
final AtomicLong bytesSinceLastPause = new AtomicLong();

@Override
public void messageReceived(FileChunkRequest request, TransportChannel channel, Task task) throws Exception {
// How many bytes we've copied since we last called RateLimiter.pause
public void messageReceived(final FileChunkRequest request, TransportChannel channel, Task task) throws Exception {
try (ReplicationRef<SegmentReplicationTarget> ref = onGoingReplications.getSafe(request.recoveryId(), request.shardId())) {
final SegmentReplicationTarget target = ref.get();
final ActionListener<Void> listener = target.createOrFinishListener(channel, Actions.FILE_CHUNK, request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ChannelActionListener;
import org.opensearch.common.CheckedFunction;
import org.opensearch.common.Nullable;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.logging.Loggers;
import org.opensearch.common.util.CancellableThreads;
Expand Down Expand Up @@ -108,6 +109,7 @@ public void setLastAccessTime() {
lastAccessTime = System.nanoTime();
}

@Nullable
public ActionListener<Void> markRequestReceivedAndCreateListener(long requestSeqNo, ActionListener<Void> listener) {
return requestTracker.markReceivedAndCreateListener(requestSeqNo, listener);
}
Expand Down Expand Up @@ -182,6 +184,7 @@ protected void ensureRefCount() {
}
}

@Nullable
public ActionListener<Void> createOrFinishListener(
final TransportChannel channel,
final String action,
Expand All @@ -190,6 +193,7 @@ public ActionListener<Void> createOrFinishListener(
return createOrFinishListener(channel, action, request, nullVal -> TransportResponse.Empty.INSTANCE);
}

@Nullable
public ActionListener<Void> createOrFinishListener(
final TransportChannel channel,
final String action,
Expand All @@ -216,6 +220,7 @@ public ActionListener<Void> createOrFinishListener(
* @param request {@link FileChunkRequest} Request containing the file chunk.
* @param bytesSinceLastPause {@link AtomicLong} Bytes since the last pause.
* @param rateLimiter {@link RateLimiter} Rate limiter.
* @param listener {@link ActionListener} listener that completes when the chunk has been written.
* @throws IOException When there is an issue pausing the rate limiter.
*/
public void handleFileChunk(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,34 @@

public class SegmentReplicationTargetServiceTests extends IndexShardTestCase {

public void testTargetReturnsSuccess_listenerCompletes() throws IOException {
Settings settings = Settings.builder().put("node.name", SegmentReplicationTargetServiceTests.class.getSimpleName()).build();
private IndexShard indexShard;
private ReplicationCheckpoint checkpoint;
private SegmentReplicationSource replicationSource;
private SegmentReplicationTargetService sut;

@Override
public void setUp() throws Exception {
super.setUp();
final Settings settings = Settings.builder().put("node.name", SegmentReplicationTargetServiceTests.class.getSimpleName()).build();
final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
final TransportService transportService = mock(TransportService.class);
final IndexShard indexShard = newShard(false, settings);
ReplicationCheckpoint checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L);
indexShard = newShard(false, settings);
checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L);
SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class);
final SegmentReplicationSource replicationSource = mock(SegmentReplicationSource.class);
replicationSource = mock(SegmentReplicationSource.class);
when(replicationSourceFactory.get(indexShard)).thenReturn(replicationSource);

SegmentReplicationTargetService sut = new SegmentReplicationTargetService(
threadPool,
recoverySettings,
transportService,
replicationSourceFactory
);
sut = new SegmentReplicationTargetService(threadPool, recoverySettings, transportService, replicationSourceFactory);
}

@Override
public void tearDown() throws Exception {
closeShards(indexShard);
super.tearDown();
}

public void testTargetReturnsSuccess_listenerCompletes() throws IOException {
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
indexShard,
Expand Down Expand Up @@ -72,22 +82,6 @@ public void onReplicationFailure(SegmentReplicationState state, OpenSearchExcept
}

public void testTargetThrowsException() throws IOException {
Settings settings = Settings.builder().put("node.name", SegmentReplicationTargetServiceTests.class.getSimpleName()).build();
final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
final TransportService transportService = mock(TransportService.class);
final IndexShard indexShard = newShard(false, settings);
ReplicationCheckpoint checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L);
SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class);
SegmentReplicationSource replicationSource = mock(SegmentReplicationSource.class);

SegmentReplicationTargetService sut = new SegmentReplicationTargetService(
threadPool,
recoverySettings,
transportService,
replicationSourceFactory
);

final OpenSearchException expectedError = new OpenSearchException("Fail");
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
Expand Down Expand Up @@ -118,22 +112,6 @@ public void onReplicationFailure(SegmentReplicationState state, OpenSearchExcept
}

public void testBeforeIndexShardClosed_CancelsOngoingReplications() throws IOException {
Settings settings = Settings.builder().put("node.name", SegmentReplicationTargetServiceTests.class.getSimpleName()).build();
final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
final TransportService transportService = mock(TransportService.class);
final IndexShard indexShard = newShard(false, settings);
ReplicationCheckpoint checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L);
SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class);
SegmentReplicationSource replicationSource = mock(SegmentReplicationSource.class);

SegmentReplicationTargetService sut = new SegmentReplicationTargetService(
threadPool,
recoverySettings,
transportService,
replicationSourceFactory
);

final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
indexShard,
Expand All @@ -142,7 +120,7 @@ public void testBeforeIndexShardClosed_CancelsOngoingReplications() throws IOExc
);
final SegmentReplicationTarget spy = Mockito.spy(target);
sut.startReplication(spy);
sut.beforeIndexShardClosed(indexShard.shardId(), indexShard, settings);
sut.beforeIndexShardClosed(indexShard.shardId(), indexShard, Settings.EMPTY);
Mockito.verify(spy, times(1)).cancel(any());
closeShards(indexShard);
}
Expand Down

0 comments on commit 4c3e74d

Please sign in to comment.