Skip to content

Commit

Permalink
add unary response trailers metadata support
Browse files Browse the repository at this point in the history
  • Loading branch information
TFT17 committed Feb 12, 2024
1 parent 18f1406 commit c88eda5
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d
val ctx = s"ctx: $Ctx"

s"def ${method.name}" + (method.streamType match {
case StreamType.Unary => s"(request: $scalaInType, $ctx): F[$scalaOutType]"
case StreamType.ClientStreaming => s"(request: $Stream[F, $scalaInType], $ctx): F[$scalaOutType]"
case StreamType.Unary => s"(request: $scalaInType, $ctx): F[($scalaOutType, $Metadata)]"
case StreamType.ClientStreaming => s"(request: $Stream[F, $scalaInType], $ctx): F[($scalaOutType, $Metadata)]"
case StreamType.ServerStreaming => s"(request: $scalaInType, $ctx): $Stream[F, $scalaOutType]"
case StreamType.Bidirectional => s"(request: $Stream[F, $scalaInType], $ctx): $Stream[F, $scalaOutType]"
})
Expand Down
8 changes: 4 additions & 4 deletions e2e/src/test/resources/TestServiceFs2Grpc.scala.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@ package hello.world
import _root_.cats.syntax.all._

trait TestServiceFs2Grpc[F[_], A] {
def noStreaming(request: hello.world.TestMessage, ctx: A): F[hello.world.TestMessage]
def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[hello.world.TestMessage]
def noStreaming(request: hello.world.TestMessage, ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)]
def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)]
def serverStreaming(request: hello.world.TestMessage, ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage]
def bothStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage]
}

object TestServiceFs2Grpc extends _root_.fs2.grpc.GeneratedCompanion[TestServiceFs2Grpc] {

def mkClient[F[_]: _root_.cats.effect.Async, A](dispatcher: _root_.cats.effect.std.Dispatcher[F], channel: _root_.io.grpc.Channel, mkMetadata: A => F[_root_.io.grpc.Metadata], clientOptions: _root_.fs2.grpc.client.ClientOptions): TestServiceFs2Grpc[F, A] = new TestServiceFs2Grpc[F, A] {
def noStreaming(request: hello.world.TestMessage, ctx: A): F[hello.world.TestMessage] = {
def noStreaming(request: hello.world.TestMessage, ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] = {
mkMetadata(ctx).flatMap { m =>
_root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_NO_STREAMING, dispatcher, clientOptions).flatMap(_.unaryToUnaryCall(request, m))
}
}
def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[hello.world.TestMessage] = {
def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] = {
mkMetadata(ctx).flatMap { m =>
_root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, dispatcher, clientOptions).flatMap(_.streamingToUnaryCall(request, m))
}
Expand Down
8 changes: 2 additions & 6 deletions runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
private def sendSingleMessage(message: Request): F[Unit] =
F.delay(call.sendMessage(message)) *> halfClose

//

def unaryToUnaryCall(message: Request, headers: Metadata): F[Response] =
def unaryToUnaryCall(message: Request, headers: Metadata): F[(Response, Metadata)] =
Fs2UnaryCallHandler.unary(call, options, message, headers)

def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] =
def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[(Response, Metadata)] =
StreamOutput.client(call).flatMap { output =>
Fs2UnaryCallHandler.stream(call, options, dispatcher, messages, output, headers)
}
Expand All @@ -87,8 +85,6 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
}
}

//

private def handleExitCase(cancelSucceed: Boolean): (ClientCall.Listener[Response], Resource.ExitCase) => F[Unit] = {
case (_, Resource.ExitCase.Succeeded) => cancel("call done".some, None).whenA(cancelSucceed)
case (_, Resource.ExitCase.Canceled) => cancel("call was cancelled".some, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,28 @@ private[client] object Fs2UnaryCallHandler {

object ReceiveState {
def init[F[_]: Sync, R](
callback: Either[Throwable, R] => Unit,
callback: Either[Throwable, (R, Metadata)] => Unit,
pf: PartialFunction[StatusRuntimeException, Exception]
): F[Ref[SyncIO, ReceiveState[R]]] =
Ref.in(new PendingMessage[R]({
case r: Right[Throwable, R] => callback(r)
case r: Right[Throwable, (R, Metadata)] => callback(r)
case Left(e: StatusRuntimeException) => callback(Left(pf.lift(e).getOrElse(e)))
case l: Left[Throwable, R] => callback(l)
case l: Left[Throwable, (R, Metadata)] => callback(l)
}))
}

class PendingMessage[R](callback: Either[Throwable, R] => Unit) extends ReceiveState[R] {
class PendingMessage[R](callback: Either[Throwable, (R, Metadata)] => Unit) extends ReceiveState[R] {
def receive(message: R): PendingHalfClose[R] = new PendingHalfClose(callback, message)

def sendError(error: Throwable): SyncIO[ReceiveState[R]] =
SyncIO(callback(Left(error))).as(new Done[R])
}

class PendingHalfClose[R](callback: Either[Throwable, R] => Unit, message: R) extends ReceiveState[R] {
class PendingHalfClose[R](callback: Either[Throwable, (R, Metadata)] => Unit, message: R) extends ReceiveState[R] {
def sendError(error: Throwable): SyncIO[ReceiveState[R]] =
SyncIO(callback(Left(error))).as(new Done[R])

def done: SyncIO[ReceiveState[R]] = SyncIO(callback(Right(message))).as(new Done[R])
def done(trailers: Metadata): SyncIO[ReceiveState[R]] = SyncIO(callback(Right((message, trailers)))).as(new Done[R])
}

class Done[R] extends ReceiveState[R]
Expand All @@ -69,6 +69,7 @@ private[client] object Fs2UnaryCallHandler {
signalReadiness: SyncIO[Unit]
): ClientCall.Listener[Response] =
new ClientCall.Listener[Response] {

override def onMessage(message: Response): Unit =
state.get
.flatMap {
Expand All @@ -90,7 +91,7 @@ private[client] object Fs2UnaryCallHandler {
if (status.isOk) {
state.get.flatMap {
case expected: PendingHalfClose[Response] =>
expected.done.flatMap(state.set)
expected.done(trailers).flatMap(state.set)
case current: PendingMessage[Response] =>
current
.sendError(
Expand Down Expand Up @@ -120,7 +121,7 @@ private[client] object Fs2UnaryCallHandler {
options: ClientOptions,
message: Request,
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
)(implicit F: Async[F]): F[(Response, Metadata)] = F.async[(Response, Metadata)] { cb =>
ReceiveState.init(cb, options.errorAdapter).map { state =>
call.start(mkListener[Response](state, SyncIO.unit), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
Expand All @@ -139,8 +140,8 @@ private[client] object Fs2UnaryCallHandler {
messages: Stream[F, Request],
output: StreamOutput[F, Request],
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).flatMap { state =>
)(implicit F: Async[F]): F[(Response, Metadata)] = F.async[(Response, Metadata)] { cb =>
ReceiveState.init[F, Response](cb, options.errorAdapter).flatMap { state =>
call.start(mkListener[Response](state, output.onReadySync(dispatcher)), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ class Fs2ServerCallHandler[F[_]: Async] private (
) {

def unaryToUnaryCall[Request, Response](
implementation: (Request, Metadata) => F[Response]
implementation: (Request, Metadata) => F[(Response, Metadata)]
): ServerCallHandler[Request, Response] =
Fs2UnaryServerCallHandler.unary(implementation, options, dispatcher)
Fs2UnaryServerCallHandler.unary((req, meta) => implementation(req, meta), options, dispatcher)

def unaryToStreamingCall[Request, Response](
implementation: (Request, Metadata) => Stream[F, Response]
): ServerCallHandler[Request, Response] =
Fs2UnaryServerCallHandler.stream(implementation, options, dispatcher)

def streamingToUnaryCall[Request, Response](
implementation: (Stream[F, Request], Metadata) => F[Response]
implementation: (Stream[F, Request], Metadata) => F[(Response, Metadata)]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options))
Expand Down
33 changes: 25 additions & 8 deletions runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,45 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
}
}

private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendSingleMessage
private def handleUnaryResponse(headers: Metadata, response: F[(Response, Metadata)])(implicit
F: Sync[F]
): F[Metadata] = {
for {
_ <- call.sendHeaders(headers)
_ <- call.request(1)
responseWithTrailers <- response
_ <- call.sendSingleMessage(responseWithTrailers._1)
} yield responseWithTrailers._2
}

private def handleStreamResponse(headers: Metadata, sendResponse: Stream[F, Nothing])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> sendResponse.compile.drain

private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = {
private def unsafeRun(f: F[Metadata])(implicit F: Async[F]): Unit = {
val bracketed =
F.handleError {
F.guaranteeCase(f) {
case Outcome.Succeeded(_) => call.closeStream(Status.OK, new Metadata())
case Outcome.Succeeded(mdF) =>
for {
md <- mdF
_ <- call.closeStream(Status.OK, md)
} yield ()

case Outcome.Canceled() => call.closeStream(Status.CANCELLED, new Metadata())
case Outcome.Errored(t) => reportError(t)
}
}.void
}(_ => ())

// Exceptions are reported by closing the call
dispatcher.unsafeRunAndForget(F.race(bracketed, isCancelled.get))
}

def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])(implicit
def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[(Response, Metadata)])(implicit
F: Async[F]
): Unit =
unsafeRun(handleUnaryResponse(headers, implementation(source)))
unsafeRun(
handleUnaryResponse(headers, implementation(source))
)

def unsafeStreamResponse(
streamOutput: StreamOutput[F, Response],
Expand All @@ -81,5 +96,7 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
)(implicit
F: Async[F]
): Unit =
unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source))))
unsafeRun(
handleStreamResponse(headers, streamOutput.writeStream(implementation(source))).as(new Metadata())
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

package fs2.grpc.server.internal

import cats.effect._
import cats.effect.*
import cats.syntax.all.*
import cats.effect.std.Dispatcher
import fs2._
import fs2.*
import fs2.grpc.server.ServerCallOptions
import io.grpc._
import io.grpc.*

private[server] object Fs2ServerCall {
type Cancel = SyncIO[Unit]
Expand Down Expand Up @@ -64,15 +65,18 @@ private[server] final class Fs2ServerCall[Request, Response](
}
.stream
.compile
.drain,
.drain
.as(new Metadata()),
dispatcher
)

def unary[F[_]](response: F[Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] =
def unary[F[_]](response: F[(Response, Metadata)], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] =
run(
F.map(response) { message =>
call.sendHeaders(new Metadata())
call.sendMessage(message)
val (response, trailers) = message
call.sendMessage(response)
trailers
},
dispatcher
)
Expand All @@ -83,15 +87,15 @@ private[server] final class Fs2ServerCall[Request, Response](
def close(status: Status, metadata: Metadata): SyncIO[Unit] =
SyncIO(call.close(status, metadata))

private def run[F[_]](completed: F[Unit], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = {
private def run[F[_]](completed: F[Metadata], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = {
SyncIO {
val cancel = dispatcher.unsafeRunCancelable(
F.handleError {
F.guaranteeCase(completed) {
case Outcome.Succeeded(_) => close(Status.OK, new Metadata()).to[F]
case Outcome.Succeeded(trailersF) => trailersF.flatMap(trailers => close(Status.OK, trailers).to[F])
case Outcome.Errored(e) => handleError(e).to[F]
case Outcome.Canceled() => close(Status.CANCELLED, new Metadata()).to[F]
}
}.void
}(_ => ())
)
SyncIO(cancel()).void
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private[server] object Fs2UnaryServerCallHandler {
}

def unary[F[_]: Sync, Request, Response](
impl: (Request, Metadata) => F[Response],
impl: (Request, Metadata) => F[(Response, Metadata)],
options: ServerOptions,
dispatcher: Dispatcher[F]
): ServerCallHandler[Request, Response] =
Expand Down
7 changes: 4 additions & 3 deletions runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ class ClientSuite extends Fs2GrpcSuite {
assertEquals(result.value, None)

dummy.listener.get.onClose(Status.OK, new Metadata())
// TODO

// Check that call completes after status
tc.tick()
assertEquals(result.value, Some(Success(5)))
assertEquals(result.value.map(_.map(_._1)), Some(Success(5)))
assertEquals(dummy.messagesSent.size, 1)
assertEquals(dummy.requested, 2)

Expand Down Expand Up @@ -139,7 +140,7 @@ class ClientSuite extends Fs2GrpcSuite {

// Check that call completes after status
tc.tick()
assertEquals(result.value, Some(Success(5)))
assertEquals(result.value.map(_.map(_._1)), Some(Success(5)))
assertEquals(dummy.messagesSent.size, 3)
assertEquals(dummy.requested, 2)

Expand Down Expand Up @@ -192,7 +193,7 @@ class ClientSuite extends Fs2GrpcSuite {

// Check that call completes after status
tc.tick()
assertEquals(result.value, Some(Success(5)))
assertEquals(result.value.map(_.map(_._1)), Some(Success(5)))
assertEquals(dummy.messagesSent.size, 0)
assertEquals(dummy.requested, 2)

Expand Down
21 changes: 14 additions & 7 deletions runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class ServerSuite extends Fs2GrpcSuite {
options: ServerOptions = ServerOptions.default
): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) =>
val dummy = new DummyServerCall
val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d)
val handler =
Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), options, d)
val listener = handler.startCall(dummy, new Metadata())

listener.onMessage("123")
Expand All @@ -58,7 +59,11 @@ class ServerSuite extends Fs2GrpcSuite {

runTest("cancellation for unaryToUnary") { (tc, d) =>
val dummy = new DummyServerCall
val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), ServerOptions.default, d)
val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int](
(req, _) => IO((req.length, new Metadata())),
ServerOptions.default,
d
)
val listener = handler.startCall(dummy, new Metadata())

listener.onCancel()
Expand All @@ -71,7 +76,7 @@ class ServerSuite extends Fs2GrpcSuite {
runTest("cancellation on the fly for unaryToUnary") { (tc, d) =>
val dummy = new DummyServerCall
val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int](
(req, _) => IO(req.length).delayBy(10.seconds),
(req, _) => IO((req.length, new Metadata())).delayBy(10.seconds),
ServerOptions.default,
d
)
Expand All @@ -94,7 +99,8 @@ class ServerSuite extends Fs2GrpcSuite {
options: ServerOptions = ServerOptions.default
): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) =>
val dummy = new DummyServerCall
val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d)
val handler =
Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), options, d)
val listener = handler.startCall(dummy, new Metadata())

listener.onMessage("123")
Expand All @@ -112,7 +118,8 @@ class ServerSuite extends Fs2GrpcSuite {
options: ServerOptions = ServerOptions.default
): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) =>
val dummy = new DummyServerCall
val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d)
val handler =
Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), options, d)
val listener = handler.startCall(dummy, new Metadata())

listener.onHalfClose()
Expand Down Expand Up @@ -320,7 +327,7 @@ class ServerSuite extends Fs2GrpcSuite {
val dummy = new DummyServerCall

val handler = Fs2ServerCallHandler[IO](d, so)
.streamingToUnaryCall[String, Int]((req, _) => implementation(req))
.streamingToUnaryCall[String, Int]((req, _) => implementation(req).map((_, new Metadata())))
val listener = handler.startCall(dummy, new Metadata())

listener.onMessage("ab")
Expand All @@ -339,7 +346,7 @@ class ServerSuite extends Fs2GrpcSuite {
val deferred = d.unsafeRunSync(Deferred[IO, Unit])
val handler = Fs2ServerCallHandler[IO](d, ServerOptions.default)
.streamingToUnaryCall[String, Int]((requests, _) => {
requests.evalMap(_ => deferred.get).compile.drain.as(1)
requests.evalMap(_ => deferred.get).compile.drain.as((1, new Metadata()))
})
val listener = handler.startCall(dummy, new Metadata())

Expand Down

0 comments on commit c88eda5

Please sign in to comment.