diff --git a/okio/api/okio.api b/okio/api/okio.api index 10c4f049e1..597369de2a 100644 --- a/okio/api/okio.api +++ b/okio/api/okio.api @@ -772,6 +772,7 @@ public class okio/Timeout { public static final field NONE Lokio/Timeout; public fun ()V public final fun awaitSignal (Ljava/util/concurrent/locks/Condition;)V + public fun cancel ()V public fun clearDeadline ()Lokio/Timeout; public fun clearTimeout ()Lokio/Timeout; public final fun deadline (JLjava/util/concurrent/TimeUnit;)Lokio/Timeout; diff --git a/okio/src/jvmMain/kotlin/okio/Timeout.kt b/okio/src/jvmMain/kotlin/okio/Timeout.kt index 956612e0aa..5fb88ce47e 100644 --- a/okio/src/jvmMain/kotlin/okio/Timeout.kt +++ b/okio/src/jvmMain/kotlin/okio/Timeout.kt @@ -19,6 +19,7 @@ import java.io.IOException import java.io.InterruptedIOException import java.util.concurrent.TimeUnit import java.util.concurrent.locks.Condition +import kotlin.concurrent.Volatile import kotlin.time.Duration import kotlin.time.DurationUnit import kotlin.time.toTimeUnit @@ -32,6 +33,12 @@ actual open class Timeout { private var deadlineNanoTime = 0L private var timeoutNanos = 0L + /** + * A sentinel that is updated to a new object on each call to [cancel]. Sample this property + * before and after an operation to test if the timeout was canceled during the operation. + */ + @Volatile private var cancelMark: Any? = null + /** * Wait at most `timeout` time before aborting an operation. Using a per-operation timeout means * that as long as forward progress is being made, no sequence of operations will fail. @@ -107,6 +114,20 @@ actual open class Timeout { } } + /** + * Prevent all current applications of this timeout from firing. Use this when a time-limited + * operation should no longer be time-limited because the nature of the operation has changed. + * + * This function does not mutate the [deadlineNanoTime] or [timeoutNanos] properties of this + * timeout. It only applies to active operations that are limited by this timeout, and applies by + * allowing those operations to run indefinitely. + * + * Subclasses that override this method must call `super.cancel()`. + */ + open fun cancel() { + cancelMark = Any() + } + /** * Waits on `monitor` until it is signaled. Throws [InterruptedIOException] if either the thread * is interrupted or if this timeout elapses before `monitor` is signaled. @@ -239,18 +260,23 @@ actual open class Timeout { timeoutNanos } - // Attempt to wait that long. This will break out early if the monitor is notified. - var elapsedNanos = 0L - if (waitNanos > 0L) { - val waitMillis = waitNanos / 1000000L - (monitor as Object).wait(waitMillis, (waitNanos - waitMillis * 1000000L).toInt()) - elapsedNanos = System.nanoTime() - start - } + if (waitNanos <= 0) throw InterruptedIOException("timeout") - // Throw if the timeout elapsed before the monitor was notified. - if (elapsedNanos >= waitNanos) { - throw InterruptedIOException("timeout") - } + val cancelMarkBefore = cancelMark + + // Attempt to wait that long. This will return early if the monitor is notified. + val waitMillis = waitNanos / 1000000L + (monitor as Object).wait(waitMillis, (waitNanos - waitMillis * 1000000L).toInt()) + val elapsedNanos = System.nanoTime() - start + + // If there's time remaining, we probably got the call we were waiting for. + if (elapsedNanos < waitNanos) return + + // Return without throwing if this timeout was canceled while we were waiting. Note that this + // return is a 'spurious wakeup' because Object.notify() was not called. + if (cancelMark !== cancelMarkBefore) return + + throw InterruptedIOException("timeout") } catch (e: InterruptedException) { Thread.currentThread().interrupt() // Retain interrupted status. throw InterruptedIOException("interrupted") diff --git a/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt b/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt index 78226151d4..f226a3ff48 100644 --- a/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt +++ b/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt @@ -166,6 +166,49 @@ class WaitUntilNotifiedTest { } } + @Test + @Synchronized + fun cancelBeforeWaitDoesNothing() { + assumeNotWindows() + val timeout = Timeout() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancel() + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() { + val timeout = Timeout() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(500) + + val start = now() + timeout.waitUntilNotified(this) // Returns early but doesn't throw. + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun multipleCancelsAreIdempotent() { + val timeout = Timeout() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(250) + timeout.cancelLater(500) + timeout.cancelLater(750) + + val start = now() + timeout.waitUntilNotified(this) // Returns early but doesn't throw. + assertElapsed(1000.0, start) + } + /** Returns the nanotime in milliseconds as a double for measuring timeouts. */ private fun now(): Double { return System.nanoTime() / 1000000.0 @@ -178,4 +221,14 @@ class WaitUntilNotifiedTest { private fun assertElapsed(duration: Double, start: Double) { assertEquals(duration, now() - start - 200.0, 250.0) } + + private fun Timeout.cancelLater(delay: Long) { + executorService.schedule( + { + cancel() + }, + delay, + TimeUnit.MILLISECONDS, + ) + } }