Skip to content

Commit

Permalink
chore: add workload labeling in init container (#13454)
Browse files Browse the repository at this point in the history
  • Loading branch information
tryangul committed Aug 9, 2024
1 parent 142184a commit d69775a
Show file tree
Hide file tree
Showing 16 changed files with 141 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.airbyte.workload.launcher.model
package io.airbyte.workers.input

import io.airbyte.config.ResourceRequirements
import io.airbyte.persistence.job.models.ReplicationInput
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package io.airbyte.workload.launcher.pods
package io.airbyte.workers.pod

import io.airbyte.workers.pod.PodLabeler.LabelKeys.AUTO_ID
import io.airbyte.workers.pod.PodLabeler.LabelKeys.MUTEX_KEY
import io.airbyte.workers.pod.PodLabeler.LabelKeys.SWEEPER_LABEL_KEY
import io.airbyte.workers.pod.PodLabeler.LabelKeys.SWEEPER_LABEL_VALUE
import io.airbyte.workers.pod.PodLabeler.LabelKeys.WORKLOAD_ID
import io.airbyte.workers.process.Metadata
import io.airbyte.workers.process.Metadata.CHECK_JOB
import io.airbyte.workers.process.Metadata.DISCOVER_JOB
Expand All @@ -11,19 +16,11 @@ import io.airbyte.workers.process.Metadata.SYNC_JOB
import io.airbyte.workers.process.Metadata.SYNC_STEP_KEY
import io.airbyte.workers.process.Metadata.WRITE_STEP
import io.airbyte.workers.process.ProcessFactory
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.AUTO_ID
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.MUTEX_KEY
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.SWEEPER_LABEL_KEY
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.SWEEPER_LABEL_VALUE
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.WORKLOAD_ID
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.UUID

@Singleton
class PodLabeler(
@Named("containerOrchestratorImage") private val orchestratorImageName: String,
) {
class PodLabeler {
fun getSourceLabels(): Map<String, String> {
return mapOf(
SYNC_STEP_KEY to READ_STEP,
Expand All @@ -36,8 +33,8 @@ class PodLabeler(
)
}

fun getReplicationOrchestratorLabels(): Map<String, String> {
return getImageLabels() +
fun getReplicationOrchestratorLabels(orchestratorImageName: String): Map<String, String> {
return getImageLabels(orchestratorImageName) +
mapOf(
JOB_TYPE_KEY to SYNC_JOB,
SYNC_STEP_KEY to ORCHESTRATOR_REPLICATION_STEP,
Expand All @@ -62,7 +59,7 @@ class PodLabeler(
)
}

private fun getImageLabels(): Map<String, String> {
private fun getImageLabels(orchestratorImageName: String): Map<String, String> {
val shortImageName = ProcessFactory.getShortImageName(orchestratorImageName)
val imageVersion = ProcessFactory.getImageVersion(orchestratorImageName)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.airbyte.workers.orchestrator
package io.airbyte.workers.pod

import io.airbyte.workers.process.KubeProcessFactory.KUBE_NAME_LEN_LIMIT
import io.airbyte.workers.process.ProcessFactory
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package io.airbyte.workload.launcher.pods
package io.airbyte.workers.pod

import io.airbyte.workers.pod.PodLabeler.LabelKeys.AUTO_ID
import io.airbyte.workers.pod.PodLabeler.LabelKeys.MUTEX_KEY
import io.airbyte.workers.pod.PodLabeler.LabelKeys.WORKLOAD_ID
import io.airbyte.workers.process.Metadata.CHECK_JOB
import io.airbyte.workers.process.Metadata.DISCOVER_JOB
import io.airbyte.workers.process.Metadata.IMAGE_NAME
Expand All @@ -12,9 +15,6 @@ import io.airbyte.workers.process.Metadata.SYNC_JOB
import io.airbyte.workers.process.Metadata.SYNC_STEP_KEY
import io.airbyte.workers.process.Metadata.WRITE_STEP
import io.airbyte.workers.process.ProcessFactory
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.AUTO_ID
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.MUTEX_KEY
import io.airbyte.workload.launcher.pods.PodLabeler.LabelKeys.WORKLOAD_ID
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
Expand All @@ -25,7 +25,7 @@ import java.util.stream.Stream
class PodLabelerTest {
@Test
fun getSourceLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getSourceLabels()

assert(
Expand All @@ -38,7 +38,7 @@ class PodLabelerTest {

@Test
fun getDestinationLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getDestinationLabels()

assert(
Expand All @@ -51,8 +51,8 @@ class PodLabelerTest {

@Test
fun getReplicationOrchestratorLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val result = labeler.getReplicationOrchestratorLabels()
val labeler = PodLabeler()
val result = labeler.getReplicationOrchestratorLabels(ORCHESTRATOR_IMAGE_NAME)
val shortImageName = ProcessFactory.getShortImageName(ORCHESTRATOR_IMAGE_NAME)
val imageVersion = ProcessFactory.getImageVersion(ORCHESTRATOR_IMAGE_NAME)

Expand All @@ -69,7 +69,7 @@ class PodLabelerTest {

@Test
fun getCheckLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getCheckLabels()

assert(
Expand All @@ -82,7 +82,7 @@ class PodLabelerTest {

@Test
fun getDiscoverLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getDiscoverLabels()

assert(
Expand All @@ -95,7 +95,7 @@ class PodLabelerTest {

@Test
fun getSpecLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getSpecLabels()

assert(
Expand All @@ -109,7 +109,7 @@ class PodLabelerTest {
@ParameterizedTest
@MethodSource("randomStringMatrix")
fun getWorkloadLabels(workloadId: String) {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getWorkloadLabels(workloadId)

assert(
Expand All @@ -123,7 +123,7 @@ class PodLabelerTest {
@ParameterizedTest
@MethodSource("randomStringMatrix")
fun getMutexLabels(key: String) {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getMutexLabels(key)

assert(
Expand All @@ -136,7 +136,7 @@ class PodLabelerTest {

@Test
fun getAutoIdLabels() {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val id = UUID.randomUUID()
val result = labeler.getAutoIdLabels(id)

Expand All @@ -156,7 +156,7 @@ class PodLabelerTest {
passThroughLabels: Map<String, String>,
autoId: UUID,
) {
val labeler = PodLabeler(ORCHESTRATOR_IMAGE_NAME)
val labeler = PodLabeler()
val result = labeler.getSharedLabels(workloadId, mutexKey, passThroughLabels, autoId)

assert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ enum class EnvVar {
METRIC_CLIENT,
MINIO_ENDPOINT,

OPERATION_TYPE,
OTEL_COLLECTOR_ENDPOINT,

PATH_TO_CONNECTORS,
Expand Down
11 changes: 5 additions & 6 deletions airbyte-workload-init-container/src/main/kotlin/InputFetcher.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.airbyte.initContainer

import io.airbyte.config.FailureReason.FailureOrigin
import io.airbyte.initContainer.input.ReplicationHydrationProcessor
import io.airbyte.initContainer.input.InputHydrationProcessor
import io.airbyte.initContainer.system.SystemClient
import io.airbyte.workload.api.client.WorkloadApiClient
import io.airbyte.workload.api.client.model.generated.WorkloadFailureRequest
Expand All @@ -13,20 +13,19 @@ private val logger = KotlinLogging.logger {}
@Singleton
class InputFetcher(
private val workloadApiClient: WorkloadApiClient,
private val replInputProcessor: ReplicationHydrationProcessor,
private val hydrationProcessor: InputHydrationProcessor,
private val systemClient: SystemClient,
) {
fun fetch(workloadId: String) {
val rawPayload =
val workload =
try {
val workload = workloadApiClient.workloadApi.workloadGet(workloadId)
workload.inputPayload
workloadApiClient.workloadApi.workloadGet(workloadId)
} catch (e: Exception) {
return failWorkloadAndExit(workloadId, "fetching workload", e)
}

try {
replInputProcessor.process(rawPayload)
hydrationProcessor.process(workload)
} catch (e: Exception) {
return failWorkloadAndExit(workloadId, "processing workload", e)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.airbyte.initContainer.input

import io.airbyte.workload.api.client.model.generated.Workload

interface InputHydrationProcessor {
fun process(workload: Workload)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,51 @@ package io.airbyte.initContainer.input
import io.airbyte.initContainer.system.FileClient
import io.airbyte.persistence.job.models.ReplicationInput
import io.airbyte.workers.ReplicationInputHydrator
import io.airbyte.workers.input.setDestinationLabels
import io.airbyte.workers.input.setSourceLabels
import io.airbyte.workers.models.ReplicationActivityInput
import io.airbyte.workers.pod.PodLabeler
import io.airbyte.workers.serde.ObjectSerializer
import io.airbyte.workers.serde.PayloadDeserializer
import io.airbyte.workers.sync.OrchestratorConstants
import io.airbyte.workload.api.client.model.generated.Workload
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton

/**
* Parses, hydrates and writes input files for the replication orchestrator.
*/
@Requires(property = "airbyte.init.operation", pattern = "sync")
@Singleton
class ReplicationHydrationProcessor(
private val replicationInputHydrator: ReplicationInputHydrator,
private val deserializer: PayloadDeserializer,
private val serializer: ObjectSerializer,
private val fileClient: FileClient,
) {
fun process(rawPayload: String) {
private val labeler: PodLabeler,
) : InputHydrationProcessor {
override fun process(workload: Workload) {
val rawPayload = workload.inputPayload
val parsed: ReplicationActivityInput = deserializer.toReplicationActivityInput(rawPayload)

val hydrated: ReplicationInput = replicationInputHydrator.getHydratedReplicationInput(parsed)

val labels =
labeler.getSharedLabels(
workloadId = workload.id,
mutexKey = workload.mutexKey,
autoId = workload.autoId,
passThroughLabels = workload.labels.associate { it.key to it.value },
)

val inputWithLabels =
hydrated
.setSourceLabels(labels)
.setDestinationLabels(labels)

fileClient.writeInputFile(
OrchestratorConstants.INIT_FILE_INPUT,
serializer.serialize(hydrated),
serializer.serialize(inputWithLabels),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ airbyte:
client: ${FEATURE_FLAG_CLIENT:}
path: ${FEATURE_FLAG_PATH:/flags}
api-key: ${LAUNCHDARKLY_KEY:}
init:
operation: ${OPERATION_TYPE:}
internal-api:
auth-header:
name: ${AIRBYTE_API_AUTH_HEADER_NAME:}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.airbyte.initContainer

import io.airbyte.initContainer.InputFetcherTest.Fixtures.WORKLOAD_ID
import io.airbyte.initContainer.InputFetcherTest.Fixtures.workload
import io.airbyte.initContainer.input.ReplicationHydrationProcessor
import io.airbyte.initContainer.input.InputHydrationProcessor
import io.airbyte.initContainer.system.SystemClient
import io.airbyte.workload.api.client.WorkloadApiClient
import io.airbyte.workload.api.client.model.generated.Workload
Expand All @@ -22,7 +22,7 @@ class InputFetcherTest {
lateinit var workloadApiClient: WorkloadApiClient

@MockK
lateinit var inputProcessor: ReplicationHydrationProcessor
lateinit var inputProcessor: InputHydrationProcessor

@MockK
lateinit var systemClient: SystemClient
Expand All @@ -42,12 +42,12 @@ class InputFetcherTest {
@Test
fun `fetches input and processes it`() {
every { workloadApiClient.workloadApi.workloadGet(WORKLOAD_ID) } returns workload
every { inputProcessor.process(workload.inputPayload) } returns Unit
every { inputProcessor.process(workload) } returns Unit

fetcher.fetch(WORKLOAD_ID)

verify { workloadApiClient.workloadApi.workloadGet(WORKLOAD_ID) }
verify { inputProcessor.process(workload.inputPayload) }
verify { inputProcessor.process(workload) }
}

@Test
Expand All @@ -65,7 +65,7 @@ class InputFetcherTest {
@Test
fun `fails workload on workload process error`() {
every { workloadApiClient.workloadApi.workloadGet(WORKLOAD_ID) } returns workload
every { inputProcessor.process(workload.inputPayload) } throws Exception("bang")
every { inputProcessor.process(workload) } throws Exception("bang")
every { workloadApiClient.workloadApi.workloadFailure(any()) } returns Unit
every { systemClient.exitProcess(1) } returns Unit

Expand Down
Loading

0 comments on commit d69775a

Please sign in to comment.