diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala index 51781e44..b6705054 100644 --- a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala @@ -40,8 +40,8 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d val ctx = s"ctx: $Ctx" s"def ${method.name}" + (method.streamType match { - case StreamType.Unary => s"(request: $scalaInType, $ctx): F[$scalaOutType]" - case StreamType.ClientStreaming => s"(request: $Stream[F, $scalaInType], $ctx): F[$scalaOutType]" + case StreamType.Unary => s"(request: $scalaInType, $ctx): F[($scalaOutType, $Metadata)]" + case StreamType.ClientStreaming => s"(request: $Stream[F, $scalaInType], $ctx): F[($scalaOutType, $Metadata)]" case StreamType.ServerStreaming => s"(request: $scalaInType, $ctx): $Stream[F, $scalaOutType]" case StreamType.Bidirectional => s"(request: $Stream[F, $scalaInType], $ctx): $Stream[F, $scalaOutType]" }) diff --git a/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt b/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt index 3c4b5add..1c6c895e 100644 --- a/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt +++ b/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt @@ -3,8 +3,8 @@ package hello.world import _root_.cats.syntax.all._ trait TestServiceFs2Grpc[F[_], A] { - def noStreaming(request: hello.world.TestMessage, ctx: A): F[hello.world.TestMessage] - def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[hello.world.TestMessage] + def noStreaming(request: hello.world.TestMessage, ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] + def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] def serverStreaming(request: hello.world.TestMessage, ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] def bothStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] } @@ -12,12 +12,12 @@ trait TestServiceFs2Grpc[F[_], A] { object TestServiceFs2Grpc extends _root_.fs2.grpc.GeneratedCompanion[TestServiceFs2Grpc] { def mkClient[F[_]: _root_.cats.effect.Async, A](dispatcher: _root_.cats.effect.std.Dispatcher[F], channel: _root_.io.grpc.Channel, mkMetadata: A => F[_root_.io.grpc.Metadata], clientOptions: _root_.fs2.grpc.client.ClientOptions): TestServiceFs2Grpc[F, A] = new TestServiceFs2Grpc[F, A] { - def noStreaming(request: hello.world.TestMessage, ctx: A): F[hello.world.TestMessage] = { + def noStreaming(request: hello.world.TestMessage, ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] = { mkMetadata(ctx).flatMap { m => _root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_NO_STREAMING, dispatcher, clientOptions).flatMap(_.unaryToUnaryCall(request, m)) } } - def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[hello.world.TestMessage] = { + def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] = { mkMetadata(ctx).flatMap { m => _root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, dispatcher, clientOptions).flatMap(_.streamingToUnaryCall(request, m)) } diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala index f9116816..b145105d 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala @@ -58,12 +58,10 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( private def sendSingleMessage(message: Request): F[Unit] = F.delay(call.sendMessage(message)) *> halfClose - // - - def unaryToUnaryCall(message: Request, headers: Metadata): F[Response] = + def unaryToUnaryCall(message: Request, headers: Metadata): F[(Response, Metadata)] = Fs2UnaryCallHandler.unary(call, options, message, headers) - def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] = + def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[(Response, Metadata)] = StreamOutput.client(call).flatMap { output => Fs2UnaryCallHandler.stream(call, options, dispatcher, messages, output, headers) } @@ -87,8 +85,6 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( } } - // - private def handleExitCase(cancelSucceed: Boolean): (ClientCall.Listener[Response], Resource.ExitCase) => F[Unit] = { case (_, Resource.ExitCase.Succeeded) => cancel("call done".some, None).whenA(cancelSucceed) case (_, Resource.ExitCase.Canceled) => cancel("call was cancelled".some, None) diff --git a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala index 70958ec5..e83d46f2 100644 --- a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala @@ -38,28 +38,28 @@ private[client] object Fs2UnaryCallHandler { object ReceiveState { def init[F[_]: Sync, R]( - callback: Either[Throwable, R] => Unit, + callback: Either[Throwable, (R, Metadata)] => Unit, pf: PartialFunction[StatusRuntimeException, Exception] ): F[Ref[SyncIO, ReceiveState[R]]] = Ref.in(new PendingMessage[R]({ - case r: Right[Throwable, R] => callback(r) + case r: Right[Throwable, (R, Metadata)] => callback(r) case Left(e: StatusRuntimeException) => callback(Left(pf.lift(e).getOrElse(e))) - case l: Left[Throwable, R] => callback(l) + case l: Left[Throwable, (R, Metadata)] => callback(l) })) } - class PendingMessage[R](callback: Either[Throwable, R] => Unit) extends ReceiveState[R] { + class PendingMessage[R](callback: Either[Throwable, (R, Metadata)] => Unit) extends ReceiveState[R] { def receive(message: R): PendingHalfClose[R] = new PendingHalfClose(callback, message) def sendError(error: Throwable): SyncIO[ReceiveState[R]] = SyncIO(callback(Left(error))).as(new Done[R]) } - class PendingHalfClose[R](callback: Either[Throwable, R] => Unit, message: R) extends ReceiveState[R] { + class PendingHalfClose[R](callback: Either[Throwable, (R, Metadata)] => Unit, message: R) extends ReceiveState[R] { def sendError(error: Throwable): SyncIO[ReceiveState[R]] = SyncIO(callback(Left(error))).as(new Done[R]) - def done: SyncIO[ReceiveState[R]] = SyncIO(callback(Right(message))).as(new Done[R]) + def done(trailers: Metadata): SyncIO[ReceiveState[R]] = SyncIO(callback(Right((message, trailers)))).as(new Done[R]) } class Done[R] extends ReceiveState[R] @@ -69,6 +69,7 @@ private[client] object Fs2UnaryCallHandler { signalReadiness: SyncIO[Unit] ): ClientCall.Listener[Response] = new ClientCall.Listener[Response] { + override def onMessage(message: Response): Unit = state.get .flatMap { @@ -90,7 +91,7 @@ private[client] object Fs2UnaryCallHandler { if (status.isOk) { state.get.flatMap { case expected: PendingHalfClose[Response] => - expected.done.flatMap(state.set) + expected.done(trailers).flatMap(state.set) case current: PendingMessage[Response] => current .sendError( @@ -115,13 +116,15 @@ private[client] object Fs2UnaryCallHandler { override def onReady(): Unit = signalReadiness.unsafeRunSync() } + def unary[F[_], Request, Response]( - call: ClientCall[Request, Response], - options: ClientOptions, - message: Request, - headers: Metadata - )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => + call: ClientCall[Request, Response], + options: ClientOptions, + message: Request, + headers: Metadata + )(implicit F: Async[F]): F[(Response, Metadata)] = F.async[(Response, Metadata)] { cb => ReceiveState.init(cb, options.errorAdapter).map { state => + call.start(mkListener[Response](state, SyncIO.unit), headers) // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. @@ -139,8 +142,8 @@ private[client] object Fs2UnaryCallHandler { messages: Stream[F, Request], output: StreamOutput[F, Request], headers: Metadata - )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => - ReceiveState.init(cb, options.errorAdapter).flatMap { state => + )(implicit F: Async[F]): F[(Response, Metadata)] = F.async[(Response, Metadata)] { cb => + ReceiveState.init[F, Response](cb, options.errorAdapter).flatMap { state => call.start(mkListener[Response](state, output.onReadySync(dispatcher)), headers) // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala index 37064909..e2f5f459 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala @@ -36,9 +36,9 @@ class Fs2ServerCallHandler[F[_]: Async] private ( ) { def unaryToUnaryCall[Request, Response]( - implementation: (Request, Metadata) => F[Response] + implementation: (Request, Metadata) => F[(Response, Metadata)] ): ServerCallHandler[Request, Response] = - Fs2UnaryServerCallHandler.unary(implementation, options, dispatcher) + Fs2UnaryServerCallHandler.unary((req, meta) => implementation(req, meta), options, dispatcher) def unaryToStreamingCall[Request, Response]( implementation: (Request, Metadata) => Stream[F, Response] @@ -46,7 +46,7 @@ class Fs2ServerCallHandler[F[_]: Async] private ( Fs2UnaryServerCallHandler.stream(implementation, options, dispatcher) def streamingToUnaryCall[Request, Response]( - implementation: (Stream[F, Request], Metadata) => F[Response] + implementation: (Stream[F, Request], Metadata) => F[(Response, Metadata)] ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options)) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala index ae1ee98d..10317a57 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala @@ -49,30 +49,43 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { } } - private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendSingleMessage + private def handleUnaryResponse(headers: Metadata, response: F[(Response, Metadata)])(implicit F: Sync[F]): F[Metadata] = { + for { + _ <- call.sendHeaders(headers) + _ <- call.request(1) + responseWithTrailers <- response + _ <- call.sendSingleMessage(responseWithTrailers._1) + } yield responseWithTrailers._2 + } private def handleStreamResponse(headers: Metadata, sendResponse: Stream[F, Nothing])(implicit F: Sync[F]): F[Unit] = call.sendHeaders(headers) *> call.request(1) *> sendResponse.compile.drain - private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = { + private def unsafeRun(f: F[Metadata])(implicit F: Async[F]): Unit = { val bracketed = F.handleError { F.guaranteeCase(f) { - case Outcome.Succeeded(_) => call.closeStream(Status.OK, new Metadata()) + case Outcome.Succeeded(mdF) => + for { + md <- mdF + _ <- call.closeStream(Status.OK, md) + } yield () + case Outcome.Canceled() => call.closeStream(Status.CANCELLED, new Metadata()) case Outcome.Errored(t) => reportError(t) - } + }.void }(_ => ()) // Exceptions are reported by closing the call dispatcher.unsafeRunAndForget(F.race(bracketed, isCancelled.get)) } - def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])(implicit + def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[(Response, Metadata)])(implicit F: Async[F] ): Unit = - unsafeRun(handleUnaryResponse(headers, implementation(source))) + unsafeRun( + handleUnaryResponse(headers, implementation(source)) + ) def unsafeStreamResponse( streamOutput: StreamOutput[F, Response], @@ -81,5 +94,7 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { )(implicit F: Async[F] ): Unit = - unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source)))) + unsafeRun( + handleStreamResponse(headers, streamOutput.writeStream(implementation(source))).as(new Metadata()) + ) } diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala index 008f7e40..3aa50c59 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala @@ -21,11 +21,12 @@ package fs2.grpc.server.internal -import cats.effect._ +import cats.effect.* +import cats.syntax.all.* import cats.effect.std.Dispatcher -import fs2._ +import fs2.* import fs2.grpc.server.ServerCallOptions -import io.grpc._ +import io.grpc.* private[server] object Fs2ServerCall { type Cancel = SyncIO[Unit] @@ -64,15 +65,17 @@ private[server] final class Fs2ServerCall[Request, Response]( } .stream .compile - .drain, + .drain.as(new Metadata()), dispatcher ) - def unary[F[_]](response: F[Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = + def unary[F[_]](response: F[(Response, Metadata)], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = run( F.map(response) { message => call.sendHeaders(new Metadata()) - call.sendMessage(message) + val (response, trailers) = message + call.sendMessage(response) + trailers }, dispatcher ) @@ -83,21 +86,24 @@ private[server] final class Fs2ServerCall[Request, Response]( def close(status: Status, metadata: Metadata): SyncIO[Unit] = SyncIO(call.close(status, metadata)) - private def run[F[_]](completed: F[Unit], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = { + private def run[F[_]](completed: F[Metadata], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = { SyncIO { val cancel = dispatcher.unsafeRunCancelable( F.handleError { F.guaranteeCase(completed) { - case Outcome.Succeeded(_) => close(Status.OK, new Metadata()).to[F] + case Outcome.Succeeded(trailersF) => trailersF.flatMap( + trailers => close(Status.OK, trailers).to[F] + ) case Outcome.Errored(e) => handleError(e).to[F] case Outcome.Canceled() => close(Status.CANCELLED, new Metadata()).to[F] - } + }.void }(_ => ()) ) SyncIO(cancel()).void } } + private def handleError(t: Throwable): SyncIO[Unit] = t match { case ex: StatusException => close(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata())) case ex: StatusRuntimeException => close(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata())) diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala index 1af54faa..9594cbda 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala @@ -90,7 +90,7 @@ private[server] object Fs2UnaryServerCallHandler { } def unary[F[_]: Sync, Request, Response]( - impl: (Request, Metadata) => F[Response], + impl: (Request, Metadata) => F[(Response, Metadata)], options: ServerOptions, dispatcher: Dispatcher[F] ): ServerCallHandler[Request, Response] = diff --git a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala index a6a430c7..329590a6 100644 --- a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala @@ -47,10 +47,11 @@ class ClientSuite extends Fs2GrpcSuite { assertEquals(result.value, None) dummy.listener.get.onClose(Status.OK, new Metadata()) + //TODO // Check that call completes after status tc.tick() - assertEquals(result.value, Some(Success(5))) + assertEquals(result.value.map(_.map(_._1)), Some(Success(5))) assertEquals(dummy.messagesSent.size, 1) assertEquals(dummy.requested, 2) @@ -139,7 +140,7 @@ class ClientSuite extends Fs2GrpcSuite { // Check that call completes after status tc.tick() - assertEquals(result.value, Some(Success(5))) + assertEquals(result.value.map(_.map(_._1)), Some(Success(5))) assertEquals(dummy.messagesSent.size, 3) assertEquals(dummy.requested, 2) @@ -192,7 +193,7 @@ class ClientSuite extends Fs2GrpcSuite { // Check that call completes after status tc.tick() - assertEquals(result.value, Some(Success(5))) + assertEquals(result.value.map(_.map(_._1)), Some(Success(5))) assertEquals(dummy.messagesSent.size, 0) assertEquals(dummy.requested, 2) diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 2eabf325..ac122494 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -43,7 +43,7 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), options, d) val listener = handler.startCall(dummy, new Metadata()) listener.onMessage("123") @@ -58,7 +58,7 @@ class ServerSuite extends Fs2GrpcSuite { runTest("cancellation for unaryToUnary") { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), ServerOptions.default, d) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), ServerOptions.default, d) val listener = handler.startCall(dummy, new Metadata()) listener.onCancel() @@ -71,7 +71,7 @@ class ServerSuite extends Fs2GrpcSuite { runTest("cancellation on the fly for unaryToUnary") { (tc, d) => val dummy = new DummyServerCall val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]( - (req, _) => IO(req.length).delayBy(10.seconds), + (req, _) => IO((req.length, new Metadata())).delayBy(10.seconds), ServerOptions.default, d ) @@ -94,7 +94,7 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), options, d) val listener = handler.startCall(dummy, new Metadata()) listener.onMessage("123") @@ -112,7 +112,7 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO((req.length, new Metadata())), options, d) val listener = handler.startCall(dummy, new Metadata()) listener.onHalfClose() @@ -320,7 +320,7 @@ class ServerSuite extends Fs2GrpcSuite { val dummy = new DummyServerCall val handler = Fs2ServerCallHandler[IO](d, so) - .streamingToUnaryCall[String, Int]((req, _) => implementation(req)) + .streamingToUnaryCall[String, Int]((req, _) => implementation(req).map((_, new Metadata()))) val listener = handler.startCall(dummy, new Metadata()) listener.onMessage("ab") @@ -339,7 +339,7 @@ class ServerSuite extends Fs2GrpcSuite { val deferred = d.unsafeRunSync(Deferred[IO, Unit]) val handler = Fs2ServerCallHandler[IO](d, ServerOptions.default) .streamingToUnaryCall[String, Int]((requests, _) => { - requests.evalMap(_ => deferred.get).compile.drain.as(1) + requests.evalMap(_ => deferred.get).compile.drain.as((1, new Metadata())) }) val listener = handler.startCall(dummy, new Metadata())