diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/PersistStateActivityImpl.java b/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/PersistStateActivityImpl.java index 4b6dd00753b1..210174640733 100644 --- a/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/PersistStateActivityImpl.java +++ b/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/PersistStateActivityImpl.java @@ -57,14 +57,8 @@ public boolean persist(final UUID connectionId, final StandardSyncOutput syncOut AirbyteApiClient.retryWithJitter( () -> airbyteApiClient.getStateApi().getState(new ConnectionIdRequestBody().connectionId(connectionId)), "get state"); - if (featureFlags.needStateValidation() && previousState != null) { - final StateType newStateType = maybeStateWrapper.get().getStateType(); - final StateType prevStateType = convertClientStateTypeToInternal(previousState.getStateType()); - if (isMigration(newStateType, prevStateType) && newStateType == StateType.STREAM) { - validateStreamStates(maybeStateWrapper.get(), configuredCatalog); - } - } + validate(configuredCatalog, maybeStateWrapper, previousState); AirbyteApiClient.retryWithJitter( () -> { @@ -85,6 +79,42 @@ public boolean persist(final UUID connectionId, final StandardSyncOutput syncOut } } + /** + * Validates whether it is safe to persist the new state based on the previously saved state. + * + * @param configuredCatalog The configured catalog of streams for the connection. + * @param newState The new state. + * @param previousState The previous state. + */ + private void validate(final ConfiguredAirbyteCatalog configuredCatalog, + final Optional newState, + final ConnectionState previousState) { + /** + * If state validation is enabled and the previous state exists and is not empty, make sure that + * state will not be lost as part of the migration from legacy -> per stream. + * + * Otherwise, it is okay to update if the previous state is missing or empty. + */ + if (featureFlags.needStateValidation() && !isStateEmpty(previousState)) { + final StateType newStateType = newState.get().getStateType(); + final StateType prevStateType = convertClientStateTypeToInternal(previousState.getStateType()); + + if (isMigration(newStateType, prevStateType) && newStateType == StateType.STREAM) { + validateStreamStates(newState.get(), configuredCatalog); + } + } + } + + /** + * Test whether the connection state is empty. + * + * @param connectionState The connection state. + * @return {@code true} if the connection state is null or empty, {@code false} otherwise. + */ + private boolean isStateEmpty(final ConnectionState connectionState) { + return connectionState == null || connectionState.getState() == null || connectionState.getState().isEmpty(); + } + @VisibleForTesting void validateStreamStates(final StateWrapper state, final ConfiguredAirbyteCatalog configuredCatalog) { final List stateStreamDescriptors = diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/temporal/sync/PersistStateActivityTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/temporal/sync/PersistStateActivityTest.java index f132fda23a7b..80ab6f067e07 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/temporal/sync/PersistStateActivityTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/temporal/sync/PersistStateActivityTest.java @@ -4,13 +4,19 @@ package io.airbyte.workers.temporal.sync; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.JsonNode; import io.airbyte.api.client.AirbyteApiClient; import io.airbyte.api.client.generated.StateApi; import io.airbyte.api.client.invoker.generated.ApiException; +import io.airbyte.api.client.model.generated.ConnectionIdRequestBody; +import io.airbyte.api.client.model.generated.ConnectionState; import io.airbyte.api.client.model.generated.ConnectionStateCreateOrUpdate; +import io.airbyte.api.client.model.generated.ConnectionStateType; import io.airbyte.commons.features.FeatureFlags; import io.airbyte.commons.json.Jsons; import io.airbyte.config.StandardSyncOutput; @@ -42,6 +48,10 @@ class PersistStateActivityTest { private final static UUID CONNECTION_ID = UUID.randomUUID(); + private static final String STREAM_A = "a"; + private static final String STREAM_A_NAMESPACE = "a1"; + private static final String STREAM_B = "b"; + private static final String STREAM_C = "c"; @Mock AirbyteApiClient airbyteApiClient; @@ -78,7 +88,7 @@ void testPersistEmpty() { @Test void testPersist() throws ApiException { - Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true); + when(featureFlags.useStreamCapableState()).thenReturn(true); final JsonNode jsonState = Jsons.jsonNode(Map.ofEntries( Map.entry("some", "state"))); @@ -88,7 +98,7 @@ void testPersist() throws ApiException { persistStateActivity.persist(CONNECTION_ID, new StandardSyncOutput().withState(state), new ConfiguredAirbyteCatalog()); // The ser/der of the state into a state wrapper is tested in StateMessageHelperTest - Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class)); + Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); } // For per-stream state, we expect there to be state for each stream within the configured catalog @@ -97,8 +107,9 @@ void testPersist() throws ApiException { // catalog has a state message when migrating from Legacy to Per-Stream @Test void testPersistWithValidMissingStateDuringMigration() throws ApiException { - final ConfiguredAirbyteStream stream = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("a").withNamespace("a1")); - final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("b")); + final ConfiguredAirbyteStream stream = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE)); + final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B)); final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage() .withType(AirbyteStateType.STREAM) @@ -110,19 +121,20 @@ void testPersistWithValidMissingStateDuringMigration() throws ApiException { final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2)); final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state); - Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true); + when(featureFlags.useStreamCapableState()).thenReturn(true); - mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), Mockito.any(StateType.class))).thenReturn(true); + mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), any(StateType.class))).thenReturn(true); persistStateActivity.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog); - Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class)); + Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); } @Test void testPersistWithValidStateDuringMigration() throws ApiException { - final ConfiguredAirbyteStream stream = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("a").withNamespace("a1")); - final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("b")); + final ConfiguredAirbyteStream stream = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE)); + final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B)); final ConfiguredAirbyteStream stream3 = - new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("c")).withSyncMode(SyncMode.FULL_REFRESH); + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH); final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage() .withType(AirbyteStateType.STREAM) @@ -138,17 +150,18 @@ void testPersistWithValidStateDuringMigration() throws ApiException { final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3)); final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state); - Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true); - mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), Mockito.any(StateType.class))).thenReturn(true); + when(featureFlags.useStreamCapableState()).thenReturn(true); + mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), any(StateType.class))).thenReturn(true); persistStateActivity.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog); - Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class)); + Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); } // Global stream states do not need to be validated during the migration to per-stream state @Test void testPersistWithGlobalStateDuringMigration() throws ApiException { - final ConfiguredAirbyteStream stream = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("a").withNamespace("a1")); - final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("b")); + final ConfiguredAirbyteStream stream = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE)); + final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B)); final AirbyteStateMessage stateMessage = new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL); final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage)); @@ -156,12 +169,130 @@ void testPersistWithGlobalStateDuringMigration() throws ApiException { final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2)); final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state); - Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true); - mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.GLOBAL), Mockito.any(StateType.class))).thenReturn(true); + when(featureFlags.useStreamCapableState()).thenReturn(true); + mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.GLOBAL), any(StateType.class))).thenReturn(true); persistStateActivity.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog); final PersistStateActivityImpl persistStateSpy = spy(persistStateActivity); - Mockito.verify(persistStateSpy, Mockito.times(0)).validateStreamStates(Mockito.any(), Mockito.any()); - Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class)); + Mockito.verify(persistStateSpy, Mockito.times(0)).validateStreamStates(any(), any()); + Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); + } + + @Test + void testPersistWithPerStreamStateDuringMigrationFromEmptyLegacyState() throws ApiException { + /* + * This test covers a scenario where a reset is executed before any successful syncs for a + * connection. When this occurs, an empty, legacy state is stored for the connection. + */ + final ConfiguredAirbyteStream stream = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE)); + final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B)); + final ConfiguredAirbyteStream stream3 = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH); + + final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream( + new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream)) + .withStreamState(Jsons.emptyObject())); + final AirbyteStateMessage stateMessage2 = new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream( + new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream2))); + final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage1, stateMessage2)); + final State state = new State().withState(jsonState); + + final AirbyteApiClient airbyteApiClient1 = mock(AirbyteApiClient.class); + final StateApi stateApi1 = mock(StateApi.class); + final ConnectionState connectionState = mock(ConnectionState.class); + Mockito.lenient().when(connectionState.getStateType()).thenReturn(ConnectionStateType.LEGACY); + Mockito.lenient().when(connectionState.getState()).thenReturn(Jsons.emptyObject()); + when(stateApi1.getState(any(ConnectionIdRequestBody.class))).thenReturn(connectionState); + Mockito.lenient().when(airbyteApiClient1.getStateApi()).thenReturn(stateApi1); + + final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3)); + final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state); + when(featureFlags.useStreamCapableState()).thenReturn(true); + + final PersistStateActivityImpl persistStateActivity1 = new PersistStateActivityImpl(airbyteApiClient1, featureFlags); + + persistStateActivity1.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog); + + Mockito.verify(stateApi1).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); + } + + @Test + void testPersistWithPerStreamStateDuringMigrationFromNullLegacyState() throws ApiException { + final ConfiguredAirbyteStream stream = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE)); + final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B)); + final ConfiguredAirbyteStream stream3 = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH); + + final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream( + new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream)) + .withStreamState(Jsons.emptyObject())); + final AirbyteStateMessage stateMessage2 = new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream( + new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream2))); + final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage1, stateMessage2)); + final State state = new State().withState(jsonState); + + final AirbyteApiClient airbyteApiClient1 = mock(AirbyteApiClient.class); + final StateApi stateApi1 = mock(StateApi.class); + final ConnectionState connectionState = mock(ConnectionState.class); + Mockito.lenient().when(connectionState.getStateType()).thenReturn(ConnectionStateType.LEGACY); + Mockito.lenient().when(connectionState.getState()).thenReturn(null); + when(stateApi1.getState(any(ConnectionIdRequestBody.class))).thenReturn(connectionState); + Mockito.lenient().when(airbyteApiClient1.getStateApi()).thenReturn(stateApi1); + + final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3)); + final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state); + when(featureFlags.useStreamCapableState()).thenReturn(true); + + final PersistStateActivityImpl persistStateActivity1 = new PersistStateActivityImpl(airbyteApiClient1, featureFlags); + + persistStateActivity1.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog); + + Mockito.verify(stateApi1).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); + } + + @Test + void testPersistWithPerStreamStateDuringMigrationWithNoPreviousState() throws ApiException { + final ConfiguredAirbyteStream stream = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE)); + final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B)); + final ConfiguredAirbyteStream stream3 = + new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH); + + final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream( + new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream)) + .withStreamState(Jsons.emptyObject())); + final AirbyteStateMessage stateMessage2 = new AirbyteStateMessage() + .withType(AirbyteStateType.STREAM) + .withStream( + new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream2))); + final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage1, stateMessage2)); + final State state = new State().withState(jsonState); + + final AirbyteApiClient airbyteApiClient1 = mock(AirbyteApiClient.class); + final StateApi stateApi1 = mock(StateApi.class); + when(stateApi1.getState(any(ConnectionIdRequestBody.class))).thenReturn(null); + Mockito.lenient().when(airbyteApiClient1.getStateApi()).thenReturn(stateApi1); + + final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3)); + final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state); + when(featureFlags.useStreamCapableState()).thenReturn(true); + + final PersistStateActivityImpl persistStateActivity1 = new PersistStateActivityImpl(airbyteApiClient1, featureFlags); + + persistStateActivity1.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog); + + Mockito.verify(stateApi1).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class)); } }