Skip to content

Commit

Permalink
fix bugs, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhukaihan committed Sep 3, 2024
1 parent 8d63520 commit a6814aa
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 12 deletions.
6 changes: 5 additions & 1 deletion src/main/kotlin/flag/FlagConfigUpdater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ internal class FlagConfigFallbackRetryWrapper(
}
} catch (t: Throwable) {
Logger.e("Primary flag configs start failed, start fallback. Error: ", t)
fallbackUpdater?.start()
if (fallbackUpdater == null) {
// No fallback, main start failed is wrapper start fail
throw t
}
fallbackUpdater.start()
scheduleRetry()
}
}
Expand Down
2 changes: 0 additions & 2 deletions src/main/kotlin/util/SseStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ internal class SseStream (
}

override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
println(t)
println(response)
if ((eventSource != es)) {
// Not the current event source using right now, should cancel.
eventSource.cancel()
Expand Down
51 changes: 51 additions & 0 deletions src/test/kotlin/LocalEvaluationClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@ package com.amplitude.experiment

import com.amplitude.experiment.cohort.Cohort
import com.amplitude.experiment.cohort.CohortApi
import com.amplitude.experiment.flag.FlagConfigPoller
import io.mockk.clearAllMocks
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkConstructor
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import kotlin.system.measureNanoTime
import kotlin.test.AfterTest
import kotlin.test.BeforeTest
import kotlin.test.Test

private const val API_KEY = "server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz"

class LocalEvaluationClientTest {
@AfterTest
fun afterTest() {
clearAllMocks()
}

@Test
fun `test evaluate, all flags, success`() {
Expand Down Expand Up @@ -193,6 +202,7 @@ class LocalEvaluationClientTest {
assertEquals("on", userVariant?.key)
assertEquals("on", userVariant?.value)
}

@Test
fun `evaluate with user, cohort tester targeted`() {
val cohortConfig = LocalEvaluationConfig(
Expand Down Expand Up @@ -238,6 +248,7 @@ class LocalEvaluationClientTest {
assertEquals("on", groupVariant?.key)
assertEquals("on", groupVariant?.value)
}

@Test
fun `evaluate with group, cohort tester targeted`() {
val cohortConfig = LocalEvaluationConfig(
Expand All @@ -261,4 +272,44 @@ class LocalEvaluationClientTest {
assertEquals("on", groupVariant?.key)
assertEquals("on", groupVariant?.value)
}

@Test
fun `test evaluate, stream flags, all flags, success`() {
mockkConstructor(FlagConfigPoller::class)
every { anyConstructed<FlagConfigPoller>().start(any()) } answers {
throw Exception("Should use stream, may be flaky test when stream failed")
}
val client = LocalEvaluationClient(API_KEY, LocalEvaluationConfig(streamUpdates = true))
client.start()
val variants = client.evaluate(ExperimentUser(userId = "test_user"))
val variant = variants["sdk-local-evaluation-ci-test"]
Assert.assertEquals(Variant(key = "on", value = "on", payload = "payload"), variant?.copy(metadata = null))
}

@Test
fun `evaluate with user, stream flags, cohort segment targeted`() {
mockkConstructor(FlagConfigPoller::class)
every { anyConstructed<FlagConfigPoller>().start(any()) } answers {
throw Exception("Should use stream, may be flaky test when stream failed")
}
val cohortConfig = LocalEvaluationConfig(
streamUpdates = true,
cohortSyncConfig = CohortSyncConfig("api", "secret")
)
val cohortApi = mockk<CohortApi>().apply {
every { getCohort(eq("52gz3yi7"), allAny()) } returns Cohort("52gz3yi7", "User", 2, 1722363790000, setOf("1", "2"))
every { getCohort(eq("mv7fn2bp"), allAny()) } returns Cohort("mv7fn2bp", "User", 1, 1719350216000, setOf("67890", "12345"))
every { getCohort(eq("s4t57y32"), allAny()) } returns Cohort("s4t57y32", "org name", 1, 1722368285000, setOf("Amplitude Website (Portfolio)"))
every { getCohort(eq("k1lklnnb"), allAny()) } returns Cohort("k1lklnnb", "org id", 1, 1722466388000, setOf("1"))
}
val client = LocalEvaluationClient(API_KEY, cohortConfig, cohortApi = cohortApi)
client.start()
val user = ExperimentUser(
userId = "12345",
deviceId = "device_id",
)
val userVariant = client.evaluateV2(user, setOf("sdk-local-evaluation-user-cohort-ci-test"))["sdk-local-evaluation-user-cohort-ci-test"]
assertEquals("on", userVariant?.key)
assertEquals("on", userVariant?.value)
}
}
5 changes: 5 additions & 0 deletions src/test/kotlin/flag/FlagConfigStreamApiTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class FlagConfigStreamApiTest {
every { anyConstructed<SseStream>().onError = capture(onErrorCapture) } answers {}
}

@AfterTest
fun afterTest() {
clearAllMocks()
}

private fun setupApi(
deploymentKey: String = "",
serverUrl: HttpUrl = "http://localhost".toHttpUrl(),
Expand Down
70 changes: 62 additions & 8 deletions src/test/kotlin/flag/FlagConfigUpdaterTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.amplitude.experiment.LocalEvaluationConfig
import com.amplitude.experiment.evaluation.EvaluationFlag
import com.amplitude.experiment.util.SseStream
import io.mockk.*
import java.lang.Exception
import kotlin.test.*

private val FLAG1 = EvaluationFlag("key1", emptyMap(), emptyList())
Expand All @@ -13,11 +14,16 @@ class FlagConfigPollerTest {
private var storage = InMemoryFlagConfigStorage()

@BeforeTest
fun beforeEach() {
fun beforeTest() {
fetchApi = mockk<FlagConfigApi>()
storage = InMemoryFlagConfigStorage()
}

@AfterTest
fun afterTest() {
clearAllMocks()
}

@Test
fun `Test Poller`() {
every { fetchApi.getFlagConfigs() } returns emptyList()
Expand Down Expand Up @@ -112,14 +118,19 @@ class FlagConfigStreamerTest {
private val config = LocalEvaluationConfig(streamUpdates = true, streamServerUrl = "", streamFlagConnTimeoutMillis = 2000)

@BeforeTest
fun beforeEach() {
fun beforeTest() {
streamApi = mockk<FlagConfigStreamApi>()
storage = InMemoryFlagConfigStorage()

justRun { streamApi.onUpdate = capture(onUpdateCapture) }
justRun { streamApi.onError = capture(onErrorCapture) }
}

@AfterTest
fun afterTest() {
clearAllMocks()
}

@Test
fun `Test Poller`() {
justRun { streamApi.connect() }
Expand Down Expand Up @@ -161,7 +172,7 @@ class FlagConfigStreamerTest {

@Test
fun `Test Streamer stream fails`(){
every { streamApi.connect() } answers { throw Error("Haha error") }
justRun { streamApi.connect() }
val streamer = FlagConfigStreamer(streamApi, storage, null, null, config)
var errorCount = 0
streamer.start { errorCount++ }
Expand All @@ -173,7 +184,7 @@ class FlagConfigStreamerTest {
assertEquals(0, errorCount)

// Stream fails
onErrorCapture.captured(Error("Haha error"))
onErrorCapture.captured(Exception("Haha error"))
assertEquals(1, errorCount) // Error callback is called
}
}
Expand All @@ -184,8 +195,9 @@ class FlagConfigFallbackRetryWrapperTest {

private var mainUpdater = mockk<FlagConfigUpdater>()
private var fallbackUpdater = mockk<FlagConfigUpdater>()

@BeforeTest
fun beforeEach() {
fun beforeTest() {
mainUpdater = mockk<FlagConfigUpdater>()
fallbackUpdater = mockk<FlagConfigUpdater>()

Expand All @@ -197,6 +209,11 @@ class FlagConfigFallbackRetryWrapperTest {
justRun { fallbackUpdater.shutdown() }
}

@AfterTest
fun afterTest() {
clearAllMocks()
}

@Test
fun `Test FallbackRetryWrapper main success no fallback updater`() {
val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0)
Expand Down Expand Up @@ -225,12 +242,15 @@ class FlagConfigFallbackRetryWrapperTest {
every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() }

// Main start fail, no error, same as success case
wrapper.start()
try {
wrapper.start()
fail("Start errors should throw")
} catch (_: Throwable) {}
verify(exactly = 1) { mainUpdater.start(any()) }

// Retries start
// Start errors no retry
Thread.sleep(1100)
verify(exactly = 2) { mainUpdater.start(any()) }
verify(exactly = 1) { mainUpdater.start(any()) }

wrapper.shutdown()
}
Expand Down Expand Up @@ -289,6 +309,29 @@ class FlagConfigFallbackRetryWrapperTest {
verify(exactly = 1) { mainUpdater.shutdown() }
}

@Test
fun `Test FallbackRetryWrapper main and fallback start error`() {
val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0)

every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() }
every { fallbackUpdater.start() } answers { throw Error() }

// Main start fail, no error, same as success case
try {
wrapper.start()
fail("Start errors should throw")
} catch (_: Throwable) {}
verify(exactly = 1) { mainUpdater.start(any()) }
verify(exactly = 1) { fallbackUpdater.start(any()) }

// Start errors no retry
Thread.sleep(1100)
verify(exactly = 1) { mainUpdater.start(any()) }
verify(exactly = 1) { fallbackUpdater.start(any()) }

wrapper.shutdown()
}

@Test
fun `Test FallbackRetryWrapper main start error and retries`() {
val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0)
Expand Down Expand Up @@ -346,4 +389,15 @@ class FlagConfigFallbackRetryWrapperTest {

wrapper.shutdown()
}

@Test
fun `Test FallbackRetryWrapper main updater cannot be FlagConfigFallbackRetryWrapper`() {
val wrapper = FlagConfigFallbackRetryWrapper(FlagConfigFallbackRetryWrapper(mainUpdater, null), null, 1000, 0)
try {
wrapper.start()
fail("Did not throw")
} catch (_: Throwable) {
}
verify(exactly = 0) { mainUpdater.start() }
}
}
7 changes: 6 additions & 1 deletion src/test/kotlin/util/SseStreamTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ import okhttp3.Request
import okhttp3.sse.EventSource
import okhttp3.sse.EventSourceListener
import org.mockito.Mockito
import kotlin.test.AfterTest
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals


class SseStreamTest {
private val listenerCapture = slot<EventSourceListener>()
val clientMock = mockk<OkHttpClient>()
private val clientMock = mockk<OkHttpClient>()
private val es = mockk<EventSource>("mocked es")

private var data: List<String> = listOf()
Expand All @@ -33,7 +34,11 @@ class SseStreamTest {

mockkConstructor(OkHttpClient.Builder::class)
every { anyConstructed<OkHttpClient.Builder>().build() } returns clientMock
}

@AfterTest
fun afterTest() {
clearAllMocks()
}

private fun setupStream(
Expand Down

0 comments on commit a6814aa

Please sign in to comment.