diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 38d9c47..6e6e9dd 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -14,6 +14,8 @@ import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?): Exception(message, cause) { constructor(message: String?) : this(message, null) @@ -32,8 +34,9 @@ internal class FlagConfigStreamApi ( httpClient: OkHttpClient = OkHttpClient(), val connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, - reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, ) { + private val lock: ReentrantLock = ReentrantLock() var onInitUpdate: ((List) -> Unit)? = null var onUpdate: ((List) -> Unit)? = null var onError: ((Exception?) -> Unit)? = null @@ -47,84 +50,88 @@ internal class FlagConfigStreamApi ( reconnIntervalMillis) internal fun connect() { - val isInit = AtomicBoolean(true) - val connectTimeoutFuture = CompletableFuture() - val updateTimeoutFuture = CompletableFuture() - stream.onUpdate = { data -> - if (isInit.getAndSet(false)) { - // Stream is establishing. First data received. - // Resolve timeout. - connectTimeoutFuture.complete(Unit) - - // Make sure valid data. - try { - val flags = getFlagsFromData(data) + // Guarded by lock. Update to callbacks and waits can lead to race conditions. + lock.withLock { + val isInit = AtomicBoolean(true) + val connectTimeoutFuture = CompletableFuture() + val updateTimeoutFuture = CompletableFuture() + stream.onUpdate = { data -> + if (isInit.getAndSet(false)) { + // Stream is establishing. First data received. + // Resolve timeout. + connectTimeoutFuture.complete(Unit) + // Make sure valid data. try { - if (onInitUpdate != null) { - onInitUpdate?.let { it(flags) } - } else { - onUpdate?.let { it(flags) } + val flags = getFlagsFromData(data) + + try { + if (onInitUpdate != null) { + onInitUpdate?.let { it(flags) } + } else { + onUpdate?.let { it(flags) } + } + updateTimeoutFuture.complete(Unit) + } catch (e: Throwable) { + updateTimeoutFuture.completeExceptionally(e) } - updateTimeoutFuture.complete(Unit) - } catch (e: Throwable) { - updateTimeoutFuture.completeExceptionally(e) + } catch (_: Throwable) { + updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) } - } catch (_: Throwable) { - updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) - } - - } else { - // Stream has already established. - // Make sure valid data. - try { - val flags = getFlagsFromData(data) + } else { + // Stream has already established. + // Make sure valid data. try { - onUpdate?.let { it(flags) } + val flags = getFlagsFromData(data) + + try { + onUpdate?.let { it(flags) } + } catch (_: Throwable) { + // Don't care about application error. + } } catch (_: Throwable) { - // Don't care about application error. + // Stream corrupted. Reconnect. + handleError(FlagConfigStreamApiDataCorruptError()) } - } catch (_: Throwable) { - // Stream corrupted. Reconnect. - handleError(FlagConfigStreamApiDataCorruptError()) - } + } } - } - stream.onError = { t -> - if (isInit.getAndSet(false)) { - connectTimeoutFuture.completeExceptionally(t) - updateTimeoutFuture.completeExceptionally(t) - } else { - handleError(FlagConfigStreamApiStreamError(t)) + stream.onError = { t -> + if (isInit.getAndSet(false)) { + connectTimeoutFuture.completeExceptionally(t) + updateTimeoutFuture.completeExceptionally(t) + } else { + handleError(FlagConfigStreamApiStreamError(t)) + } } - } - stream.connect() + stream.connect() - val t: Throwable - try { - connectTimeoutFuture.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS) - updateTimeoutFuture.get() - return - } catch (e: TimeoutException) { - // Timeouts should retry - t = FlagConfigStreamApiConnTimeoutError() - } catch (e: ExecutionException) { - val cause = e.cause - t = if (cause is StreamException) { - FlagConfigStreamApiStreamError(cause) - } else { - FlagConfigStreamApiError(e) + val t: Throwable + try { + connectTimeoutFuture.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + updateTimeoutFuture.get() + return + } catch (e: TimeoutException) { + // Timeouts should retry + t = FlagConfigStreamApiConnTimeoutError() + } catch (e: ExecutionException) { + val cause = e.cause + t = if (cause is StreamException) { + FlagConfigStreamApiStreamError(cause) + } else { + FlagConfigStreamApiError(e) + } + } catch (e: Throwable) { + t = FlagConfigStreamApiError(e) } - } catch (e: Throwable) { - t = FlagConfigStreamApiError(e) + close() + throw t } - close() - throw t } internal fun close() { + // Not guarded by lock. close() can halt connect(). stream.cancel() } diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index b62ce9d..5d1ce5e 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -14,6 +14,8 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.Executors import java.util.concurrent.ScheduledFuture import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock import kotlin.math.max import kotlin.math.min @@ -89,42 +91,52 @@ internal class FlagConfigPoller( private val cohortLoader: CohortLoader?, private val cohortStorage: CohortStorage?, private val config: LocalEvaluationConfig, - private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(), ): FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { + private val lock: ReentrantLock = ReentrantLock() private val pool = Executors.newScheduledThreadPool(1, daemonFactory) - private var scheduledFuture: ScheduledFuture<*>? = null + private var scheduledFuture: ScheduledFuture<*>? = null // @GuardedBy(lock) override fun start(onError: (() -> Unit)?) { refresh() - if (scheduledFuture != null) { - stop() + lock.withLock { + stopInternal() + scheduledFuture = pool.scheduleWithFixedDelay( + { + try { + refresh() + } catch (t: Throwable) { + Logger.e("Refresh flag configs failed.", t) + stop() + onError?.invoke() + } + }, + config.flagConfigPollerIntervalMillis, + config.flagConfigPollerIntervalMillis, + TimeUnit.MILLISECONDS + ) } - scheduledFuture = pool.scheduleWithFixedDelay( - { - try { - refresh() - } catch (t: Throwable) { - Logger.e("Refresh flag configs failed.", t) - stop() - onError?.invoke() - } - }, - config.flagConfigPollerIntervalMillis, - config.flagConfigPollerIntervalMillis, - TimeUnit.MILLISECONDS - ) } - override fun stop() { + // @GuardedBy(lock) + private fun stopInternal() { // Pause only stop the task scheduled. It doesn't stop the executor. scheduledFuture?.cancel(true) scheduledFuture = null } + override fun stop() { + lock.withLock { + stopInternal() + } + } + override fun shutdown() { - // Stop the executor. - pool.shutdown() + lock.withLock { + // Stop the executor. + pool.shutdown() + } } private fun refresh() { @@ -151,21 +163,25 @@ internal class FlagConfigStreamer( ): FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { + private val lock: ReentrantLock = ReentrantLock() override fun start(onError: (() -> Unit)?) { - flagConfigStreamApi.onUpdate = {flags -> - update(flags) - } - flagConfigStreamApi.onError = {e -> - Logger.e("Stream flag configs streaming failed.", e) - metrics.onFlagConfigStreamFailure(e) - onError?.invoke() - } - wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { - flagConfigStreamApi.connect() + lock.withLock { + flagConfigStreamApi.onUpdate = { flags -> + update(flags) + } + flagConfigStreamApi.onError = { e -> + Logger.e("Stream flag configs streaming failed.", e) + metrics.onFlagConfigStreamFailure(e) + onError?.invoke() + } + wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { + flagConfigStreamApi.connect() + } } } override fun stop() { + // Not guarded by lock. close() can cancel start(). flagConfigStreamApi.close() } @@ -178,11 +194,12 @@ internal class FlagConfigFallbackRetryWrapper( private val mainUpdater: FlagConfigUpdater, private val fallbackUpdater: FlagConfigUpdater?, private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, - private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, ): FlagConfigUpdater { + private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, retryDelayMillis - maxJitterMillis) + maxJitterMillis) private val executor = Executors.newScheduledThreadPool(1, daemonFactory) - private var retryTask: ScheduledFuture<*>? = null + private var retryTask: ScheduledFuture<*>? = null // @GuardedBy(lock) /** * Since the wrapper retries, so there will never be error case. Thus, onError will never be called. @@ -192,37 +209,49 @@ internal class FlagConfigFallbackRetryWrapper( throw Error("Do not use FlagConfigFallbackRetryWrapper as main updater. Fallback updater will never be used. Rewrite retry and fallback logic.") } - try { - mainUpdater.start { - scheduleRetry() // Don't care if poller start error or not, always retry. - try { - fallbackUpdater?.start() - } catch (_: Throwable) { + lock.withLock { + retryTask?.cancel(true) + + try { + mainUpdater.start { + lock.withLock { + scheduleRetry() // Don't care if poller start error or not, always retry. + try { + fallbackUpdater?.start() + } catch (_: Throwable) { + } + } } + fallbackUpdater?.stop() + } catch (t: Throwable) { + Logger.e("Primary flag configs start failed, start fallback. Error: ", t) + if (fallbackUpdater == null) { + // No fallback, main start failed is wrapper start fail + throw t + } + fallbackUpdater.start() + scheduleRetry() } - } catch (t: Throwable) { - Logger.e("Primary flag configs start failed, start fallback. Error: ", t) - if (fallbackUpdater == null) { - // No fallback, main start failed is wrapper start fail - throw t - } - fallbackUpdater.start() - scheduleRetry() } } override fun stop() { - mainUpdater.stop() - fallbackUpdater?.stop() - retryTask?.cancel(true) + lock.withLock { + mainUpdater.stop() + fallbackUpdater?.stop() + retryTask?.cancel(true) + } } override fun shutdown() { - mainUpdater.shutdown() - fallbackUpdater?.shutdown() - retryTask?.cancel(true) + lock.withLock { + mainUpdater.shutdown() + fallbackUpdater?.shutdown() + retryTask?.cancel(true) + } } + // @GuardedBy(lock) private fun scheduleRetry() { retryTask = executor.schedule({ try { diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index 1224fcf..d80d4dc 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -13,7 +13,9 @@ import okhttp3.sse.EventSourceListener import okhttp3.sse.EventSources import java.util.* import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.schedule +import kotlin.concurrent.withLock import kotlin.math.max import kotlin.math.min @@ -30,8 +32,9 @@ internal class SseStream ( private val connectionTimeoutMillis: Long, private val keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, - private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, ) { + private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, reconnIntervalMillis - maxJitterMillis)..(min(reconnIntervalMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) private val eventSourceListener = object : EventSourceListener() { override fun onOpen(eventSource: EventSource, response: Response) { @@ -39,13 +42,15 @@ internal class SseStream ( } override fun onClosed(eventSource: EventSource) { - if ((eventSource != es)) { - // Not the current event source using right now, should cancel. - eventSource.cancel() - return + lock.withLock { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } } // Server closed the connection, just reconnect. - cancel() + cancelInternal() connect() } @@ -55,10 +60,12 @@ internal class SseStream ( type: String?, data: String ) { - if ((eventSource != es)) { - // Not the current event source using right now, should cancel. - eventSource.cancel() - return + lock.withLock { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } } // Keep alive data if (KEEP_ALIVE_DATA == data) { @@ -68,10 +75,12 @@ internal class SseStream ( } override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { - if ((eventSource != es)) { - // Not the current event source using right now, should cancel. - eventSource.cancel() - return + lock.withLock { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } } if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { // Relying on okhttp3.internal to differentiate cancel case. @@ -99,8 +108,8 @@ internal class SseStream ( .retryOnConnectionFailure(false) .build() - private var es: EventSource? = null - private var reconnectTimerTask: TimerTask? = null + private var es: EventSource? = null // @GuardedBy(lock) + private var reconnectTimerTask: TimerTask? = null // @GuardedBy(lock) internal var onUpdate: ((String) -> Unit)? = null internal var onError: ((Throwable?) -> Unit)? = null @@ -108,20 +117,29 @@ internal class SseStream ( * Creates an event source and immediately returns. The connection is performed async. Errors are informed through callbacks. */ internal fun connect() { - cancel() // Clear any existing event sources. - es = client.newEventSource(request, eventSourceListener) - reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. - // This forces client side reconnection after interval. - this@SseStream.cancel() - connect() + lock.withLock { + cancelInternal() // Clear any existing event sources. + es = client.newEventSource(request, eventSourceListener) + reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. + // This forces client side reconnection after interval. + this@SseStream.cancel() + connect() + } } } - internal fun cancel() { + // @GuardedBy(lock) + private fun cancelInternal() { reconnectTimerTask?.cancel() // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. es?.cancel() es = null } + + internal fun cancel() { + lock.withLock { + cancelInternal() + } + } } \ No newline at end of file diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt index 4ee0b15..042d74e 100644 --- a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -292,16 +292,18 @@ class FlagConfigFallbackRetryWrapperTest { wrapper.start() verify(exactly = 1) { mainUpdater.start(any()) } verify(exactly = 0) { fallbackUpdater.start() } + verify(exactly = 1) { fallbackUpdater.stop() } // Stop wrapper.stop() verify(exactly = 1) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } // Start again wrapper.start() verify(exactly = 2) { mainUpdater.start(any()) } verify(exactly = 0) { fallbackUpdater.start() } + verify(exactly = 3) { fallbackUpdater.stop() } // Shutdown wrapper.shutdown() @@ -373,19 +375,19 @@ class FlagConfigFallbackRetryWrapperTest { // Retry success justRun { mainUpdater.start(capture(mainOnErrorCapture)) } - verify(exactly = 0) { fallbackUpdater.stop() } + verify(exactly = 1) { fallbackUpdater.stop() } Thread.sleep(1100) verify(exactly = 3) { mainUpdater.start(any()) } verify(exactly = 1) { fallbackUpdater.start(any()) } verify(exactly = 0) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } // No more start Thread.sleep(1100) verify(exactly = 3) { mainUpdater.start(any()) } verify(exactly = 1) { fallbackUpdater.start(any()) } verify(exactly = 0) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } wrapper.shutdown() }