Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to adapt status runtime exceptions #137

Merged
merged 1 commit into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 31 additions & 32 deletions java-runtime/src/main/scala/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ import fs2._
final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus])
final case class GrpcStatus(status: Status, trailers: Metadata)

class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCall[Request, Response]) extends AnyVal {
class Fs2ClientCall[F[_], Request, Response] private[client] (
call: ClientCall[Request, Response],
errorAdapter: StatusRuntimeException => Option[Exception]
) {

private val ea: PartialFunction[Throwable, Throwable] = {
case e: StatusRuntimeException => errorAdapter(e).getOrElse(e)
}

private def cancel(message: Option[String], cause: Option[Throwable])(implicit F: Sync[F]): F[Unit] =
F.delay(call.cancel(message.orNull, cause.orNull))
Expand All @@ -27,65 +34,57 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCa
private def start(listener: ClientCall.Listener[Response], metadata: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.start(listener, metadata))

def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(implicit F: Sync[F]): F[A] = {
def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(implicit F: Sync[F]): F[A] =
createListener.flatTap(start(_, headers)) <* request(1)
}

def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] = {
def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] =
sendMessage(message) *> halfClose
}

def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] = {
def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] =
stream.evalMap(sendMessage) ++ Stream.eval(halfClose)
}

def handleCallError(
implicit F: ConcurrentEffect[F]): (ClientCall.Listener[Response], ExitCase[Throwable]) => F[Unit] = {
def handleCallError(implicit F: Sync[F]): (ClientCall.Listener[Response], ExitCase[Throwable]) => F[Unit] = {
case (_, ExitCase.Completed) => F.unit
case (_, ExitCase.Canceled) => cancel("call was cancelled".some, None)
case (_, ExitCase.Error(t)) => cancel(t.getMessage.some, t.some)
}

def unaryToUnaryCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] = {
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))({ listener =>
sendSingleMessage(message) *> listener.getValue
})(handleCallError)
}
def unaryToUnaryCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] =
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))(
l => sendSingleMessage(message) *> l.getValue.adaptError(ea)
)(handleCallError)

def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata)(
implicit F: ConcurrentEffect[F]): F[Response] = {
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))({ listener =>
Stream.eval(listener.getValue).concurrently(sendStream(messages)).compile.lastOrError
})(handleCallError)
}
def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] =
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers))(
l => Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(messages)).compile.lastOrError
)(handleCallError)

def unaryToStreamingCall(message: Request, headers: Metadata)(
implicit F: ConcurrentEffect[F]): Stream[F, Response] = {
def unaryToStreamingCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): Stream[F, Response] =
Stream
.bracketCase(startListener(Fs2StreamClientCallListener[F, Response](call.request), headers))(handleCallError)
.flatMap(Stream.eval_(sendSingleMessage(message)) ++ _.stream)
}
.flatMap(Stream.eval_(sendSingleMessage(message)) ++ _.stream.adaptError(ea))

def streamingToStreamingCall(messages: Stream[F, Request], headers: Metadata)(
implicit F: ConcurrentEffect[F]): Stream[F, Response] = {
def streamingToStreamingCall(messages: Stream[F, Request], headers: Metadata)(implicit F: ConcurrentEffect[F]): Stream[F, Response] =
Stream
.bracketCase(startListener(Fs2StreamClientCallListener[F, Response](call.request), headers))(handleCallError)
.flatMap(_.stream.concurrently(sendStream(messages)))
}
.flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages)))
}

object Fs2ClientCall {

def apply[F[_]]: PartiallyAppliedClientCall[F] = new PartiallyAppliedClientCall[F]

class PartiallyAppliedClientCall[F[_]](val dummy: Boolean = false) extends AnyVal {

def apply[Request, Response](
channel: Channel,
methodDescriptor: MethodDescriptor[Request, Response],
callOptions: CallOptions)(implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] =
F.delay(new Fs2ClientCall(channel.newCall[Request, Response](methodDescriptor, callOptions)))
callOptions: CallOptions,
errorAdapter: StatusRuntimeException => Option[Exception] = _ => None
)(implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] = {
F.delay(new Fs2ClientCall(channel.newCall[Request, Response](methodDescriptor, callOptions), errorAdapter))
}

}

def apply[F[_]]: PartiallyAppliedClientCall[F] =
new PartiallyAppliedClientCall[F]
}
67 changes: 56 additions & 11 deletions java-runtime/src/test/scala/client/ClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,28 @@ package org.lyranthe.fs2_grpc
package java_runtime
package client

import cats.implicits._
import cats.effect.{ContextShift, IO, Timer}
import cats.effect.laws.util.TestContext
import fs2._
import io.grpc._
import minitest._

import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.util.Success

object ClientSuite extends SimpleTestSuite {

def fs2ClientCall(dummy: DummyClientCall) =
new Fs2ClientCall[IO, String, Int](dummy, _ => None)

test("single message to unaryToUnary") {

implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture()
dummy.listener.get.onMessage(5)

Expand All @@ -44,7 +47,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val timer: Timer[IO] = ec.timer

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client.unaryToUnaryCall("hello", new Metadata()).timeout(1.second).unsafeToFuture()

ec.tick()
Expand All @@ -68,7 +71,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture()

dummy.listener.get.onClose(Status.OK, new Metadata())
Expand All @@ -89,7 +92,7 @@ object ClientSuite extends SimpleTestSuite {


val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture()
dummy.listener.get.onMessage(5)

Expand All @@ -115,7 +118,7 @@ object ClientSuite extends SimpleTestSuite {


val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client
.streamingToUnaryCall(Stream.emits(List("a", "b", "c")), new Metadata())
.unsafeToFuture()
Expand All @@ -142,7 +145,7 @@ object ClientSuite extends SimpleTestSuite {


val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client
.streamingToUnaryCall(Stream.empty, new Metadata())
.unsafeToFuture()
Expand All @@ -168,7 +171,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result = client.unaryToStreamingCall("hello", new Metadata()).compile.toList.unsafeToFuture()

dummy.listener.get.onMessage(1)
Expand All @@ -194,7 +197,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result =
client
.streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata())
Expand Down Expand Up @@ -225,7 +228,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val timer: Timer[IO] = ec.timer

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result =
client
.streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata())
Expand Down Expand Up @@ -255,7 +258,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = fs2ClientCall(dummy)
val result =
client
.streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata())
Expand Down Expand Up @@ -297,4 +300,46 @@ object ClientSuite extends SimpleTestSuite {
val channel = result.value.get.get
assert(channel.isTerminated)
}

test("error adapter is used when applicable") {

implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

def testCalls(shouldAdapt: Boolean): Unit = {

def testAdapter(call: Fs2ClientCall[IO, String, Int] => IO[Unit]): Unit = {

val (status, errorMsg) = if(shouldAdapt) (Status.ABORTED, "OhNoes!") else {
(Status.INVALID_ARGUMENT, Status.INVALID_ARGUMENT.asRuntimeException().getMessage)
}

val adapter: StatusRuntimeException => Option[Exception] = _.getStatus match {
case Status.ABORTED if shouldAdapt => Some(new RuntimeException(errorMsg))
case _ => None
}

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy, adapter)
val result = call(client).unsafeToFuture()

dummy.listener.get.onClose(status, new Metadata())
ec.tick()

assertEquals(result.value.get.failed.get.getMessage, errorMsg)
}

testAdapter(_.unaryToUnaryCall("hello", new Metadata()).void)
testAdapter(_.unaryToStreamingCall("hello", new Metadata()).compile.toList.void)
testAdapter(_.streamingToUnaryCall(Stream.emit("hello"), new Metadata()).void)
testAdapter(_.streamingToStreamingCall(Stream.emit("hello"), new Metadata()).compile.toList.void)
}

///

testCalls(shouldAdapt = true)
testCalls(shouldAdapt = false)

}

}
12 changes: 8 additions & 4 deletions sbt-java-gen/src/main/scala/Fs2GrpcServicePrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d

private[this] def createClientCall(method: MethodDescriptor) = {
val basicClientCall =
s"$Fs2ClientCall[F](channel, _root_.$servicePkgName.${serviceName}Grpc.${method.descriptorName}, c($CallOptions.DEFAULT))"
s"$Fs2ClientCall[F](channel, _root_.$servicePkgName.${serviceName}Grpc.${method.descriptorName}, c($CallOptions.DEFAULT), $ErrorAdapterName)"
if (method.isServerStreaming)
s"$Stream.eval($basicClientCall)"
else
Expand Down Expand Up @@ -102,7 +102,7 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d

private[this] def serviceClient: PrinterEndo = {
_.add(
s"def client[F[_]: $ConcurrentEffect, $Ctx](channel: $Channel, f: $Ctx => $Metadata, c: $CallOptions => $CallOptions = identity): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {")
s"def client[F[_]: $ConcurrentEffect, $Ctx](channel: $Channel, f: $Ctx => $Metadata, c: $CallOptions => $CallOptions = identity, $ErrorAdapterDefault): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {")
.indent
.call(serviceMethodImplementations)
.outdent
Expand All @@ -126,9 +126,9 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d

private[this] def serviceClientMeta: PrinterEndo =
_.add(
s"def stub[F[_]: $ConcurrentEffect](channel: $Channel, callOptions: $CallOptions = $CallOptions.DEFAULT): $serviceNameFs2[F, $Metadata] = {")
s"def stub[F[_]: $ConcurrentEffect](channel: $Channel, callOptions: $CallOptions = $CallOptions.DEFAULT, $ErrorAdapterDefault): $serviceNameFs2[F, $Metadata] = {")
.indent
.add(s"client[F, $Metadata](channel, identity, _ => callOptions)")
.add(s"client[F, $Metadata](channel, identity, _ => callOptions, $ErrorAdapterName)")
.outdent
.add("}")

Expand Down Expand Up @@ -172,6 +172,10 @@ object Fs2GrpcServicePrinter {
val Fs2ServerCallHandler = s"$jrtPkg.server.Fs2ServerCallHandler"
val Fs2ClientCall = s"$jrtPkg.client.Fs2ClientCall"

val ErrorAdapter = s"$grpcPkg.StatusRuntimeException => Option[Exception]"
val ErrorAdapterName = "errorAdapter"
val ErrorAdapterDefault = s"$ErrorAdapterName: $ErrorAdapter = _ => None"

val ServerServiceDefinition = s"$grpcPkg.ServerServiceDefinition"
val CallOptions = s"$grpcPkg.CallOptions"
val Channel = s"$grpcPkg.Channel"
Expand Down