diff --git a/atomix/cluster/src/main/java/io/atomix/raft/impl/RaftContext.java b/atomix/cluster/src/main/java/io/atomix/raft/impl/RaftContext.java index 07e6bf87140d..c5e361b61389 100644 --- a/atomix/cluster/src/main/java/io/atomix/raft/impl/RaftContext.java +++ b/atomix/cluster/src/main/java/io/atomix/raft/impl/RaftContext.java @@ -469,12 +469,15 @@ public long setCommitIndex(final long commitIndex) { */ public void addSnapshotReplicationListener( final SnapshotReplicationListener snapshotReplicationListener) { - snapshotReplicationListeners.add(snapshotReplicationListener); - if (ongoingSnapshotReplication) { - // Notify listener immediately if it registered during an ongoing replication. - // This is to prevent missing necessary state transitions. - snapshotReplicationListener.onSnapshotReplicationStarted(); - } + threadContext.execute( + () -> { + snapshotReplicationListeners.add(snapshotReplicationListener); + if (ongoingSnapshotReplication) { + // Notify listener immediately if it registered during an ongoing replication. + // This is to prevent missing necessary state transitions. + snapshotReplicationListener.onSnapshotReplicationStarted(); + } + }); } /** @@ -484,17 +487,24 @@ public void addSnapshotReplicationListener( */ public void removeSnapshotReplicationListener( final SnapshotReplicationListener snapshotReplicationListener) { - snapshotReplicationListeners.remove(snapshotReplicationListener); + threadContext.execute(() -> snapshotReplicationListeners.remove(snapshotReplicationListener)); } public void notifySnapshotReplicationStarted() { - ongoingSnapshotReplication = true; - snapshotReplicationListeners.forEach(SnapshotReplicationListener::onSnapshotReplicationStarted); + threadContext.execute( + () -> { + ongoingSnapshotReplication = true; + snapshotReplicationListeners.forEach( + SnapshotReplicationListener::onSnapshotReplicationStarted); + }); } public void notifySnapshotReplicationCompleted() { - snapshotReplicationListeners.forEach(l -> l.onSnapshotReplicationCompleted(term)); - ongoingSnapshotReplication = false; + threadContext.execute( + () -> { + snapshotReplicationListeners.forEach(l -> l.onSnapshotReplicationCompleted(term)); + ongoingSnapshotReplication = false; + }); } /** diff --git a/atomix/cluster/src/test/java/io/atomix/raft/SnapshotReplicationListenerTest.java b/atomix/cluster/src/test/java/io/atomix/raft/SnapshotReplicationListenerTest.java index da5a90a06789..8c0d6afe36a6 100644 --- a/atomix/cluster/src/test/java/io/atomix/raft/SnapshotReplicationListenerTest.java +++ b/atomix/cluster/src/test/java/io/atomix/raft/SnapshotReplicationListenerTest.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; @@ -68,8 +67,7 @@ public void shouldNotifyOnRegisteringListener() { final var follower = raftRule.getFollower().orElseThrow(); // then follower.getContext().notifySnapshotReplicationStarted(); - verify(snapshotReplicationListener, never()).onSnapshotReplicationStarted(); follower.getContext().addSnapshotReplicationListener(snapshotReplicationListener); - verify(snapshotReplicationListener).onSnapshotReplicationStarted(); + verify(snapshotReplicationListener, timeout(1_000).times(1)).onSnapshotReplicationStarted(); } }