diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2AbstractServicePrinter.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2AbstractServicePrinter.scala new file mode 100644 index 00000000..1877b23d --- /dev/null +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2AbstractServicePrinter.scala @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.codegen + +import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor} +import fs2.grpc.codegen.Fs2AbstractServicePrinter.constants.{ + Async, + Channel, + ClientOptions, + Companion, + Ctx, + Dispatcher, + Fs2ClientCall, + Fs2ServerCallHandler, + Metadata, + ServerOptions, + ServerServiceDefinition, + Stream +} +import scalapb.compiler.{DescriptorImplicits, FunctionalPrinter} +import scalapb.compiler.FunctionalPrinter.PrinterEndo + +abstract class Fs2AbstractServicePrinter extends Fs2ServicePrinter { + + val service: ServiceDescriptor + val serviceSuffix: String + val di: DescriptorImplicits + + import di._ + + private[this] val serviceName: String = service.name + private[this] val serviceNameFs2: String = s"$serviceName${serviceSuffix}" + private[this] val servicePkgName: String = service.getFile.scalaPackage.fullName + + protected def serviceMethodSignature(method: MethodDescriptor): String + + protected[this] def handleMethod(method: MethodDescriptor): String + + private[this] def createClientCall(method: MethodDescriptor) = { + val basicClientCall = + s"$Fs2ClientCall[F](channel, ${method.grpcDescriptor.fullName}, dispatcher, clientOptions)" + if (method.isServerStreaming) + s"$Stream.eval($basicClientCall)" + else + basicClientCall + } + + private[this] def serviceMethodImplementation(method: MethodDescriptor): PrinterEndo = { p => + val mkMetadata = if (method.isServerStreaming) s"$Stream.eval(mkMetadata(ctx))" else "mkMetadata(ctx)" + + p.add(serviceMethodSignature(method) + " = {") + .indent + .add(s"$mkMetadata.flatMap { m =>") + .indent + .add(s"${createClientCall(method)}.flatMap(_.${handleMethod(method)}(request, m))") + .outdent + .add("}") + .outdent + .add("}") + } + + private[this] def serviceBindingImplementation(method: MethodDescriptor): PrinterEndo = { p => + val inType = method.inputType.scalaType + val outType = method.outputType.scalaType + val descriptor = method.grpcDescriptor.fullName + val handler = s"$Fs2ServerCallHandler[F](dispatcher, serverOptions).${handleMethod(method)}[$inType, $outType]" + + val serviceCall = s"serviceImpl.${method.name}" + val eval = if (method.isServerStreaming) s"$Stream.eval(mkCtx(m))" else "mkCtx(m)" + + p.add(s".addMethod($descriptor, $handler((r, m) => $eval.flatMap($serviceCall(r, _))))") + } + + private[this] def serviceMethods: PrinterEndo = _.seq(service.methods.map(serviceMethodSignature)) + + private[this] def serviceMethodImplementations: PrinterEndo = + _.call(service.methods.map(serviceMethodImplementation): _*) + + private[this] def serviceBindingImplementations: PrinterEndo = + _.indent + .add(s".builder(${service.grpcDescriptor.fullName})") + .call(service.methods.map(serviceBindingImplementation): _*) + .add(".build()") + .outdent + + private[this] def serviceTrait: PrinterEndo = + _.add(s"trait $serviceNameFs2[F[_], $Ctx] {").indent.call(serviceMethods).outdent.add("}") + + private[this] def serviceObject: PrinterEndo = + _.add(s"object $serviceNameFs2 extends $Companion[$serviceNameFs2] {").indent.newline + .call(serviceClient) + .newline + .call(serviceBinding) + .outdent + .newline + .add("}") + + private[this] def serviceClient: PrinterEndo = { + _.add( + s"def mkClient[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], channel: $Channel, mkMetadata: $Ctx => F[$Metadata], clientOptions: $ClientOptions): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {" + ).indent + .call(serviceMethodImplementations) + .outdent + .add("}") + } + + private[this] def serviceBinding: PrinterEndo = { + _.add( + s"protected def serviceBinding[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], serviceImpl: $serviceNameFs2[F, $Ctx], mkCtx: $Metadata => F[$Ctx], serverOptions: $ServerOptions): $ServerServiceDefinition = {" + ).indent + .add(s"$ServerServiceDefinition") + .call(serviceBindingImplementations) + .outdent + .add("}") + } + + // / + + def printService(printer: FunctionalPrinter): FunctionalPrinter = { + printer + .add(s"package $servicePkgName", "", "import _root_.cats.syntax.all._", "") + .call(serviceTrait) + .newline + .call(serviceObject) + } +} + +object Fs2AbstractServicePrinter { + private[codegen] object constants { + + private val effPkg = "_root_.cats.effect" + private val fs2Pkg = "_root_.fs2" + private val fs2grpcPkg = "_root_.fs2.grpc" + private val grpcPkg = "_root_.io.grpc" + + // / + + val Ctx = "A" + + val Async = s"$effPkg.Async" + val Resource = s"$effPkg.Resource" + val Dispatcher = s"$effPkg.std.Dispatcher" + val Stream = s"$fs2Pkg.Stream" + + val Fs2ServerCallHandler = s"$fs2grpcPkg.server.Fs2ServerCallHandler" + val Fs2ClientCall = s"$fs2grpcPkg.client.Fs2ClientCall" + val ClientOptions = s"$fs2grpcPkg.client.ClientOptions" + val ServerOptions = s"$fs2grpcPkg.server.ServerOptions" + val Companion = s"$fs2grpcPkg.GeneratedCompanion" + + val ServerServiceDefinition = s"$grpcPkg.ServerServiceDefinition" + val Channel = s"$grpcPkg.Channel" + val Metadata = s"$grpcPkg.Metadata" + + } + +} diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2CodeGenerator.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2CodeGenerator.scala index 54a6a997..44601485 100644 --- a/codegen/src/main/scala/fs2/grpc/codegen/Fs2CodeGenerator.scala +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2CodeGenerator.scala @@ -21,32 +21,55 @@ package fs2.grpc.codegen -import com.google.protobuf.Descriptors.FileDescriptor +import com.google.protobuf.Descriptors.{FileDescriptor, ServiceDescriptor} import com.google.protobuf.ExtensionRegistry import com.google.protobuf.compiler.PluginProtos import protocgen.{CodeGenApp, CodeGenRequest, CodeGenResponse} import scalapb.compiler.{DescriptorImplicits, FunctionalPrinter, GeneratorParams} import scalapb.options.Scalapb -import scala.jdk.CollectionConverters._ + +import scala.jdk.CollectionConverters.* final case class Fs2Params(serviceSuffix: String = "Fs2Grpc") object Fs2CodeGenerator extends CodeGenApp { + private def generateServiceFile( + file: FileDescriptor, + service: ServiceDescriptor, + serviceSuffix: String, + di: DescriptorImplicits, + p: ServiceDescriptor => Fs2ServicePrinter + ): PluginProtos.CodeGeneratorResponse.File = { + import di.{ExtendedServiceDescriptor, ExtendedFileDescriptor} + + val code = p(service).printService(FunctionalPrinter()).result() + val b = PluginProtos.CodeGeneratorResponse.File.newBuilder() + b.setName(file.scalaDirectory + "/" + service.name + s"$serviceSuffix.scala") + b.setContent(code) + b.build + } + def generateServiceFiles( file: FileDescriptor, fs2params: Fs2Params, di: DescriptorImplicits ): Seq[PluginProtos.CodeGeneratorResponse.File] = { - file.getServices.asScala.map { service => - import di.{ExtendedServiceDescriptor, ExtendedFileDescriptor} - - val p = new Fs2GrpcServicePrinter(service, fs2params.serviceSuffix, di) - val code = p.printService(FunctionalPrinter()).result() - val b = PluginProtos.CodeGeneratorResponse.File.newBuilder() - b.setName(file.scalaDirectory + "/" + service.name + s"${fs2params.serviceSuffix}.scala") - b.setContent(code) - b.build + file.getServices.asScala.flatMap { service => + generateServiceFile( + file, + service, + fs2params.serviceSuffix + "Trailers", + di, + new Fs2GrpcExhaustiveTrailersServicePrinter(_, fs2params.serviceSuffix + "Trailers", di) + ) :: + generateServiceFile( + file, + service, + fs2params.serviceSuffix, + di, + new Fs2GrpcServicePrinter(_, fs2params.serviceSuffix, di) + ) :: Nil }.toSeq } @@ -66,7 +89,9 @@ object Fs2CodeGenerator extends CodeGenApp { parseParameters(request.parameter) match { case Right((params, fs2params)) => val implicits = DescriptorImplicits.fromCodeGenRequest(params, request) - val srvFiles = request.filesToGenerate.flatMap(generateServiceFiles(_, fs2params, implicits)) + val srvFiles = request.filesToGenerate.flatMap( + generateServiceFiles(_, fs2params, implicits) + ) CodeGenResponse.succeed( srvFiles, Set(PluginProtos.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL) diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcExhaustiveTrailersServicePrinter.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcExhaustiveTrailersServicePrinter.scala new file mode 100644 index 00000000..67616c26 --- /dev/null +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcExhaustiveTrailersServicePrinter.scala @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.codegen + +import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor} +import scalapb.compiler.{DescriptorImplicits, StreamType} + +class Fs2GrpcExhaustiveTrailersServicePrinter( + val service: ServiceDescriptor, + val serviceSuffix: String, + val di: DescriptorImplicits +) extends Fs2AbstractServicePrinter { + import fs2.grpc.codegen.Fs2AbstractServicePrinter.constants._ + import di._ + + override protected def serviceMethodSignature(method: MethodDescriptor): String = { + + val scalaInType = method.inputType.scalaType + val scalaOutType = method.outputType.scalaType + val ctx = s"ctx: $Ctx" + + s"def ${method.name}" + (method.streamType match { + case StreamType.Unary => s"(request: $scalaInType, $ctx): F[($scalaOutType, $Metadata)]" + case StreamType.ClientStreaming => s"(request: $Stream[F, $scalaInType], $ctx): F[($scalaOutType, $Metadata)]" + case StreamType.ServerStreaming => s"(request: $scalaInType, $ctx): $Stream[F, $scalaOutType]" + case StreamType.Bidirectional => s"(request: $Stream[F, $scalaInType], $ctx): $Stream[F, $scalaOutType]" + }) + } + + override protected def handleMethod(method: MethodDescriptor): String = { + method.streamType match { + case StreamType.Unary => "unaryToUnaryCallTrailers" + case StreamType.ClientStreaming => "streamingToUnaryCallTrailers" + case StreamType.ServerStreaming => "unaryToStreamingCall" + case StreamType.Bidirectional => "streamingToStreamingCall" + } + } + +} diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala index 51781e44..0272af70 100644 --- a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala @@ -22,18 +22,14 @@ package fs2.grpc.codegen import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor} -import scalapb.compiler.FunctionalPrinter.PrinterEndo -import scalapb.compiler.{DescriptorImplicits, FunctionalPrinter, StreamType} +import scalapb.compiler.{DescriptorImplicits, StreamType} -class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, di: DescriptorImplicits) { +class Fs2GrpcServicePrinter(val service: ServiceDescriptor, val serviceSuffix: String, val di: DescriptorImplicits) + extends Fs2AbstractServicePrinter { + import fs2.grpc.codegen.Fs2AbstractServicePrinter.constants._ import di._ - import Fs2GrpcServicePrinter.constants._ - private[this] val serviceName: String = service.name - private[this] val serviceNameFs2: String = s"$serviceName$serviceSuffix" - private[this] val servicePkgName: String = service.getFile.scalaPackage.fullName - - private[this] def serviceMethodSignature(method: MethodDescriptor) = { + override protected def serviceMethodSignature(method: MethodDescriptor): String = { val scalaInType = method.inputType.scalaType val scalaOutType = method.outputType.scalaType @@ -47,7 +43,7 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d }) } - private[this] def handleMethod(method: MethodDescriptor) = { + override protected def handleMethod(method: MethodDescriptor): String = { method.streamType match { case StreamType.Unary => "unaryToUnaryCall" case StreamType.ClientStreaming => "streamingToUnaryCall" @@ -56,93 +52,6 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d } } - private[this] def createClientCall(method: MethodDescriptor) = { - val basicClientCall = - s"$Fs2ClientCall[F](channel, ${method.grpcDescriptor.fullName}, dispatcher, clientOptions)" - if (method.isServerStreaming) - s"$Stream.eval($basicClientCall)" - else - basicClientCall - } - - private[this] def serviceMethodImplementation(method: MethodDescriptor): PrinterEndo = { p => - val mkMetadata = if (method.isServerStreaming) s"$Stream.eval(mkMetadata(ctx))" else "mkMetadata(ctx)" - - p.add(serviceMethodSignature(method) + " = {") - .indent - .add(s"$mkMetadata.flatMap { m =>") - .indent - .add(s"${createClientCall(method)}.flatMap(_.${handleMethod(method)}(request, m))") - .outdent - .add("}") - .outdent - .add("}") - } - - private[this] def serviceBindingImplementation(method: MethodDescriptor): PrinterEndo = { p => - val inType = method.inputType.scalaType - val outType = method.outputType.scalaType - val descriptor = method.grpcDescriptor.fullName - val handler = s"$Fs2ServerCallHandler[F](dispatcher, serverOptions).${handleMethod(method)}[$inType, $outType]" - - val serviceCall = s"serviceImpl.${method.name}" - val eval = if (method.isServerStreaming) s"$Stream.eval(mkCtx(m))" else "mkCtx(m)" - - p.add(s".addMethod($descriptor, $handler((r, m) => $eval.flatMap($serviceCall(r, _))))") - } - - private[this] def serviceMethods: PrinterEndo = _.seq(service.methods.map(serviceMethodSignature)) - - private[this] def serviceMethodImplementations: PrinterEndo = - _.call(service.methods.map(serviceMethodImplementation): _*) - - private[this] def serviceBindingImplementations: PrinterEndo = - _.indent - .add(s".builder(${service.grpcDescriptor.fullName})") - .call(service.methods.map(serviceBindingImplementation): _*) - .add(".build()") - .outdent - - private[this] def serviceTrait: PrinterEndo = - _.add(s"trait $serviceNameFs2[F[_], $Ctx] {").indent.call(serviceMethods).outdent.add("}") - - private[this] def serviceObject: PrinterEndo = - _.add(s"object $serviceNameFs2 extends $Companion[$serviceNameFs2] {").indent.newline - .call(serviceClient) - .newline - .call(serviceBinding) - .outdent - .newline - .add("}") - - private[this] def serviceClient: PrinterEndo = { - _.add( - s"def mkClient[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], channel: $Channel, mkMetadata: $Ctx => F[$Metadata], clientOptions: $ClientOptions): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {" - ).indent - .call(serviceMethodImplementations) - .outdent - .add("}") - } - - private[this] def serviceBinding: PrinterEndo = { - _.add( - s"protected def serviceBinding[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], serviceImpl: $serviceNameFs2[F, $Ctx], mkCtx: $Metadata => F[$Ctx], serverOptions: $ServerOptions): $ServerServiceDefinition = {" - ).indent - .add(s"$ServerServiceDefinition") - .call(serviceBindingImplementations) - .outdent - .add("}") - } - - // / - - def printService(printer: FunctionalPrinter): FunctionalPrinter = { - printer - .add(s"package $servicePkgName", "", s"import _root_.cats.syntax.all.${service.getFile.V.WildcardImport}", "") - .call(serviceTrait) - .newline - .call(serviceObject) - } } object Fs2GrpcServicePrinter { diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2ServicePrinter.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2ServicePrinter.scala new file mode 100644 index 00000000..5bf80270 --- /dev/null +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2ServicePrinter.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.codegen + +import scalapb.compiler.FunctionalPrinter + +trait Fs2ServicePrinter { + def printService(printer: FunctionalPrinter): FunctionalPrinter +} diff --git a/e2e/src/test/resources/TestServiceFs2GrpcTrailers.scala.txt b/e2e/src/test/resources/TestServiceFs2GrpcTrailers.scala.txt new file mode 100644 index 00000000..8341d528 --- /dev/null +++ b/e2e/src/test/resources/TestServiceFs2GrpcTrailers.scala.txt @@ -0,0 +1,47 @@ +package hello.world + +import _root_.cats.syntax.all._ + +trait TestServiceFs2GrpcTrailers[F[_], A] { + def noStreaming(request: hello.world.TestMessage, ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] + def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] + def serverStreaming(request: hello.world.TestMessage, ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] + def bothStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] +} + +object TestServiceFs2GrpcTrailers extends _root_.fs2.grpc.GeneratedCompanion[TestServiceFs2GrpcTrailers] { + + def mkClient[F[_]: _root_.cats.effect.Async, A](dispatcher: _root_.cats.effect.std.Dispatcher[F], channel: _root_.io.grpc.Channel, mkMetadata: A => F[_root_.io.grpc.Metadata], clientOptions: _root_.fs2.grpc.client.ClientOptions): TestServiceFs2GrpcTrailers[F, A] = new TestServiceFs2GrpcTrailers[F, A] { + def noStreaming(request: hello.world.TestMessage, ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] = { + mkMetadata(ctx).flatMap { m => + _root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_NO_STREAMING, dispatcher, clientOptions).flatMap(_.unaryToUnaryCallTrailers(request, m)) + } + } + def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[(hello.world.TestMessage, _root_.io.grpc.Metadata)] = { + mkMetadata(ctx).flatMap { m => + _root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, dispatcher, clientOptions).flatMap(_.streamingToUnaryCallTrailers(request, m)) + } + } + def serverStreaming(request: hello.world.TestMessage, ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] = { + _root_.fs2.Stream.eval(mkMetadata(ctx)).flatMap { m => + _root_.fs2.Stream.eval(_root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING, dispatcher, clientOptions)).flatMap(_.unaryToStreamingCall(request, m)) + } + } + def bothStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] = { + _root_.fs2.Stream.eval(mkMetadata(ctx)).flatMap { m => + _root_.fs2.Stream.eval(_root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING, dispatcher, clientOptions)).flatMap(_.streamingToStreamingCall(request, m)) + } + } + } + + protected def serviceBinding[F[_]: _root_.cats.effect.Async, A](dispatcher: _root_.cats.effect.std.Dispatcher[F], serviceImpl: TestServiceFs2GrpcTrailers[F, A], mkCtx: _root_.io.grpc.Metadata => F[A], serverOptions: _root_.fs2.grpc.server.ServerOptions): _root_.io.grpc.ServerServiceDefinition = { + _root_.io.grpc.ServerServiceDefinition + .builder(hello.world.TestServiceGrpc.SERVICE) + .addMethod(hello.world.TestServiceGrpc.METHOD_NO_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).unaryToUnaryCallTrailers[hello.world.TestMessage, hello.world.TestMessage]((r, m) => mkCtx(m).flatMap(serviceImpl.noStreaming(r, _)))) + .addMethod(hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).streamingToUnaryCallTrailers[hello.world.TestMessage, hello.world.TestMessage]((r, m) => mkCtx(m).flatMap(serviceImpl.clientStreaming(r, _)))) + .addMethod(hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).unaryToStreamingCall[hello.world.TestMessage, hello.world.TestMessage]((r, m) => _root_.fs2.Stream.eval(mkCtx(m)).flatMap(serviceImpl.serverStreaming(r, _)))) + .addMethod(hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).streamingToStreamingCall[hello.world.TestMessage, hello.world.TestMessage]((r, m) => _root_.fs2.Stream.eval(mkCtx(m)).flatMap(serviceImpl.bothStreaming(r, _)))) + .build() + } + +} \ No newline at end of file diff --git a/e2e/src/test/scala/fs2/grpc/e2e/CodegenSpec.scala b/e2e/src/test/scala/fs2/grpc/e2e/CodegenSpec.scala index 60338259..aa44d3e0 100644 --- a/e2e/src/test/scala/fs2/grpc/e2e/CodegenSpec.scala +++ b/e2e/src/test/scala/fs2/grpc/e2e/CodegenSpec.scala @@ -42,8 +42,21 @@ class Fs2CodeGeneratorSpec extends munit.FunSuite { } + test("code generator outputs correct service file for trailers") { + + val testFileName = "TestServiceFs2GrpcTrailers.scala" + val reference = Source.fromResource(s"${testFileName}.txt").getLines().mkString("\n") + val generated = Source.fromFile(new File(sourcesGenerated, testFileName)).getLines().mkString("\n") + + assertEquals(generated, reference) + + } + test("implicit of companion resolves") { implicitly[GeneratedCompanion[TestServiceFs2Grpc]] } + test("implicit of companion resolves trailers") { + implicitly[GeneratedCompanion[TestServiceFs2GrpcTrailers]] + } } diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala index f9116816..258ffc30 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala @@ -61,13 +61,21 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( // def unaryToUnaryCall(message: Request, headers: Metadata): F[Response] = + Fs2UnaryCallHandler.unary(call, options, message, headers).map(_._1) + + def unaryToUnaryCallTrailers(message: Request, headers: Metadata): F[(Response, Metadata)] = Fs2UnaryCallHandler.unary(call, options, message, headers) - def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] = + def streamingToUnaryCallTrailers(messages: Stream[F, Request], headers: Metadata): F[(Response, Metadata)] = StreamOutput.client(call).flatMap { output => Fs2UnaryCallHandler.stream(call, options, dispatcher, messages, output, headers) } + def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] = + StreamOutput.client(call).flatMap { output => + Fs2UnaryCallHandler.stream(call, options, dispatcher, messages, output, headers).map(_._1) + } + def unaryToStreamingCall(message: Request, md: Metadata): Stream[F, Response] = Stream .resource(mkStreamListenerR(md, SyncIO.unit)) diff --git a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala index 70958ec5..11da81b0 100644 --- a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala @@ -38,28 +38,29 @@ private[client] object Fs2UnaryCallHandler { object ReceiveState { def init[F[_]: Sync, R]( - callback: Either[Throwable, R] => Unit, + callback: Either[Throwable, (R, Metadata)] => Unit, pf: PartialFunction[StatusRuntimeException, Exception] ): F[Ref[SyncIO, ReceiveState[R]]] = Ref.in(new PendingMessage[R]({ - case r: Right[Throwable, R] => callback(r) + case r: Right[Throwable, (R, Metadata)] => callback(r) case Left(e: StatusRuntimeException) => callback(Left(pf.lift(e).getOrElse(e))) - case l: Left[Throwable, R] => callback(l) + case l: Left[Throwable, (R, Metadata)] => callback(l) })) } - class PendingMessage[R](callback: Either[Throwable, R] => Unit) extends ReceiveState[R] { + class PendingMessage[R](callback: Either[Throwable, (R, Metadata)] => Unit) extends ReceiveState[R] { def receive(message: R): PendingHalfClose[R] = new PendingHalfClose(callback, message) def sendError(error: Throwable): SyncIO[ReceiveState[R]] = SyncIO(callback(Left(error))).as(new Done[R]) } - class PendingHalfClose[R](callback: Either[Throwable, R] => Unit, message: R) extends ReceiveState[R] { + class PendingHalfClose[R](callback: Either[Throwable, (R, Metadata)] => Unit, message: R) extends ReceiveState[R] { def sendError(error: Throwable): SyncIO[ReceiveState[R]] = SyncIO(callback(Left(error))).as(new Done[R]) + def done: SyncIO[ReceiveState[R]] = SyncIO(callback(Right((message, new Metadata())))).as(new Done[R]) - def done: SyncIO[ReceiveState[R]] = SyncIO(callback(Right(message))).as(new Done[R]) + def done(trailers: Metadata): SyncIO[ReceiveState[R]] = SyncIO(callback(Right((message, trailers)))).as(new Done[R]) } class Done[R] extends ReceiveState[R] @@ -69,6 +70,7 @@ private[client] object Fs2UnaryCallHandler { signalReadiness: SyncIO[Unit] ): ClientCall.Listener[Response] = new ClientCall.Listener[Response] { + override def onMessage(message: Response): Unit = state.get .flatMap { @@ -90,7 +92,7 @@ private[client] object Fs2UnaryCallHandler { if (status.isOk) { state.get.flatMap { case expected: PendingHalfClose[Response] => - expected.done.flatMap(state.set) + expected.done(trailers).flatMap(state.set) case current: PendingMessage[Response] => current .sendError( @@ -120,7 +122,7 @@ private[client] object Fs2UnaryCallHandler { options: ClientOptions, message: Request, headers: Metadata - )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => + )(implicit F: Async[F]): F[(Response, Metadata)] = F.async[(Response, Metadata)] { cb => ReceiveState.init(cb, options.errorAdapter).map { state => call.start(mkListener[Response](state, SyncIO.unit), headers) // Initially ask for two responses from flow-control so that if a misbehaving server @@ -139,8 +141,8 @@ private[client] object Fs2UnaryCallHandler { messages: Stream[F, Request], output: StreamOutput[F, Request], headers: Metadata - )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => - ReceiveState.init(cb, options.errorAdapter).flatMap { state => + )(implicit F: Async[F]): F[(Response, Metadata)] = F.async[(Response, Metadata)] { cb => + ReceiveState.init[F, Response](cb, options.errorAdapter).flatMap { state => call.start(mkListener[Response](state, output.onReadySync(dispatcher)), headers) // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala index 37064909..7abc4933 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala @@ -37,6 +37,15 @@ class Fs2ServerCallHandler[F[_]: Async] private ( def unaryToUnaryCall[Request, Response]( implementation: (Request, Metadata) => F[Response] + ): ServerCallHandler[Request, Response] = + Fs2UnaryServerCallHandler.unary( + (req, meta) => implementation(req, meta).map((_, new Metadata())), + options, + dispatcher + ) + + def unaryToUnaryCallTrailers[Request, Response]( + implementation: (Request, Metadata) => F[(Response, Metadata)] ): ServerCallHandler[Request, Response] = Fs2UnaryServerCallHandler.unary(implementation, options, dispatcher) @@ -45,12 +54,22 @@ class Fs2ServerCallHandler[F[_]: Async] private ( ): ServerCallHandler[Request, Response] = Fs2UnaryServerCallHandler.stream(implementation, options, dispatcher) + def streamingToUnaryCallTrailers[Request, Response]( + implementation: (Stream[F, Request], Metadata) => F[(Response, Metadata)] + ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options)) + listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers)) + listener + } + } + def streamingToUnaryCall[Request, Response]( implementation: (Stream[F, Request], Metadata) => F[Response] ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options)) - listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers)) + listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers).map((_, new Metadata()))) listener } } diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala index ae1ee98d..f6c40c3c 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala @@ -49,30 +49,45 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { } } - private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendSingleMessage + private def handleUnaryResponse(headers: Metadata, response: F[(Response, Metadata)])(implicit + F: Sync[F] + ): F[Metadata] = { + for { + _ <- call.sendHeaders(headers) + _ <- call.request(1) + responseWithTrailers <- response + _ <- call.sendSingleMessage(responseWithTrailers._1) + } yield responseWithTrailers._2 + } private def handleStreamResponse(headers: Metadata, sendResponse: Stream[F, Nothing])(implicit F: Sync[F]): F[Unit] = call.sendHeaders(headers) *> call.request(1) *> sendResponse.compile.drain - private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = { + private def unsafeRun(f: F[Metadata])(implicit F: Async[F]): Unit = { val bracketed = F.handleError { F.guaranteeCase(f) { - case Outcome.Succeeded(_) => call.closeStream(Status.OK, new Metadata()) + case Outcome.Succeeded(mdF) => + for { + md <- mdF + _ <- call.closeStream(Status.OK, md) + } yield () + case Outcome.Canceled() => call.closeStream(Status.CANCELLED, new Metadata()) case Outcome.Errored(t) => reportError(t) - } + }.void }(_ => ()) // Exceptions are reported by closing the call dispatcher.unsafeRunAndForget(F.race(bracketed, isCancelled.get)) } - def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])(implicit + def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[(Response, Metadata)])(implicit F: Async[F] ): Unit = - unsafeRun(handleUnaryResponse(headers, implementation(source))) + unsafeRun( + handleUnaryResponse(headers, implementation(source)) + ) def unsafeStreamResponse( streamOutput: StreamOutput[F, Response], @@ -81,5 +96,7 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { )(implicit F: Async[F] ): Unit = - unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source)))) + unsafeRun( + handleStreamResponse(headers, streamOutput.writeStream(implementation(source))).as(new Metadata()) + ) } diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala index 008f7e40..3facc019 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala @@ -22,6 +22,7 @@ package fs2.grpc.server.internal import cats.effect._ +import cats.syntax.all._ import cats.effect.std.Dispatcher import fs2._ import fs2.grpc.server.ServerCallOptions @@ -64,15 +65,18 @@ private[server] final class Fs2ServerCall[Request, Response]( } .stream .compile - .drain, + .drain + .as(new Metadata()), dispatcher ) - def unary[F[_]](response: F[Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = + def unary[F[_]](response: F[(Response, Metadata)], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = run( F.map(response) { message => call.sendHeaders(new Metadata()) - call.sendMessage(message) + val (response, trailers) = message + call.sendMessage(response) + trailers }, dispatcher ) @@ -83,15 +87,15 @@ private[server] final class Fs2ServerCall[Request, Response]( def close(status: Status, metadata: Metadata): SyncIO[Unit] = SyncIO(call.close(status, metadata)) - private def run[F[_]](completed: F[Unit], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = { + private def run[F[_]](completed: F[Metadata], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = { SyncIO { val cancel = dispatcher.unsafeRunCancelable( F.handleError { F.guaranteeCase(completed) { - case Outcome.Succeeded(_) => close(Status.OK, new Metadata()).to[F] + case Outcome.Succeeded(trailersF) => trailersF.flatMap(trailers => close(Status.OK, trailers).to[F]) case Outcome.Errored(e) => handleError(e).to[F] case Outcome.Canceled() => close(Status.CANCELLED, new Metadata()).to[F] - } + }.void }(_ => ()) ) SyncIO(cancel()).void diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala index 1af54faa..9594cbda 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala @@ -90,7 +90,7 @@ private[server] object Fs2UnaryServerCallHandler { } def unary[F[_]: Sync, Request, Response]( - impl: (Request, Metadata) => F[Response], + impl: (Request, Metadata) => F[(Response, Metadata)], options: ServerOptions, dispatcher: Dispatcher[F] ): ServerCallHandler[Request, Response] = diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 2eabf325..7d4e0339 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -43,7 +43,11 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]( + (req, _) => IO(req.length).map(i => (i, new Metadata())), + options, + d + ) val listener = handler.startCall(dummy, new Metadata()) listener.onMessage("123") @@ -58,7 +62,11 @@ class ServerSuite extends Fs2GrpcSuite { runTest("cancellation for unaryToUnary") { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), ServerOptions.default, d) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]( + (req, _) => IO(req.length).map((_, new Metadata())), + ServerOptions.default, + d + ) val listener = handler.startCall(dummy, new Metadata()) listener.onCancel() @@ -71,7 +79,7 @@ class ServerSuite extends Fs2GrpcSuite { runTest("cancellation on the fly for unaryToUnary") { (tc, d) => val dummy = new DummyServerCall val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]( - (req, _) => IO(req.length).delayBy(10.seconds), + (req, _) => IO(req.length).delayBy(10.seconds).map((_, new Metadata())), ServerOptions.default, d ) @@ -94,7 +102,8 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val handler = + Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length).map((_, new Metadata())), options, d) val listener = handler.startCall(dummy, new Metadata()) listener.onMessage("123") @@ -112,7 +121,8 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val handler = + Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length).map((_, new Metadata())), options, d) val listener = handler.startCall(dummy, new Metadata()) listener.onHalfClose()