diff --git a/java-runtime/src/main/scala/client/Fs2ClientCall.scala b/java-runtime/src/main/scala/client/Fs2ClientCall.scala index 280e60e5..08544993 100644 --- a/java-runtime/src/main/scala/client/Fs2ClientCall.scala +++ b/java-runtime/src/main/scala/client/Fs2ClientCall.scala @@ -10,7 +10,14 @@ import fs2._ final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus]) final case class GrpcStatus(status: Status, trailers: Metadata) -class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCall[Request, Response]) extends AnyVal { +class Fs2ClientCall[F[_], Request, Response] private[client] ( + call: ClientCall[Request, Response], + errorAdapter: StatusRuntimeException => Option[Exception] +) { + + private val ea: PartialFunction[Throwable, Throwable] = { + case e: StatusRuntimeException => errorAdapter(e).getOrElse(e) + } private def cancel(message: Option[String], cause: Option[Throwable])(implicit F: Sync[F]): F[Unit] = F.delay(call.cancel(message.orNull, cause.orNull)) @@ -27,65 +34,57 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCa private def start(listener: ClientCall.Listener[Response], metadata: Metadata)(implicit F: Sync[F]): F[Unit] = F.delay(call.start(listener, metadata)) - def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(implicit F: Sync[F]): F[A] = { + def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(implicit F: Sync[F]): F[A] = createListener.flatTap(start(_, headers)) <* request(1) - } - def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] = { + def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] = sendMessage(message) *> halfClose - } - def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] = { + def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] = stream.evalMap(sendMessage) ++ Stream.eval(halfClose) - } - def handleCallError( - implicit F: ConcurrentEffect[F]): (ClientCall.Listener[Response], ExitCase[Throwable]) => F[Unit] = { + def handleCallError(implicit F: Sync[F]): (ClientCall.Listener[Response], ExitCase[Throwable]) => F[Unit] = { case (_, ExitCase.Completed) => F.unit case (_, ExitCase.Canceled) => cancel("call was cancelled".some, None) case (_, ExitCase.Error(t)) => cancel(t.getMessage.some, t.some) } - def unaryToUnaryCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] = { - F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))({ listener => - sendSingleMessage(message) *> listener.getValue - })(handleCallError) - } + def unaryToUnaryCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] = + F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))( + l => sendSingleMessage(message) *> l.getValue.adaptError(ea) + )(handleCallError) - def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata)( - implicit F: ConcurrentEffect[F]): F[Response] = { - F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))({ listener => - Stream.eval(listener.getValue).concurrently(sendStream(messages)).compile.lastOrError - })(handleCallError) - } + def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] = + F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))( + l => Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(messages)).compile.lastOrError + )(handleCallError) - def unaryToStreamingCall(message: Request, headers: Metadata)( - implicit F: ConcurrentEffect[F]): Stream[F, Response] = { + def unaryToStreamingCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): Stream[F, Response] = Stream .bracketCase(startListener(Fs2StreamClientCallListener[F, Response](call.request), headers))(handleCallError) - .flatMap(Stream.eval_(sendSingleMessage(message)) ++ _.stream) - } + .flatMap(Stream.eval_(sendSingleMessage(message)) ++ _.stream.adaptError(ea)) - def streamingToStreamingCall(messages: Stream[F, Request], headers: Metadata)( - implicit F: ConcurrentEffect[F]): Stream[F, Response] = { + def streamingToStreamingCall(messages: Stream[F, Request], headers: Metadata)(implicit F: ConcurrentEffect[F]): Stream[F, Response] = Stream .bracketCase(startListener(Fs2StreamClientCallListener[F, Response](call.request), headers))(handleCallError) - .flatMap(_.stream.concurrently(sendStream(messages))) - } + .flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages))) } object Fs2ClientCall { + def apply[F[_]]: PartiallyAppliedClientCall[F] = new PartiallyAppliedClientCall[F] + class PartiallyAppliedClientCall[F[_]](val dummy: Boolean = false) extends AnyVal { def apply[Request, Response]( channel: Channel, methodDescriptor: MethodDescriptor[Request, Response], - callOptions: CallOptions)(implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] = - F.delay(new Fs2ClientCall(channel.newCall[Request, Response](methodDescriptor, callOptions))) + callOptions: CallOptions, + errorAdapter: StatusRuntimeException => Option[Exception] = _ => None + )(implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] = { + F.delay(new Fs2ClientCall(channel.newCall[Request, Response](methodDescriptor, callOptions), errorAdapter)) + } } - def apply[F[_]]: PartiallyAppliedClientCall[F] = - new PartiallyAppliedClientCall[F] } diff --git a/java-runtime/src/test/scala/client/ClientSuite.scala b/java-runtime/src/test/scala/client/ClientSuite.scala index e337699f..c07de2fb 100644 --- a/java-runtime/src/test/scala/client/ClientSuite.scala +++ b/java-runtime/src/test/scala/client/ClientSuite.scala @@ -2,25 +2,28 @@ package org.lyranthe.fs2_grpc package java_runtime package client +import cats.implicits._ import cats.effect.{ContextShift, IO, Timer} import cats.effect.laws.util.TestContext import fs2._ import io.grpc._ import minitest._ - import scala.concurrent.TimeoutException import scala.concurrent.duration._ import scala.util.Success object ClientSuite extends SimpleTestSuite { + def fs2ClientCall(dummy: DummyClientCall) = + new Fs2ClientCall[IO, String, Int](dummy, _ => None) + test("single message to unaryToUnary") { implicit val ec: TestContext = TestContext() implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture() dummy.listener.get.onMessage(5) @@ -44,7 +47,7 @@ object ClientSuite extends SimpleTestSuite { implicit val timer: Timer[IO] = ec.timer val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client.unaryToUnaryCall("hello", new Metadata()).timeout(1.second).unsafeToFuture() ec.tick() @@ -68,7 +71,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture() dummy.listener.get.onClose(Status.OK, new Metadata()) @@ -89,7 +92,7 @@ object ClientSuite extends SimpleTestSuite { val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture() dummy.listener.get.onMessage(5) @@ -115,7 +118,7 @@ object ClientSuite extends SimpleTestSuite { val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client .streamingToUnaryCall(Stream.emits(List("a", "b", "c")), new Metadata()) .unsafeToFuture() @@ -142,7 +145,7 @@ object ClientSuite extends SimpleTestSuite { val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client .streamingToUnaryCall(Stream.empty, new Metadata()) .unsafeToFuture() @@ -168,7 +171,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client.unaryToStreamingCall("hello", new Metadata()).compile.toList.unsafeToFuture() dummy.listener.get.onMessage(1) @@ -194,7 +197,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client .streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata()) @@ -225,7 +228,7 @@ object ClientSuite extends SimpleTestSuite { implicit val timer: Timer[IO] = ec.timer val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client .streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata()) @@ -255,7 +258,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = fs2ClientCall(dummy) val result = client .streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata()) @@ -297,4 +300,46 @@ object ClientSuite extends SimpleTestSuite { val channel = result.value.get.get assert(channel.isTerminated) } + + test("error adapter is used when applicable") { + + implicit val ec: TestContext = TestContext() + implicit val cs: ContextShift[IO] = IO.contextShift(ec) + + def testCalls(shouldAdapt: Boolean): Unit = { + + def testAdapter(call: Fs2ClientCall[IO, String, Int] => IO[Unit]): Unit = { + + val (status, errorMsg) = if(shouldAdapt) (Status.ABORTED, "OhNoes!") else { + (Status.INVALID_ARGUMENT, Status.INVALID_ARGUMENT.asRuntimeException().getMessage) + } + + val adapter: StatusRuntimeException => Option[Exception] = _.getStatus match { + case Status.ABORTED if shouldAdapt => Some(new RuntimeException(errorMsg)) + case _ => None + } + + val dummy = new DummyClientCall() + val client = new Fs2ClientCall[IO, String, Int](dummy, adapter) + val result = call(client).unsafeToFuture() + + dummy.listener.get.onClose(status, new Metadata()) + ec.tick() + + assertEquals(result.value.get.failed.get.getMessage, errorMsg) + } + + testAdapter(_.unaryToUnaryCall("hello", new Metadata()).void) + testAdapter(_.unaryToStreamingCall("hello", new Metadata()).compile.toList.void) + testAdapter(_.streamingToUnaryCall(Stream.emit("hello"), new Metadata()).void) + testAdapter(_.streamingToStreamingCall(Stream.emit("hello"), new Metadata()).compile.toList.void) + } + + /// + + testCalls(shouldAdapt = true) + testCalls(shouldAdapt = false) + + } + } diff --git a/sbt-java-gen/src/main/scala/Fs2GrpcServicePrinter.scala b/sbt-java-gen/src/main/scala/Fs2GrpcServicePrinter.scala index 5dd3812a..f22db43d 100644 --- a/sbt-java-gen/src/main/scala/Fs2GrpcServicePrinter.scala +++ b/sbt-java-gen/src/main/scala/Fs2GrpcServicePrinter.scala @@ -37,7 +37,7 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d private[this] def createClientCall(method: MethodDescriptor) = { val basicClientCall = - s"$Fs2ClientCall[F](channel, _root_.$servicePkgName.${serviceName}Grpc.${method.descriptorName}, c($CallOptions.DEFAULT))" + s"$Fs2ClientCall[F](channel, _root_.$servicePkgName.${serviceName}Grpc.${method.descriptorName}, c($CallOptions.DEFAULT), $ErrorAdapterName)" if (method.isServerStreaming) s"$Stream.eval($basicClientCall)" else @@ -102,7 +102,7 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d private[this] def serviceClient: PrinterEndo = { _.add( - s"def client[F[_]: $ConcurrentEffect, $Ctx](channel: $Channel, f: $Ctx => $Metadata, c: $CallOptions => $CallOptions = identity): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {") + s"def client[F[_]: $ConcurrentEffect, $Ctx](channel: $Channel, f: $Ctx => $Metadata, c: $CallOptions => $CallOptions = identity, $ErrorAdapterDefault): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {") .indent .call(serviceMethodImplementations) .outdent @@ -126,9 +126,9 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d private[this] def serviceClientMeta: PrinterEndo = _.add( - s"def stub[F[_]: $ConcurrentEffect](channel: $Channel, callOptions: $CallOptions = $CallOptions.DEFAULT): $serviceNameFs2[F, $Metadata] = {") + s"def stub[F[_]: $ConcurrentEffect](channel: $Channel, callOptions: $CallOptions = $CallOptions.DEFAULT, $ErrorAdapterDefault): $serviceNameFs2[F, $Metadata] = {") .indent - .add(s"client[F, $Metadata](channel, identity, _ => callOptions)") + .add(s"client[F, $Metadata](channel, identity, _ => callOptions, $ErrorAdapterName)") .outdent .add("}") @@ -172,6 +172,10 @@ object Fs2GrpcServicePrinter { val Fs2ServerCallHandler = s"$jrtPkg.server.Fs2ServerCallHandler" val Fs2ClientCall = s"$jrtPkg.client.Fs2ClientCall" + val ErrorAdapter = s"$grpcPkg.StatusRuntimeException => Option[Exception]" + val ErrorAdapterName = "errorAdapter" + val ErrorAdapterDefault = s"$ErrorAdapterName: $ErrorAdapter = _ => None" + val ServerServiceDefinition = s"$grpcPkg.ServerServiceDefinition" val CallOptions = s"$grpcPkg.CallOptions" val Channel = s"$grpcPkg.Channel"