Skip to content

Commit

Permalink
Handle ControlThrowable exception in 'race' function (#216)
Browse files Browse the repository at this point in the history
Co-authored-by: adamw <adam@warski.org>
  • Loading branch information
rcardin and adamw authored Sep 18, 2024
1 parent 8634884 commit 47d490f
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
39 changes: 34 additions & 5 deletions core/src/main/scala/ox/race.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import java.util.concurrent.ArrayBlockingQueue
import scala.annotation.tailrec
import scala.concurrent.TimeoutException
import scala.concurrent.duration.FiniteDuration
import scala.util.control.{ControlThrowable, NonFatal}
import scala.util.{Failure, Success, Try}

/** A `Some` if the computation `t` took less than `duration`, and `None` otherwise. if the computation `t` throws an exception, it is
Expand Down Expand Up @@ -37,8 +38,22 @@ def race[T](fs: Seq[() => T]): T = race(NoErrorMode)(fs)
*/
def race[E, F[_], T](em: ErrorMode[E, F])(fs: Seq[() => F[T]]): F[T] =
unsupervised {
val result = new ArrayBlockingQueue[Try[F[T]]](fs.size)
fs.foreach(f => forkUnsupervised(result.put(Try(f()))))
val result = new ArrayBlockingQueue[RaceBranchResult[F[T]]](fs.size)
fs.foreach(f =>
forkUnsupervised {
val r =
try RaceBranchResult.Success(f())
catch
case NonFatal(e) => RaceBranchResult.NonFatalException(e)
// #213: we treat ControlThrowables as non-fatal, as in the context of `race` they should count as a
// "failed branch", but not cause immediate interruption
case e: ControlThrowable => RaceBranchResult.NonFatalException(e)
// #213: any other fatal exceptions must cause `race` to be interrupted immediately; this is needed as we
// are in an unsupervised scope, so by default exceptions aren't propagated
case e => RaceBranchResult.FatalException(e)
result.put(r)
}
)

@tailrec
def takeUntilSuccess(failures: Vector[Either[E, Throwable]], left: Int): F[T] =
Expand All @@ -57,10 +72,11 @@ def race[E, F[_], T](em: ErrorMode[E, F])(fs: Seq[() => F[T]]): F[T] =
throw e
else
result.take() match
case Success(v) =>
case RaceBranchResult.Success(v) =>
if em.isError(v) then takeUntilSuccess(failures :+ Left(em.getError(v)), left - 1)
else v
case Failure(e) => takeUntilSuccess(failures :+ Right(e), left - 1)
case RaceBranchResult.NonFatalException(e) => takeUntilSuccess(failures :+ Right(e), left - 1)
case RaceBranchResult.FatalException(e) => throw e

takeUntilSuccess(Vector.empty, fs.size)
}
Expand Down Expand Up @@ -113,7 +129,15 @@ def raceEither[E, T](f1: => Either[E, T], f2: => Either[E, T], f3: => Either[E,
//

/** Returns the result of the first computation to complete (either successfully or with an exception). */
def raceResult[T](fs: Seq[() => T]): T = race(fs.map(f => () => Try(f()))).get // TODO optimize
def raceResult[T](fs: Seq[() => T]): T = race(
fs.map(f =>
() =>
// #213: the Try() constructor doesn't catch fatal exceptions; in this context, we want to propagate *all*
// exceptions as fast as possible
try Success(f())
catch case e: Throwable => Failure(e)
)
).get // TODO optimize

/** Returns the result of the first computation to complete (either successfully or with an exception). */
def raceResult[T](f1: => T, f2: => T): T = raceResult(List(() => f1, () => f2))
Expand All @@ -123,3 +147,8 @@ def raceResult[T](f1: => T, f2: => T, f3: => T): T = raceResult(List(() => f1, (

/** Returns the result of the first computation to complete (either successfully or with an exception). */
def raceResult[T](f1: => T, f2: => T, f3: => T, f4: => T): T = raceResult(List(() => f1, () => f2, () => f3, () => f4))

private enum RaceBranchResult[+T]:
case Success(value: T)
case NonFatalException(throwable: Throwable)
case FatalException(throwable: Throwable)
80 changes: 80 additions & 0 deletions core/src/test/scala/ox/RaceTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import org.scalatest.matchers.should.Matchers
import ox.*
import ox.util.Trail

import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.TimeoutException
import scala.concurrent.duration.DurationInt
import scala.util.control.ControlThrowable

class RaceTest extends AnyFlatSpec with Matchers {
"timeout" should "short-circuit a long computation" in {
Expand Down Expand Up @@ -131,6 +133,40 @@ class RaceTest extends AnyFlatSpec with Matchers {
e.getSuppressed.map(_.getMessage).toSet shouldBe Set("boom2!", "boom3!")
}

it should "treat ControlThrowable as a non-fatal exception" in {
try
race(
throw new NastyControlThrowable("boom1!"), {
sleep(200.millis)
throw new NastyControlThrowable("boom2!")
}, {
sleep(200.millis)
throw new NastyControlThrowable("boom3!")
}
)
fail("Race should throw")
catch
case e: Throwable =>
e.getMessage shouldBe "boom1!"
// Suppressed exceptions are not available for ControlThrowable
}

it should "immediately rethrow other fatal exceptions" in {
val flag = new AtomicBoolean(false)
try
race(
throw new StackOverflowError(), {
sleep(1.second)
flag.set(true)
throw new RuntimeException()
}
)
fail("Race should throw")
catch
case e: StackOverflowError => // the expected exception
flag.get() shouldBe false // because a fatal exception was thrown, the second computation should be interrupted
}

"raceEither" should "return the first successful computation to complete" in {
val trail = Trail()
val start = System.currentTimeMillis()
Expand All @@ -155,4 +191,48 @@ class RaceTest extends AnyFlatSpec with Matchers {
trail.get shouldBe Vector("error", "slow")
end - start should be < 1000L
}

"raceResult" should "immediately return when a normal exception occurs" in {
val flag = new AtomicBoolean(false)
try
raceResult(
throw new RuntimeException("boom!"), {
sleep(1.second)
flag.set(true)
}
)
fail("raceResult should throw")
catch
case e: Throwable =>
e.getMessage shouldBe "boom!"
flag.get() shouldBe false
}

it should "immediately return when a control exception occurs" in {
val flag = new AtomicBoolean(false)
try
raceResult(
throw new NastyControlThrowable("boom!"), {
sleep(1.second)
flag.set(true)
}
)
fail("raceResult should throw")
catch case e: NastyControlThrowable => flag.get() shouldBe false
}

it should "immediately return when a fatal exception occurs" in {
val flag = new AtomicBoolean(false)
try
raceResult(
throw new StackOverflowError(), {
sleep(1.second)
flag.set(true)
}
)
fail("raceResult should throw")
catch case e: StackOverflowError => flag.get() shouldBe false
}

class NastyControlThrowable(val message: String) extends ControlThrowable(message) {}
}

0 comments on commit 47d490f

Please sign in to comment.