From a1b9db50da0c0d5049ca9f5c02e25b086c3799b3 Mon Sep 17 00:00:00 2001 From: Davin Chia Date: Mon, 28 Nov 2022 21:15:58 -0800 Subject: [PATCH] Progress Bar Estimate (#19814) Implement estimate message processing allowing the platform to hold on to estimate message counts in memory. The estimate message is protocol message connectors can choose to emit to provide support for progress bar calculations. There are two kinds of estimates, per-Sync or per-Stream. Sources cannot emit both types in a single sync. Per-stream estimates are what we usually expect. Per-sync estimates are for sources that cannot provide more granular estimates for whatever reasons e.g. CDC sources. In a follow up PR, the platform will periodically save these messages through the save stats api. --- .../internal/AirbyteMessageTracker.java | 96 ++++++++++++- .../workers/internal/MessageTracker.java | 30 ++++ .../test_utils/AirbyteMessageUtils.java | 62 +++++++-- .../general/DefaultReplicationWorkerTest.java | 2 +- .../internal/AirbyteMessageTrackerTest.java | 128 +++++++++++++++--- .../DefaultCheckConnectionWorkerTest.java | 2 +- .../DefaultDiscoverCatalogWorkerTest.java | 2 +- .../general/DefaultGetSpecWorkerTest.java | 2 +- 8 files changed, 280 insertions(+), 44 deletions(-) diff --git a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java index f7d8caa4be698..a94c00927431a 100644 --- a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java +++ b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java @@ -8,6 +8,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.hash.HashFunction; @@ -19,6 +20,7 @@ import io.airbyte.config.State; import io.airbyte.protocol.models.AirbyteControlConnectorConfigMessage; import io.airbyte.protocol.models.AirbyteControlMessage; +import io.airbyte.protocol.models.AirbyteEstimateTraceMessage; import io.airbyte.protocol.models.AirbyteMessage; import io.airbyte.protocol.models.AirbyteRecordMessage; import io.airbyte.protocol.models.AirbyteStateMessage; @@ -51,6 +53,7 @@ public class AirbyteMessageTracker implements MessageTracker { private final Map streamToRunningCount; private final HashFunction hashFunction; private final BiMap nameNamespacePairToIndex; + private final Map nameNamespacePairToStreamStats; private final Map streamToTotalBytesEmitted; private final Map streamToTotalRecordsEmitted; private final StateDeltaTracker stateDeltaTracker; @@ -60,6 +63,11 @@ public class AirbyteMessageTracker implements MessageTracker { private final StateAggregator stateAggregator; private final boolean logConnectorMessages = new EnvVariableFeatureFlags().logConnectorMessages(); + // These variables support SYNC level estimates and are meant for sources where stream level + // estimates are not possible e.g. CDC sources. + private Long totalRecordsEstimatedSync; + private Long totalBytesEstimatedSync; + private short nextStreamIndex; /** @@ -78,6 +86,11 @@ private enum ConnectorType { DESTINATION } + /** + * POJO for all per-stream stats. + */ + private record StreamStats(long estimatedBytes, long emittedBytes, long estimatedRecords, long emittedRecords) {} + public AirbyteMessageTracker() { this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES), new DefaultStateAggregator(new EnvVariableFeatureFlags().useStreamCapableState()), @@ -93,6 +106,7 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker, this.streamToRunningCount = new HashMap<>(); this.nameNamespacePairToIndex = HashBiMap.create(); this.hashFunction = Hashing.murmur3_32_fixed(); + this.nameNamespacePairToStreamStats = new HashMap<>(); this.streamToTotalBytesEmitted = new HashMap<>(); this.streamToTotalRecordsEmitted = new HashMap<>(); this.stateDeltaTracker = stateDeltaTracker; @@ -252,7 +266,7 @@ private void handleEmittedOrchestratorConnectorConfig(final AirbyteControlConnec */ private void handleEmittedTrace(final AirbyteTraceMessage traceMessage, final ConnectorType connectorType) { switch (traceMessage.getType()) { - case ESTIMATE -> handleEmittedEstimateTrace(traceMessage, connectorType); + case ESTIMATE -> handleEmittedEstimateTrace(traceMessage.getEstimate()); case ERROR -> handleEmittedErrorTrace(traceMessage, connectorType); default -> log.warn("Invalid message type for trace message: {}", traceMessage); } @@ -266,8 +280,34 @@ private void handleEmittedErrorTrace(final AirbyteTraceMessage errorTraceMessage } } - @SuppressWarnings("PMD") // until method is implemented - private void handleEmittedEstimateTrace(final AirbyteTraceMessage estimateTraceMessage, final ConnectorType connectorType) { + /** + * There are several assumptions here: + *

+ * - Assume the estimate is a whole number and not a sum i.e. each estimate replaces the previous + * estimate. + *

+ * - Sources cannot emit both STREAM and SYNC estimates in a same sync. Error out if this happens. + */ + @SuppressWarnings("PMD.AvoidDuplicateLiterals") + private void handleEmittedEstimateTrace(final AirbyteEstimateTraceMessage estimate) { + switch (estimate.getType()) { + case STREAM -> { + Preconditions.checkArgument(totalBytesEstimatedSync == null, "STREAM and SYNC estimates should not be emitted in the same sync."); + Preconditions.checkArgument(totalRecordsEstimatedSync == null, "STREAM and SYNC estimates should not be emitted in the same sync."); + + log.debug("Saving stream estimates for namespace: {}, stream: {}", estimate.getNamespace(), estimate.getName()); + nameNamespacePairToStreamStats.put( + new AirbyteStreamNameNamespacePair(estimate.getName(), estimate.getNamespace()), + new StreamStats(estimate.getByteEstimate(), 0L, estimate.getRowEstimate(), 0L)); + } + case SYNC -> { + Preconditions.checkArgument(nameNamespacePairToStreamStats.isEmpty(), "STREAM and SYNC estimates should not be emitted in the same sync."); + + log.debug("Saving sync estimates"); + totalBytesEstimatedSync = estimate.getByteEstimate(); + totalRecordsEstimatedSync = estimate.getRowEstimate(); + } + } } @@ -368,6 +408,17 @@ public Map getStreamToEmittedRecords() { entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue)); } + /** + * Swap out stream indices for stream names and return total records estimated by stream. + */ + @Override + public Map getStreamToEstimatedRecords() { + return nameNamespacePairToStreamStats.entrySet().stream().collect( + Collectors.toMap( + Entry::getKey, + entry -> entry.getValue().estimatedRecords())); + } + /** * Swap out stream indices for stream names and return total bytes emitted by stream. */ @@ -377,6 +428,17 @@ public Map getStreamToEmittedBytes() { entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue)); } + /** + * Swap out stream indices for stream names and return total bytes estimated by stream. + */ + @Override + public Map getStreamToEstimatedBytes() { + return nameNamespacePairToStreamStats.entrySet().stream().collect( + Collectors.toMap( + Entry::getKey, + entry -> entry.getValue().estimatedBytes())); + } + /** * Compute sum of emitted record counts across all streams. */ @@ -385,6 +447,20 @@ public long getTotalRecordsEmitted() { return streamToTotalRecordsEmitted.values().stream().reduce(0L, Long::sum); } + /** + * Compute sum of estimated record counts across all streams. + */ + @Override + public long getTotalRecordsEstimated() { + if (!nameNamespacePairToStreamStats.isEmpty()) { + return nameNamespacePairToStreamStats.values().stream() + .map(e -> e.estimatedRecords) + .reduce(0L, Long::sum); + } + + return totalRecordsEstimatedSync; + } + /** * Compute sum of emitted bytes across all streams. */ @@ -393,6 +469,20 @@ public long getTotalBytesEmitted() { return streamToTotalBytesEmitted.values().stream().reduce(0L, Long::sum); } + /** + * Compute sum of estimated bytes across all streams. + */ + @Override + public long getTotalBytesEstimated() { + if (!nameNamespacePairToStreamStats.isEmpty()) { + return nameNamespacePairToStreamStats.values().stream() + .map(e -> e.estimatedBytes) + .reduce(0L, Long::sum); + } + + return totalBytesEstimatedSync; + } + /** * Compute sum of committed record counts across all streams. If the delta tracker has exceeded its * capacity, return empty because committed record counts cannot be reliably computed. diff --git a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java index 09507ec7a374e..a2f31bf250d80 100644 --- a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java +++ b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java @@ -66,6 +66,14 @@ public interface MessageTracker { */ Map getStreamToEmittedRecords(); + /** + * Get the per-stream estimated record count provided by + * {@link io.airbyte.protocol.models.AirbyteEstimateTraceMessage}. + * + * @return returns a map of estimated record count by stream name. + */ + Map getStreamToEstimatedRecords(); + /** * Get the per-stream emitted byte count. This includes messages that were emitted by the source, * but never committed by the destination. @@ -74,6 +82,14 @@ public interface MessageTracker { */ Map getStreamToEmittedBytes(); + /** + * Get the per-stream estimated byte count provided by + * {@link io.airbyte.protocol.models.AirbyteEstimateTraceMessage}. + * + * @return returns a map of estimated bytes by stream name. + */ + Map getStreamToEstimatedBytes(); + /** * Get the overall emitted record count. This includes messages that were emitted by the source, but * never committed by the destination. @@ -82,6 +98,13 @@ public interface MessageTracker { */ long getTotalRecordsEmitted(); + /** + * Get the overall estimated record count. + * + * @return returns the total count of estimated records across all streams. + */ + long getTotalRecordsEstimated(); + /** * Get the overall emitted bytes. This includes messages that were emitted by the source, but never * committed by the destination. @@ -90,6 +113,13 @@ public interface MessageTracker { */ long getTotalBytesEmitted(); + /** + * Get the overall estimated bytes. + * + * @return returns the total count of estimated bytes across all streams. + */ + long getTotalBytesEstimated(); + /** * Get the overall committed record count. * diff --git a/airbyte-commons-worker/src/main/java/io/airbyte/workers/test_utils/AirbyteMessageUtils.java b/airbyte-commons-worker/src/main/java/io/airbyte/workers/test_utils/AirbyteMessageUtils.java index 2aede71597390..244e85303c8c2 100644 --- a/airbyte-commons-worker/src/main/java/io/airbyte/workers/test_utils/AirbyteMessageUtils.java +++ b/airbyte-commons-worker/src/main/java/io/airbyte/workers/test_utils/AirbyteMessageUtils.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableMap; import io.airbyte.commons.json.Jsons; import io.airbyte.protocol.models.AirbyteErrorTraceMessage; +import io.airbyte.protocol.models.AirbyteEstimateTraceMessage; import io.airbyte.protocol.models.AirbyteGlobalState; import io.airbyte.protocol.models.AirbyteLogMessage; import io.airbyte.protocol.models.AirbyteMessage; @@ -102,29 +103,60 @@ public static AirbyteStreamState createStreamState(final String streamName) { return new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(streamName)); } + public static AirbyteMessage createStreamEstimateMessage(final String name, final String namespace, final long byteEst, final long rowEst) { + return createEstimateMessage(AirbyteEstimateTraceMessage.Type.STREAM, name, namespace, byteEst, rowEst); + } + + public static AirbyteMessage createSyncEstimateMessage(final long byteEst, final long rowEst) { + return createEstimateMessage(AirbyteEstimateTraceMessage.Type.SYNC, null, null, byteEst, rowEst); + } + + public static AirbyteMessage createEstimateMessage(AirbyteEstimateTraceMessage.Type type, + final String name, + final String namespace, + final long byteEst, + final long rowEst) { + final var est = new AirbyteEstimateTraceMessage() + .withType(type) + .withByteEstimate(byteEst) + .withRowEstimate(rowEst); + + if (name != null) { + est.withName(name); + } + if (namespace != null) { + est.withNamespace(namespace); + } + + return new AirbyteMessage() + .withType(Type.TRACE) + .withTrace(new AirbyteTraceMessage().withType(AirbyteTraceMessage.Type.ESTIMATE) + .withEstimate(est)); + } + + public static AirbyteMessage createErrorMessage(final String message, final Double emittedAt) { + return new AirbyteMessage() + .withType(AirbyteMessage.Type.TRACE) + .withTrace(createErrorTraceMessage(message, emittedAt)); + } + public static AirbyteTraceMessage createErrorTraceMessage(final String message, final Double emittedAt) { - return new AirbyteTraceMessage() - .withType(io.airbyte.protocol.models.AirbyteTraceMessage.Type.ERROR) - .withEmittedAt(emittedAt) - .withError(new AirbyteErrorTraceMessage().withMessage(message)); + return createErrorTraceMessage(message, emittedAt, null); } public static AirbyteTraceMessage createErrorTraceMessage(final String message, final Double emittedAt, final AirbyteErrorTraceMessage.FailureType failureType) { - return new AirbyteTraceMessage() + final var msg = new AirbyteTraceMessage() .withType(io.airbyte.protocol.models.AirbyteTraceMessage.Type.ERROR) - .withEmittedAt(emittedAt) - .withError(new AirbyteErrorTraceMessage().withMessage(message).withFailureType(failureType)); - } + .withError(new AirbyteErrorTraceMessage().withMessage(message)) + .withEmittedAt(emittedAt); - public static AirbyteMessage createTraceMessage(final String message, final Double emittedAt) { - return new AirbyteMessage() - .withType(AirbyteMessage.Type.TRACE) - .withTrace(new AirbyteTraceMessage() - .withType(AirbyteTraceMessage.Type.ERROR) - .withEmittedAt(emittedAt) - .withError(new AirbyteErrorTraceMessage().withMessage(message))); + if (failureType != null) { + msg.getError().withFailureType(failureType); + } + + return msg; } } diff --git a/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java b/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java index 7b915ed4943ca..e13f44edb6b4b 100644 --- a/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java +++ b/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java @@ -307,7 +307,7 @@ void testReplicationRunnableWorkerFailure() throws Exception { @Test void testOnlyStateAndRecordMessagesDeliveredToDestination() throws Exception { final AirbyteMessage LOG_MESSAGE = AirbyteMessageUtils.createLogMessage(Level.INFO, "a log message"); - final AirbyteMessage TRACE_MESSAGE = AirbyteMessageUtils.createTraceMessage("a trace message", 123456.0); + final AirbyteMessage TRACE_MESSAGE = AirbyteMessageUtils.createErrorMessage("a trace message", 123456.0); when(mapper.mapMessage(LOG_MESSAGE)).thenReturn(LOG_MESSAGE); when(mapper.mapMessage(TRACE_MESSAGE)).thenReturn(TRACE_MESSAGE); when(source.isFinished()).thenReturn(false, false, false, false, true); diff --git a/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java b/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java index 5123b299453ce..aed444225cf9d 100644 --- a/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java +++ b/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java @@ -5,6 +5,8 @@ package io.airbyte.workers.internal; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import io.airbyte.commons.json.Jsons; @@ -19,6 +21,8 @@ import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -28,9 +32,10 @@ @ExtendWith(MockitoExtension.class) class AirbyteMessageTrackerTest { - private static final String STREAM_1 = "stream1"; - private static final String STREAM_2 = "stream2"; - private static final String STREAM_3 = "stream3"; + private static final String NAMESPACE_1 = "avengers"; + private static final String STREAM_1 = "iron man"; + private static final String STREAM_2 = "black widow"; + private static final String STREAM_3 = "hulk"; private static final String INDUCED_EXCEPTION = "induced exception"; private AirbyteMessageTracker messageTracker; @@ -277,11 +282,11 @@ void testGetTotalRecordsCommitted_emptyWhenCommitStateHashThrowsException() thro } @Test - void testGetFirstDestinationAndSourceMessages() throws Exception { - final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createTraceMessage("source trace 1", Double.valueOf(123)); - final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createTraceMessage("source trace 2", Double.valueOf(124)); - final AirbyteMessage destMessage1 = AirbyteMessageUtils.createTraceMessage("dest trace 1", Double.valueOf(125)); - final AirbyteMessage destMessage2 = AirbyteMessageUtils.createTraceMessage("dest trace 2", Double.valueOf(126)); + void testGetFirstDestinationAndSourceMessages() { + final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createErrorMessage("source trace 1", 123.0); + final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createErrorMessage("source trace 2", 124.0); + final AirbyteMessage destMessage1 = AirbyteMessageUtils.createErrorMessage("dest trace 1", 125.0); + final AirbyteMessage destMessage2 = AirbyteMessageUtils.createErrorMessage("dest trace 2", 126.0); messageTracker.acceptFromSource(sourceMessage1); messageTracker.acceptFromSource(sourceMessage2); messageTracker.acceptFromDestination(destMessage1); @@ -292,39 +297,118 @@ void testGetFirstDestinationAndSourceMessages() throws Exception { } @Test - void testGetFirstDestinationAndSourceMessagesWithNulls() throws Exception { - assertEquals(messageTracker.getFirstDestinationErrorTraceMessage(), null); - assertEquals(messageTracker.getFirstSourceErrorTraceMessage(), null); + void testGetFirstDestinationAndSourceMessagesWithNulls() { + assertNull(messageTracker.getFirstDestinationErrorTraceMessage()); + assertNull(messageTracker.getFirstSourceErrorTraceMessage()); } @Test - void testErrorTraceMessageFailureWithMultipleTraceErrors() throws Exception { - final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createTraceMessage("source trace 1", Double.valueOf(123)); - final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createTraceMessage("source trace 2", Double.valueOf(124)); - final AirbyteMessage destMessage1 = AirbyteMessageUtils.createTraceMessage("dest trace 1", Double.valueOf(125)); - final AirbyteMessage destMessage2 = AirbyteMessageUtils.createTraceMessage("dest trace 2", Double.valueOf(126)); + void testErrorTraceMessageFailureWithMultipleTraceErrors() { + final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createErrorMessage("source trace 1", 123.0); + final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createErrorMessage("source trace 2", 124.0); + final AirbyteMessage destMessage1 = AirbyteMessageUtils.createErrorMessage("dest trace 1", 125.0); + final AirbyteMessage destMessage2 = AirbyteMessageUtils.createErrorMessage("dest trace 2", 126.0); messageTracker.acceptFromSource(sourceMessage1); messageTracker.acceptFromSource(sourceMessage2); messageTracker.acceptFromDestination(destMessage1); messageTracker.acceptFromDestination(destMessage2); final FailureReason failureReason = FailureHelper.sourceFailure(sourceMessage1.getTrace(), Long.valueOf(123), 1); - assertEquals(messageTracker.errorTraceMessageFailure(Long.valueOf(123), 1), + assertEquals(messageTracker.errorTraceMessageFailure(123L, 1), failureReason); } @Test - void testErrorTraceMessageFailureWithOneTraceError() throws Exception { - final AirbyteMessage destMessage = AirbyteMessageUtils.createTraceMessage("dest trace 1", Double.valueOf(125)); + void testErrorTraceMessageFailureWithOneTraceError() { + final AirbyteMessage destMessage = AirbyteMessageUtils.createErrorMessage("dest trace 1", 125.0); messageTracker.acceptFromDestination(destMessage); final FailureReason failureReason = FailureHelper.destinationFailure(destMessage.getTrace(), Long.valueOf(123), 1); - assertEquals(messageTracker.errorTraceMessageFailure(Long.valueOf(123), 1), failureReason); + assertEquals(messageTracker.errorTraceMessageFailure(123L, 1), failureReason); } @Test - void testErrorTraceMessageFailureWithNoTraceErrors() throws Exception { - assertEquals(messageTracker.errorTraceMessageFailure(Long.valueOf(123), 1), null); + void testErrorTraceMessageFailureWithNoTraceErrors() { + assertEquals(messageTracker.errorTraceMessageFailure(123L, 1), null); + } + + @Nested + class Estimates { + + // receiving an estimate for two streams should save + @Test + @DisplayName("when given stream estimates, should return correct per-stream estimates") + void streamShouldSaveAndReturnIndividualStreamCountsCorrectly() { + final var est1 = AirbyteMessageUtils.createStreamEstimateMessage(STREAM_1, NAMESPACE_1, 100L, 10L); + final var est2 = AirbyteMessageUtils.createStreamEstimateMessage(STREAM_2, NAMESPACE_1, 200L, 10L); + + messageTracker.acceptFromSource(est1); + messageTracker.acceptFromSource(est2); + + final var streamToEstBytes = messageTracker.getStreamToEstimatedBytes(); + final var expStreamToEstBytes = Map.of( + new AirbyteStreamNameNamespacePair(STREAM_1, NAMESPACE_1), 100L, + new AirbyteStreamNameNamespacePair(STREAM_2, NAMESPACE_1), 200L); + assertEquals(expStreamToEstBytes, streamToEstBytes); + + final var streamToEstRecs = messageTracker.getStreamToEstimatedRecords(); + final var expStreamToEstRecs = Map.of( + new AirbyteStreamNameNamespacePair(STREAM_1, NAMESPACE_1), 10L, + new AirbyteStreamNameNamespacePair(STREAM_2, NAMESPACE_1), 10L); + assertEquals(expStreamToEstRecs, streamToEstRecs); + } + + @Test + @DisplayName("when given stream estimates, should return correct total estimates") + void streamShouldSaveAndReturnTotalCountsCorrectly() { + final var est1 = AirbyteMessageUtils.createStreamEstimateMessage(STREAM_1, NAMESPACE_1, 100L, 10L); + final var est2 = AirbyteMessageUtils.createStreamEstimateMessage(STREAM_2, NAMESPACE_1, 200L, 10L); + + messageTracker.acceptFromSource(est1); + messageTracker.acceptFromSource(est2); + + final var totalEstBytes = messageTracker.getTotalBytesEstimated(); + assertEquals(300L, totalEstBytes); + + final var totalEstRecs = messageTracker.getTotalRecordsEstimated(); + assertEquals(20L, totalEstRecs); + } + + @Test + @DisplayName("should error when given both Stream and Sync estimates") + void shouldErrorOnBothStreamAndSyncEstimates() { + final var est1 = AirbyteMessageUtils.createStreamEstimateMessage(STREAM_1, NAMESPACE_1, 100L, 10L); + final var est2 = AirbyteMessageUtils.createSyncEstimateMessage(200L, 10L); + + messageTracker.acceptFromSource(est1); + assertThrows(IllegalArgumentException.class, () -> messageTracker.acceptFromSource(est2)); + } + + @Test + @DisplayName("when given sync estimates, should return correct total estimates") + void syncShouldSaveAndReturnTotalCountsCorrectly() { + final var est = AirbyteMessageUtils.createSyncEstimateMessage(200L, 10L); + messageTracker.acceptFromSource(est); + + final var totalEstBytes = messageTracker.getTotalBytesEstimated(); + assertEquals(200L, totalEstBytes); + + final var totalEstRecs = messageTracker.getTotalRecordsEstimated(); + assertEquals(10L, totalEstRecs); + } + + @Test + @DisplayName("when given sync estimates, should not return any per-stream estimates") + void syncShouldNotHaveStreamEstimates() { + final var est = AirbyteMessageUtils.createSyncEstimateMessage(200L, 10L); + messageTracker.acceptFromSource(est); + + final var streamToEstBytes = messageTracker.getStreamToEstimatedBytes(); + assertTrue(streamToEstBytes.isEmpty()); + final var streamToEstRecs = messageTracker.getStreamToEstimatedRecords(); + assertTrue(streamToEstRecs.isEmpty()); + } + } } diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultCheckConnectionWorkerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultCheckConnectionWorkerTest.java index e1443fbbb91a4..f5685c245e44d 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultCheckConnectionWorkerTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultCheckConnectionWorkerTest.java @@ -77,7 +77,7 @@ void setup() throws IOException, WorkerException { .withConnectionStatus(new AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.FAILED).withMessage("failed to connect")); failureStreamFactory = noop -> Lists.newArrayList(failureMessage).stream(); - final AirbyteMessage traceMessage = AirbyteMessageUtils.createTraceMessage("some error from the connector", 123.0); + final AirbyteMessage traceMessage = AirbyteMessageUtils.createErrorMessage("some error from the connector", 123.0); traceMessageStreamFactory = noop -> Lists.newArrayList(traceMessage).stream(); } diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultDiscoverCatalogWorkerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultDiscoverCatalogWorkerTest.java index ff7534073108a..b9fe4417657d1 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultDiscoverCatalogWorkerTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultDiscoverCatalogWorkerTest.java @@ -141,7 +141,7 @@ void testDiscoverSchemaProcessFail() throws Exception { @Test void testDiscoverSchemaProcessFailWithTraceMessage() throws Exception { final AirbyteStreamFactory traceStreamFactory = noop -> Lists.newArrayList( - AirbyteMessageUtils.createTraceMessage("some error from the connector", 123.0)).stream(); + AirbyteMessageUtils.createErrorMessage("some error from the connector", 123.0)).stream(); when(process.exitValue()).thenReturn(1); diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultGetSpecWorkerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultGetSpecWorkerTest.java index 787641266f72b..7ddbdfc17d2c4 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultGetSpecWorkerTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/general/DefaultGetSpecWorkerTest.java @@ -114,7 +114,7 @@ void testFailureOnNonzeroExitCode() throws InterruptedException, IOException { @Test void testFailureOnNonzeroExitCodeWithTraceMessage() throws WorkerException, InterruptedException { - final AirbyteMessage message = AirbyteMessageUtils.createTraceMessage("some error from the connector", 123.0); + final AirbyteMessage message = AirbyteMessageUtils.createErrorMessage("some error from the connector", 123.0); when(process.getInputStream()).thenReturn(new ByteArrayInputStream(Jsons.serialize(message).getBytes(Charsets.UTF_8))); when(process.waitFor(anyLong(), any())).thenReturn(true);