Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progress Bar Estimate #19814

Merged
merged 13 commits into from
Nov 29, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public class AirbyteMessageTracker implements MessageTracker {
private final BiMap<AirbyteStreamNameNamespacePair, Short> nameNamespacePairToIndex;
private final Map<Short, Long> streamToTotalBytesEmitted;
private final Map<Short, Long> streamToTotalRecordsEmitted;
private final Map<Short, Long> streamToTotalBytesEstimated;
private final Map<Short, Long> streamToTotalRecordsEstimated;
private final StateDeltaTracker stateDeltaTracker;
private final StateMetricsTracker stateMetricsTracker;
private final List<AirbyteTraceMessage> destinationErrorTraceMessages;
Expand Down Expand Up @@ -95,6 +97,8 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker,
this.hashFunction = Hashing.murmur3_32_fixed();
this.streamToTotalBytesEmitted = new HashMap<>();
this.streamToTotalRecordsEmitted = new HashMap<>();
this.streamToTotalBytesEstimated = new HashMap<>();
this.streamToTotalRecordsEstimated = new HashMap<>();
this.stateDeltaTracker = stateDeltaTracker;
this.stateMetricsTracker = stateMetricsTracker;
this.nextStreamIndex = 0;
Expand Down Expand Up @@ -252,7 +256,7 @@ private void handleEmittedOrchestratorConnectorConfig(final AirbyteControlConnec
*/
private void handleEmittedTrace(final AirbyteTraceMessage traceMessage, final ConnectorType connectorType) {
switch (traceMessage.getType()) {
case ESTIMATE -> handleEmittedEstimateTrace(traceMessage, connectorType);
case ESTIMATE -> handleEmittedEstimateTrace(traceMessage);
case ERROR -> handleEmittedErrorTrace(traceMessage, connectorType);
default -> log.warn("Invalid message type for trace message: {}", traceMessage);
}
Expand All @@ -266,9 +270,16 @@ private void handleEmittedErrorTrace(final AirbyteTraceMessage errorTraceMessage
}
}

@SuppressWarnings("PMD") // until method is implemented
private void handleEmittedEstimateTrace(final AirbyteTraceMessage estimateTraceMessage, final ConnectorType connectorType) {
private void handleEmittedEstimateTrace(final AirbyteTraceMessage estimateTraceMessage) {
// Assume the estimate is a whole number and not a sum i.e. each estimate replaces the previous
// estimate.

final var estimate = estimateTraceMessage.getEstimate();
log.info("Saving records estimates for namespace: {}, stream: {}", estimate.getNamespace(), estimate.getName());
final var index = getStreamIndex(new AirbyteStreamNameNamespacePair(estimate.getName(), estimate.getNamespace()));

streamToTotalRecordsEstimated.put(index, estimate.getRowEstimate());
streamToTotalBytesEstimated.put(index, estimate.getByteEstimate());
}

private short getStreamIndex(final AirbyteStreamNameNamespacePair pair) {
Expand Down Expand Up @@ -368,6 +379,12 @@ public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedRecords() {
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

@Override
public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedRecords() {
davinchia marked this conversation as resolved.
Show resolved Hide resolved
return streamToTotalRecordsEstimated.entrySet().stream().collect(Collectors.toMap(
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

/**
* Swap out stream indices for stream names and return total bytes emitted by stream.
*/
Expand All @@ -377,6 +394,12 @@ public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedBytes() {
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

@Override
public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedBytes() {
return streamToTotalBytesEstimated.entrySet().stream().collect(Collectors.toMap(
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

/**
* Compute sum of emitted record counts across all streams.
*/
Expand All @@ -385,6 +408,11 @@ public long getTotalRecordsEmitted() {
return streamToTotalRecordsEmitted.values().stream().reduce(0L, Long::sum);
}

@Override
public long getTotalRecordsEstimated() {
return streamToTotalRecordsEstimated.values().stream().reduce(0L, Long::sum);
}

/**
* Compute sum of emitted bytes across all streams.
*/
Expand All @@ -393,6 +421,11 @@ public long getTotalBytesEmitted() {
return streamToTotalBytesEmitted.values().stream().reduce(0L, Long::sum);
}

@Override
public long getTotalBytesEstimated() {
return streamToTotalBytesEstimated.values().stream().reduce(0L, Long::sum);
}

/**
* Compute sum of committed record counts across all streams. If the delta tracker has exceeded its
* capacity, return empty because committed record counts cannot be reliably computed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public interface MessageTracker {
*/
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedRecords();

Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedRecords();

/**
* Get the per-stream emitted byte count. This includes messages that were emitted by the source,
* but never committed by the destination.
Expand All @@ -74,6 +76,8 @@ public interface MessageTracker {
*/
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedBytes();

Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedBytes();

/**
* Get the overall emitted record count. This includes messages that were emitted by the source, but
* never committed by the destination.
Expand All @@ -82,6 +86,8 @@ public interface MessageTracker {
*/
long getTotalRecordsEmitted();

long getTotalRecordsEstimated();

/**
* Get the overall emitted bytes. This includes messages that were emitted by the source, but never
* committed by the destination.
Expand All @@ -90,6 +96,8 @@ public interface MessageTracker {
*/
long getTotalBytesEmitted();

long getTotalBytesEstimated();

/**
* Get the overall committed record count.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableMap;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.AirbyteErrorTraceMessage;
import io.airbyte.protocol.models.AirbyteEstimateTraceMessage;
import io.airbyte.protocol.models.AirbyteGlobalState;
import io.airbyte.protocol.models.AirbyteLogMessage;
import io.airbyte.protocol.models.AirbyteMessage;
Expand Down Expand Up @@ -102,29 +103,43 @@ public static AirbyteStreamState createStreamState(final String streamName) {
return new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(streamName));
}

public static AirbyteMessage createEstimateMessage(final String name, final String namespace, final long byteEst, final long rowEst) {
final var est = new AirbyteEstimateTraceMessage()
.withType(AirbyteEstimateTraceMessage.Type.STREAM)
davinchia marked this conversation as resolved.
Show resolved Hide resolved
.withNamespace(namespace)
.withName(name)
.withByteEstimate(byteEst)
.withRowEstimate(rowEst);

return new AirbyteMessage()
.withType(Type.TRACE)
.withTrace(new AirbyteTraceMessage().withType(AirbyteTraceMessage.Type.ESTIMATE)
.withEstimate(est));
}

public static AirbyteMessage createErrorMessage(final String message, final Double emittedAt) {
return new AirbyteMessage()
.withType(AirbyteMessage.Type.TRACE)
.withTrace(createErrorTraceMessage(message, emittedAt));
}

public static AirbyteTraceMessage createErrorTraceMessage(final String message, final Double emittedAt) {
return new AirbyteTraceMessage()
.withType(io.airbyte.protocol.models.AirbyteTraceMessage.Type.ERROR)
.withEmittedAt(emittedAt)
.withError(new AirbyteErrorTraceMessage().withMessage(message));
return createErrorTraceMessage(message, emittedAt, null);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consolidate the various trace message creation into one function to avoid duplication.

}

public static AirbyteTraceMessage createErrorTraceMessage(final String message,
final Double emittedAt,
final AirbyteErrorTraceMessage.FailureType failureType) {
return new AirbyteTraceMessage()
final var msg = new AirbyteTraceMessage()
.withType(io.airbyte.protocol.models.AirbyteTraceMessage.Type.ERROR)
.withEmittedAt(emittedAt)
.withError(new AirbyteErrorTraceMessage().withMessage(message).withFailureType(failureType));
}
.withError(new AirbyteErrorTraceMessage().withMessage(message))
.withEmittedAt(emittedAt);

public static AirbyteMessage createTraceMessage(final String message, final Double emittedAt) {
return new AirbyteMessage()
.withType(AirbyteMessage.Type.TRACE)
.withTrace(new AirbyteTraceMessage()
.withType(AirbyteTraceMessage.Type.ERROR)
.withEmittedAt(emittedAt)
.withError(new AirbyteErrorTraceMessage().withMessage(message)));
if (failureType != null) {
msg.getError().withFailureType(failureType);
}

return msg;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void testReplicationRunnableWorkerFailure() throws Exception {
@Test
void testOnlyStateAndRecordMessagesDeliveredToDestination() throws Exception {
final AirbyteMessage LOG_MESSAGE = AirbyteMessageUtils.createLogMessage(Level.INFO, "a log message");
final AirbyteMessage TRACE_MESSAGE = AirbyteMessageUtils.createTraceMessage("a trace message", 123456.0);
final AirbyteMessage TRACE_MESSAGE = AirbyteMessageUtils.createErrorMessage("a trace message", 123456.0);
when(mapper.mapMessage(LOG_MESSAGE)).thenReturn(LOG_MESSAGE);
when(mapper.mapMessage(TRACE_MESSAGE)).thenReturn(TRACE_MESSAGE);
when(source.isFinished()).thenReturn(false, false, false, false, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.airbyte.workers.internal;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.airbyte.commons.json.Jsons;
Expand All @@ -19,6 +20,7 @@
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -28,9 +30,10 @@
@ExtendWith(MockitoExtension.class)
class AirbyteMessageTrackerTest {

private static final String STREAM_1 = "stream1";
private static final String STREAM_2 = "stream2";
private static final String STREAM_3 = "stream3";
private static final String NAMESPACE_1 = "avengers";
private static final String STREAM_1 = "iron man";
private static final String STREAM_2 = "black widow";
private static final String STREAM_3 = "hulk";
private static final String INDUCED_EXCEPTION = "induced exception";

private AirbyteMessageTracker messageTracker;
Expand Down Expand Up @@ -277,11 +280,11 @@ void testGetTotalRecordsCommitted_emptyWhenCommitStateHashThrowsException() thro
}

@Test
void testGetFirstDestinationAndSourceMessages() throws Exception {
final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createTraceMessage("source trace 1", Double.valueOf(123));
final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createTraceMessage("source trace 2", Double.valueOf(124));
final AirbyteMessage destMessage1 = AirbyteMessageUtils.createTraceMessage("dest trace 1", Double.valueOf(125));
final AirbyteMessage destMessage2 = AirbyteMessageUtils.createTraceMessage("dest trace 2", Double.valueOf(126));
void testGetFirstDestinationAndSourceMessages() {
final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createErrorMessage("source trace 1", 123.0);
final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createErrorMessage("source trace 2", 124.0);
final AirbyteMessage destMessage1 = AirbyteMessageUtils.createErrorMessage("dest trace 1", 125.0);
final AirbyteMessage destMessage2 = AirbyteMessageUtils.createErrorMessage("dest trace 2", 126.0);
messageTracker.acceptFromSource(sourceMessage1);
messageTracker.acceptFromSource(sourceMessage2);
messageTracker.acceptFromDestination(destMessage1);
Expand All @@ -292,39 +295,81 @@ void testGetFirstDestinationAndSourceMessages() throws Exception {
}

@Test
void testGetFirstDestinationAndSourceMessagesWithNulls() throws Exception {
assertEquals(messageTracker.getFirstDestinationErrorTraceMessage(), null);
assertEquals(messageTracker.getFirstSourceErrorTraceMessage(), null);
void testGetFirstDestinationAndSourceMessagesWithNulls() {
assertNull(messageTracker.getFirstDestinationErrorTraceMessage());
assertNull(messageTracker.getFirstSourceErrorTraceMessage());
}

@Test
void testErrorTraceMessageFailureWithMultipleTraceErrors() throws Exception {
final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createTraceMessage("source trace 1", Double.valueOf(123));
final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createTraceMessage("source trace 2", Double.valueOf(124));
final AirbyteMessage destMessage1 = AirbyteMessageUtils.createTraceMessage("dest trace 1", Double.valueOf(125));
final AirbyteMessage destMessage2 = AirbyteMessageUtils.createTraceMessage("dest trace 2", Double.valueOf(126));
void testErrorTraceMessageFailureWithMultipleTraceErrors() {
final AirbyteMessage sourceMessage1 = AirbyteMessageUtils.createErrorMessage("source trace 1", 123.0);
final AirbyteMessage sourceMessage2 = AirbyteMessageUtils.createErrorMessage("source trace 2", 124.0);
final AirbyteMessage destMessage1 = AirbyteMessageUtils.createErrorMessage("dest trace 1", 125.0);
final AirbyteMessage destMessage2 = AirbyteMessageUtils.createErrorMessage("dest trace 2", 126.0);
messageTracker.acceptFromSource(sourceMessage1);
messageTracker.acceptFromSource(sourceMessage2);
messageTracker.acceptFromDestination(destMessage1);
messageTracker.acceptFromDestination(destMessage2);

final FailureReason failureReason = FailureHelper.sourceFailure(sourceMessage1.getTrace(), Long.valueOf(123), 1);
assertEquals(messageTracker.errorTraceMessageFailure(Long.valueOf(123), 1),
assertEquals(messageTracker.errorTraceMessageFailure(123L, 1),
failureReason);
}

@Test
void testErrorTraceMessageFailureWithOneTraceError() throws Exception {
final AirbyteMessage destMessage = AirbyteMessageUtils.createTraceMessage("dest trace 1", Double.valueOf(125));
void testErrorTraceMessageFailureWithOneTraceError() {
final AirbyteMessage destMessage = AirbyteMessageUtils.createErrorMessage("dest trace 1", 125.0);
messageTracker.acceptFromDestination(destMessage);

final FailureReason failureReason = FailureHelper.destinationFailure(destMessage.getTrace(), Long.valueOf(123), 1);
assertEquals(messageTracker.errorTraceMessageFailure(Long.valueOf(123), 1), failureReason);
assertEquals(messageTracker.errorTraceMessageFailure(123L, 1), failureReason);
}

@Test
void testErrorTraceMessageFailureWithNoTraceErrors() throws Exception {
assertEquals(messageTracker.errorTraceMessageFailure(Long.valueOf(123), 1), null);
void testErrorTraceMessageFailureWithNoTraceErrors() {
assertEquals(messageTracker.errorTraceMessageFailure(123L, 1), null);
}

@Nested
class Estimates {

// receiving an estimate for two streams should save
@Test
void shouldSaveAndReturnIndividualStreamCountsCorrectly() {
final var est1 = AirbyteMessageUtils.createEstimateMessage(STREAM_1, NAMESPACE_1, 100L, 10L);
final var est2 = AirbyteMessageUtils.createEstimateMessage(STREAM_2, NAMESPACE_1, 200L, 10L);

messageTracker.acceptFromSource(est1);
messageTracker.acceptFromSource(est2);

final var streamToEstBytes = messageTracker.getStreamToEstimatedBytes();
final var expStreamToEstBytes = Map.of(
new AirbyteStreamNameNamespacePair(STREAM_1, NAMESPACE_1), 100L,
new AirbyteStreamNameNamespacePair(STREAM_2, NAMESPACE_1), 200L);
assertEquals(expStreamToEstBytes, streamToEstBytes);

final var streamToEstRecs = messageTracker.getStreamToEstimatedRecords();
final var expStreamToEstRecs = Map.of(
new AirbyteStreamNameNamespacePair(STREAM_1, NAMESPACE_1), 10L,
new AirbyteStreamNameNamespacePair(STREAM_2, NAMESPACE_1), 10L);
assertEquals(expStreamToEstRecs, streamToEstRecs);
}

@Test
void shouldSaveAndReturnTotalCountsCorrectly() {
final var est1 = AirbyteMessageUtils.createEstimateMessage(STREAM_1, NAMESPACE_1, 100L, 10L);
final var est2 = AirbyteMessageUtils.createEstimateMessage(STREAM_2, NAMESPACE_1, 200L, 10L);

messageTracker.acceptFromSource(est1);
messageTracker.acceptFromSource(est2);

final var totalEstBytes = messageTracker.getTotalBytesEstimated();
assertEquals(300L, totalEstBytes);

final var totalEstRecs = messageTracker.getTotalRecordsEstimated();
assertEquals(20L, totalEstRecs);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void setup() throws IOException, WorkerException {
.withConnectionStatus(new AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.FAILED).withMessage("failed to connect"));
failureStreamFactory = noop -> Lists.newArrayList(failureMessage).stream();

final AirbyteMessage traceMessage = AirbyteMessageUtils.createTraceMessage("some error from the connector", 123.0);
final AirbyteMessage traceMessage = AirbyteMessageUtils.createErrorMessage("some error from the connector", 123.0);
traceMessageStreamFactory = noop -> Lists.newArrayList(traceMessage).stream();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void testDiscoverSchemaProcessFail() throws Exception {
@Test
void testDiscoverSchemaProcessFailWithTraceMessage() throws Exception {
final AirbyteStreamFactory traceStreamFactory = noop -> Lists.newArrayList(
AirbyteMessageUtils.createTraceMessage("some error from the connector", 123.0)).stream();
AirbyteMessageUtils.createErrorMessage("some error from the connector", 123.0)).stream();

when(process.exitValue()).thenReturn(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void testFailureOnNonzeroExitCode() throws InterruptedException, IOException {

@Test
void testFailureOnNonzeroExitCodeWithTraceMessage() throws WorkerException, InterruptedException {
final AirbyteMessage message = AirbyteMessageUtils.createTraceMessage("some error from the connector", 123.0);
final AirbyteMessage message = AirbyteMessageUtils.createErrorMessage("some error from the connector", 123.0);

when(process.getInputStream()).thenReturn(new ByteArrayInputStream(Jsons.serialize(message).getBytes(Charsets.UTF_8)));
when(process.waitFor(anyLong(), any())).thenReturn(true);
Expand Down