Skip to content

Commit

Permalink
mongo test
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohansong committed May 10, 2024
1 parent 4eace2c commit 71844fa
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,32 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {

protected abstract fun assertExpectedStateMessages(stateMessages: List<AirbyteStateMessage>)

protected open fun assertStreamStatusTraceMessageIndex(idx: Int, allMessages: List<AirbyteMessage>, expectedStreamStatus: AirbyteStreamStatusTraceMessage) {
protected open fun assertStreamStatusTraceMessageIndex(
idx: Int,
allMessages: List<AirbyteMessage>,
expectedStreamStatus: AirbyteStreamStatusTraceMessage
) {
var actualMessage = allMessages[idx]
Assertions.assertEquals(actualMessage.type, AirbyteMessage.Type.TRACE)
var traceMessage = actualMessage.trace
Assertions.assertNotNull(traceMessage.streamStatus)
Assertions.assertEquals(expectedStreamStatus, traceMessage.streamStatus)
}

private fun createAirbteStreanStatusTraceMessage(namespace: String, streamName: String, status:AirbyteStreamStatusTraceMessage.AirbyteStreamStatus) : AirbyteStreamStatusTraceMessage {

return AirbyteStreamStatusTraceMessage().withStreamDescriptor(io.airbyte.protocol.models.StreamDescriptor().withNamespace(namespace).withName(
streamName)).withStatus(status)
private fun createAirbteStreanStatusTraceMessage(
namespace: String,
streamName: String,
status: AirbyteStreamStatusTraceMessage.AirbyteStreamStatus
): AirbyteStreamStatusTraceMessage {

return AirbyteStreamStatusTraceMessage()
.withStreamDescriptor(
io.airbyte.protocol.models
.StreamDescriptor()
.withNamespace(namespace)
.withName(streamName)
)
.withStatus(status)
}

protected open fun assertExpectedStateMessagesForFullRefresh(
Expand Down Expand Up @@ -325,7 +339,9 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
.collect(Collectors.toList())
}

protected fun extractTraceMessages(messages: List<AirbyteMessage>): MutableList<io.airbyte.protocol.models.v0.AirbyteTraceMessage>? {
protected fun extractTraceMessages(
messages: List<AirbyteMessage>
): MutableList<io.airbyte.protocol.models.v0.AirbyteTraceMessage>? {
return messages
.stream()
.filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.TRACE }
Expand Down Expand Up @@ -404,10 +420,24 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val recordMessages = extractRecordMessages(actualRecords)
val stateMessages = extractStateMessages(actualRecords)

assertStreamStatusTraceMessageIndex(0, actualRecords,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED))
assertStreamStatusTraceMessageIndex(actualRecords.size - 1 , actualRecords,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE))
assertStreamStatusTraceMessageIndex(
0,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords.size - 1,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

Assertions.assertNotNull(targetPosition)
recordMessages.forEach(
Expand Down Expand Up @@ -470,10 +500,24 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val stateMessages1 = extractStateMessages(actualRecords1)
assertExpectedStateMessages(stateMessages1)

assertStreamStatusTraceMessageIndex(0, actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED))
assertStreamStatusTraceMessageIndex(actualRecords1.size - 1 , actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE))
assertStreamStatusTraceMessageIndex(
0,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords1.size - 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

updateCommand(MODELS_STREAM_NAME, COL_MODEL, updatedModel, COL_ID, 11)
waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1)
Expand Down Expand Up @@ -648,14 +692,28 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val read1 = source().read(config()!!, configuredCatalog, null)
val actualRecords1 = AutoCloseableIterators.toListAndClose(read1)


// The first message will be start of the full refresh stream.
// The last message will be the end of the incremental stream.
// Index start of the incremental stream will be depending on if connector supports resumeable full refresh.
assertStreamStatusTraceMessageIndex(0, actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME_2, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED))
assertStreamStatusTraceMessageIndex(actualRecords1.size - 1 , actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE))
// Index start of the incremental stream will be depending on if connector supports
// resumeable full refresh.
assertStreamStatusTraceMessageIndex(
0,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords1.size - 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

val recordMessages1 = extractRecordMessages(actualRecords1)
val stateMessages1 = extractStateMessages(actualRecords1)
Expand All @@ -680,10 +738,24 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
modelsSchema(),
)

assertStreamStatusTraceMessageIndex(MODEL_RECORDS_2.size, actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME_2, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE))
assertStreamStatusTraceMessageIndex(MODEL_RECORDS_2.size + 1, actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED))
assertStreamStatusTraceMessageIndex(
MODEL_RECORDS_2.size,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
assertStreamStatusTraceMessageIndex(
MODEL_RECORDS_2.size + 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)

val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1]))
val read2 = source().read(config()!!, configuredCatalog, state)
Expand All @@ -709,10 +781,24 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
)

// Expect state and record message from MODEL_RECORDS_2.
assertStreamStatusTraceMessageIndex(2 * MODEL_RECORDS_2.size, actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME_2, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE))
assertStreamStatusTraceMessageIndex(2 * MODEL_RECORDS_2.size + 1, actualRecords1,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED))
assertStreamStatusTraceMessageIndex(
2 * MODEL_RECORDS_2.size,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
assertStreamStatusTraceMessageIndex(
2 * MODEL_RECORDS_2.size + 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)

assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream())
Expand Down Expand Up @@ -895,10 +981,24 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
assertExpectedStateMessagesForNoData(stateMessages)
assertExpectedStateMessageCountMatches(stateMessages, 0)

assertStreamStatusTraceMessageIndex(0, actualRecords,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED))
assertStreamStatusTraceMessageIndex(1 , actualRecords,
createAirbteStreanStatusTraceMessage(modelsSchema(), MODELS_STREAM_NAME, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE))
assertStreamStatusTraceMessageIndex(
0,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
1,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
}

protected open fun assertExpectedStateMessagesForNoData(
Expand Down
Loading

0 comments on commit 71844fa

Please sign in to comment.