From 3fb01cbc177d6f8e3f1ecd1c208e0bb766ae2c49 Mon Sep 17 00:00:00 2001 From: Ruslan Shevchenko Date: Mon, 18 Nov 2024 19:54:03 +0200 Subject: [PATCH] unwrap exceptions from CompletableFuture --- .../monads/CompletableFutureCpsMonad.scala | 54 ++++++++++++++----- ...estAsyncExceptionInCompletableFuture.scala | 46 ++++++++++++++++ 2 files changed, 86 insertions(+), 14 deletions(-) create mode 100644 jvm/src/test/scala/cpstest/TestAsyncExceptionInCompletableFuture.scala diff --git a/jvm/src/main/scala/cps/monads/CompletableFutureCpsMonad.scala b/jvm/src/main/scala/cps/monads/CompletableFutureCpsMonad.scala index 02fd325a5..01e140397 100644 --- a/jvm/src/main/scala/cps/monads/CompletableFutureCpsMonad.scala +++ b/jvm/src/main/scala/cps/monads/CompletableFutureCpsMonad.scala @@ -1,10 +1,10 @@ package cps.monads -import cps._ -import java.util.concurrent.CompletableFuture -import scala.util.Try -import scala.util.Failure -import scala.util.Success +import cps.* + +import java.util.concurrent.{CompletableFuture, CompletionException} +import scala.concurrent.Future +import scala.util.{Failure, NotGiven, Success, Try} import scala.util.control.NonFatal @@ -30,7 +30,7 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT if (e == null) then f(Success(v.nn)) else - f(Failure(e.nn)) + f(Failure(unwrapCompletableException(e.nn))) }.nn.toCompletableFuture.nn override def flatMapTry[A,B](fa:CompletableFuture[A])(f: Try[A]=>CompletableFuture[B]):CompletableFuture[B] = @@ -42,18 +42,18 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT if (e1 == null) then retval.complete(v1.nn) else - retval.completeExceptionally(e1.nn) + retval.completeExceptionally(unwrapCompletableException(e1)) } catch case NonFatal(ex) => - retval.completeExceptionally(ex) + retval.completeExceptionally(unwrapCompletableException(ex)) else try - f(Failure(e.nn)).handle{ (v1,e1) => + f(Failure(unwrapCompletableException(e.nn))).handle{ (v1,e1) => if (e1 == null) then retval.complete(v1.nn) else - retval.completeExceptionally(e1.nn) + retval.completeExceptionally(unwrapCompletableException(e1.nn)) } catch case NonFatal(ex) => @@ -69,11 +69,11 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT retval.complete(v.nn) else try - fx(e).handle{ (v1,e1) => + fx(unwrapCompletableException(e)).handle{ (v1,e1) => if (e1 == null) then retval.complete(v1.nn) else - retval.completeExceptionally(e1.nn) + retval.completeExceptionally(unwrapCompletableException(e1.nn)) } catch case NonFatal(ex) => @@ -98,7 +98,7 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT if (e == null) r.complete(v) else - r.completeExceptionally(e) + r.completeExceptionally(unwrapCompletableException(e)) } catch case NonFatal(e) => @@ -106,12 +106,38 @@ given CompletableFutureCpsMonad: CpsSchedulingMonad[CompletableFuture] with CpsT } r - def tryCancel[A](op: CompletableFuture[A]): CompletableFuture[Unit] = + def tryCancel[A](op: CompletableFuture[A]): CompletableFuture[Unit] = if (op.cancel(true)) then CompletableFuture.completedFuture(()).nn else CompletableFuture.failedFuture(new IllegalStateException("CompletableFuture is not cancelled")).nn + + private def unwrapCompletableException(ex: Throwable): Throwable = + if (ex.isInstanceOf[CompletionException] && ex.getCause() != null) then + ex.getCause().nn + else + ex.nn + + + given fromCompletableFutureConversion[G[_], T](using CpsAsyncMonad[G], CpsMonadContext[G]): CpsMonadConversion[CompletableFuture, G] with + + def apply[T](ft: CompletableFuture[T]): G[T] = + summon[CpsAsyncMonad[G]].adoptCallbackStyle(listener => + val _unused = ft.whenComplete( + (v, e) => + if (e == null) then + listener(Success(v)) + else + if (e.isInstanceOf[CompletionException] && e.getCause() != null) then + listener(Failure(e.getCause())) + else + listener(Failure(e)) + ) + ) + + } + diff --git a/jvm/src/test/scala/cpstest/TestAsyncExceptionInCompletableFuture.scala b/jvm/src/test/scala/cpstest/TestAsyncExceptionInCompletableFuture.scala new file mode 100644 index 000000000..d39d896f0 --- /dev/null +++ b/jvm/src/test/scala/cpstest/TestAsyncExceptionInCompletableFuture.scala @@ -0,0 +1,46 @@ +package cpstest + +import org.junit.{Ignore, Test} +import org.junit.Assert.* + +import java.util.concurrent.CompletableFuture +import scala.concurrent.* +import scala.concurrent.duration.* +import scala.util.* +import scala.util.control.* +import scala.concurrent.ExecutionContext.Implicits.global + +import cps.* +import cps.monads.{*, given} + +class TestAsyncExceptionInCompletableFuture { + + object X { + + def completableFutureMethod(): CompletableFuture[Int] = { + val cf = new CompletableFuture[Int]() + cf.completeExceptionally(new IllegalStateException("test exception")) + cf + } + + } + + @Test + def testExceptinInCompletableFuture(): Unit = { + val f = async[Future] { + try { + val x = X.completableFutureMethod().await + x + } catch { + case ex: IllegalStateException => + -1 + case NonFatal(ex) => + println(s"unexpected exception: $ex") + throw ex + } + } + val x = Await.ready(f, 1.second) + assert(f.value.get.get == -1) + } + +}