diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala deleted file mode 100644 index a7dba3c6..00000000 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package fs2 -package grpc -package server - -import cats.effect._ -import io.grpc._ - -private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response]) extends AnyVal { - def sendHeaders(headers: Metadata)(implicit F: Sync[F]): F[Unit] = - F.delay(call.sendHeaders(headers)) - - def closeStream(status: Status, trailers: Metadata)(implicit F: Sync[F]): F[Unit] = - F.delay(call.close(status, trailers)) - - def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] = - F.delay(call.sendMessage(message)) - - def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] = - F.delay(call.request(numMessages)) -} - -private[server] object Fs2ServerCall { - - def apply[F[_]: Sync, Request, Response]( - call: ServerCall[Request, Response], - options: ServerOptions - ): F[Fs2ServerCall[F, Request, Response]] = Sync[F].delay { - val callOptions = options.callOptionsFn(ServerCallOptions.default) - - call.setMessageCompression(callOptions.messageCompression) - callOptions.compressor.map(_.name).foreach(call.setCompression) - - new Fs2ServerCall[F, Request, Response](call) - } - -} diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala index ebb49aff..d999a68a 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala @@ -23,7 +23,6 @@ package fs2 package grpc package server -import cats.syntax.all._ import cats.effect._ import cats.effect.std.Dispatcher import io.grpc._ @@ -35,46 +34,23 @@ class Fs2ServerCallHandler[F[_]: Async] private ( def unaryToUnaryCall[Request, Response]( implementation: (Request, Metadata) => F[Response] - ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2UnaryServerCallListener[F](call, dispatcher, options)) - listener.unsafeUnaryResponse(new Metadata(), _ flatMap { request => implementation(request, headers) }) - listener - } - } + ): ServerCallHandler[Request, Response] = + Fs2UnaryServerCallHandler.unary(implementation, options, dispatcher) def unaryToStreamingCall[Request, Response]( implementation: (Request, Metadata) => Stream[F, Response] - ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2UnaryServerCallListener[F](call, dispatcher, options)) - listener.unsafeStreamResponse( - new Metadata(), - v => Stream.eval(v) flatMap { request => implementation(request, headers) } - ) - listener - } - } + ): ServerCallHandler[Request, Response] = + Fs2UnaryServerCallHandler.stream(implementation, options, dispatcher) def streamingToUnaryCall[Request, Response]( implementation: (Stream[F, Request], Metadata) => F[Response] - ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options)) - listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers)) - listener - } - } + ): ServerCallHandler[Request, Response] = + Fs2StreamServerCallHandler.unary(implementation, options, dispatcher) def streamingToStreamingCall[Request, Response]( implementation: (Stream[F, Request], Metadata) => Stream[F, Response] - ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options)) - listener.unsafeStreamResponse(new Metadata(), implementation(_, headers)) - listener - } - } + ): ServerCallHandler[Request, Response] = + Fs2StreamServerCallHandler.stream(implementation, options, dispatcher) } object Fs2ServerCallHandler { diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala deleted file mode 100644 index 1c2b22c9..00000000 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package fs2 -package grpc -package server - -import cats.syntax.all._ -import cats.effect._ -import cats.effect.std.Dispatcher -import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException} - -private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { - - def source: G[Request] - def isCancelled: Deferred[F, Unit] - def call: Fs2ServerCall[F, Request, Response] - def dispatcher: Dispatcher[F] - - private def reportError(t: Throwable)(implicit F: Sync[F]): F[Unit] = { - - t match { - case ex: StatusException => - call.closeStream(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata())) - case ex: StatusRuntimeException => - call.closeStream(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata())) - case ex => - // TODO: Customize failure trailers? - call.closeStream(Status.INTERNAL.withDescription(ex.getMessage).withCause(ex), new Metadata()) - } - } - - private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessage - - private def handleStreamResponse(headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessage).compile.drain - - private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = { - val bracketed = F.guaranteeCase(f) { - case Outcome.Succeeded(_) => call.closeStream(Status.OK, new Metadata()) - case Outcome.Canceled() => call.closeStream(Status.CANCELLED, new Metadata()) - case Outcome.Errored(t) => reportError(t) - } - - // Exceptions are reported by closing the call - dispatcher.unsafeRunAndForget(F.race(bracketed, isCancelled.get)) - } - - def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])(implicit - F: Async[F] - ): Unit = - unsafeRun(handleUnaryResponse(headers, implementation(source))) - - def unsafeStreamResponse(headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit - F: Async[F] - ): Unit = - unsafeRun(handleStreamResponse(headers, implementation(source))) -} diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StatefulServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StatefulServerCall.scala new file mode 100644 index 00000000..c2aa0837 --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StatefulServerCall.scala @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.server + +import cats.effect._ +import cats.effect.std.Dispatcher +import io.grpc._ + +object Fs2StatefulServerCall { + type Cancel = () => Any + + def setup[F[_], I, O]( + options: ServerCallOptions, + call: ServerCall[I, O], + dispatcher: Dispatcher[F] + ): Fs2StatefulServerCall[F, I, O] = { + call.setMessageCompression(options.messageCompression) + options.compressor.map(_.name).foreach(call.setCompression) + new Fs2StatefulServerCall[F, I, O](call, dispatcher) + } +} + +final class Fs2StatefulServerCall[F[_], Request, Response]( + call: ServerCall[Request, Response], + dispatcher: Dispatcher[F] +) { + + import Fs2StatefulServerCall.Cancel + + def stream(response: fs2.Stream[F, Response])(implicit F: Sync[F]): Cancel = + run(response.map(sendMessage).compile.drain) + + def unary(response: F[Response])(implicit F: Sync[F]): Cancel = + run(F.map(response)(sendMessage)) + + private var sentHeader: Boolean = false + + private def sendMessage(message: Response): Unit = + if (!sentHeader) { + sentHeader = true + call.sendHeaders(new Metadata()) + call.sendMessage(message) + } else { + call.sendMessage(message) + } + + private def run(completed: F[Unit])(implicit F: Sync[F]): Cancel = + dispatcher.unsafeRunCancelable(F.guaranteeCase(completed) { + case Outcome.Succeeded(_) => closeStream(Status.OK, new Metadata()) + case Outcome.Errored(e) => + e match { + case ex: StatusException => + closeStream(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata())) + case ex: StatusRuntimeException => + closeStream(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata())) + case ex => + closeStream(Status.INTERNAL.withDescription(ex.getMessage).withCause(ex), new Metadata()) + } + case Outcome.Canceled() => closeStream(Status.CANCELLED, new Metadata()) + }) + + private def closeStream(status: Status, metadata: Metadata)(implicit F: Sync[F]): F[Unit] = + F.delay(call.close(status, metadata)) +} diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallHandler.scala new file mode 100644 index 00000000..dea7804b --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallHandler.scala @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.server + +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference +import cats.effect.Async +import cats.effect.std.Dispatcher +import io.grpc.Metadata +import io.grpc.ServerCall +import scala.annotation.tailrec +import scala.collection.immutable.Queue +import fs2._ +import io.grpc.ServerCallHandler + +object Fs2StreamServerCallHandler { + + import Fs2StatefulServerCall.Cancel + + private def mkListener[F[_]: Async, Request, Response]( + run: Stream[F, Request] => Cancel, + call: ServerCall[Request, Response] + ): ServerCall.Listener[Request] = + new ServerCall.Listener[Request] { + private[this] val ch = UnsafeChannel.empty[Request] + private[this] val cancel: Cancel = run(ch.stream.mapChunks { chunk => + val size = chunk.size + if (size > 0) call.request(size) + chunk + }) + + override def onCancel(): Unit = { + cancel() + () + } + + override def onMessage(message: Request): Unit = + ch.send(message) + + override def onHalfClose(): Unit = + ch.close() + } + + def unary[F[_]: Async, Request, Response]( + impl: (Stream[F, Request], Metadata) => F[Response], + options: ServerOptions, + dispatcher: Dispatcher[F] + ): ServerCallHandler[Request, Response] = + new ServerCallHandler[Request, Response] { + private val opt = options.callOptionsFn(ServerCallOptions.default) + + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher) + call.request(1) // prefetch size + mkListener[F, Request, Response](req => responder.unary(impl(req, headers)), call) + } + } + + def stream[F[_]: Async, Request, Response]( + impl: (Stream[F, Request], Metadata) => Stream[F, Response], + options: ServerOptions, + dispatcher: Dispatcher[F] + ): ServerCallHandler[Request, Response] = + new ServerCallHandler[Request, Response] { + private val opt = options.callOptionsFn(ServerCallOptions.default) + + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher) + call.request(1) // prefetch size + mkListener[F, Request, Response](req => responder.stream(impl(req, headers)), call) + } + } +} + +import UnsafeChannel._ + +final class UnsafeChannel[A] extends AtomicReference[State[A]](State.Consumed) { + + import UnsafeChannel.State._ + import scala.annotation.nowarn + + /** Send message to stream. This method is thread-unsafe + */ + @nowarn + @tailrec + def send(a: A): Unit = { + get() match { + case open: Open[A] => + if (!compareAndSet(open, open.append(a))) { + send(a) + } + case s: Suspended[A] => + lazySet(Consumed) + s.resume(new Open(Queue(a))) + case closed: Closed[A] => + } + } + + /** Close stream. This method is thread-unsafe + */ + @tailrec + def close(): Unit = + get() match { + case open: Open[_] => + if (!compareAndSet(open, open.close())) { + close() + } + case s: Suspended[_] => + lazySet(Done) + s.resume(Done) + case _ => + } + + import fs2._ + + /** This method can be called at most once + */ + def stream[F[_]](implicit F: Async[F]): Stream[F, A] = { + @nowarn + def go(): Pull[F, A, Unit] = + Pull + .suspend { + val got = getAndSet(Consumed) + if (got eq Consumed) { + Pull.eval(F.async[State[A]] { cb => + F.delay { + val next = new Suspended[A](s => cb(Right(s))) + if (!compareAndSet(Consumed, next)) { + cb(Right(getAndSet(Consumed))) + None + } else { + Some(F.delay(cb(Right(Cancelled)))) + } + } + }) + } else Pull.pure(got) + } + .flatMap { + case open: Open[A] => Pull.output(Chunk.queue(open.queue)) >> go() + case completed: Closed[A] => Pull.output(Chunk.queue(completed.queue)) + case suspended: Suspended[A] => Pull.done // unexpected + } + + go().stream + } +} + +object UnsafeChannel { + def empty[A]: UnsafeChannel[A] = new UnsafeChannel[A] + + sealed trait State[+A] + + object State { + private[UnsafeChannel] val Consumed: State[Nothing] = new Open(Queue.empty) + private[UnsafeChannel] val Cancelled: State[Nothing] = new Closed(Queue.empty) + private[UnsafeChannel] val Done: State[Nothing] = new Closed(Queue.empty) + + class Open[A](val queue: Queue[A]) extends State[A] { + def append(a: A): Open[A] = new Open(queue.enqueue(a)) + + def close(): Closed[A] = new Closed(queue) + } + + class Closed[A](val queue: Queue[A]) extends State[A] + + class Suspended[A](val f: State[A] => Unit) extends AtomicBoolean(false) with State[A] { + def resume(state: State[A]): Unit = + if (!getAndSet(true)) { + f(state) + } + } + } +} diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala deleted file mode 100644 index 137e9174..00000000 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package fs2 -package grpc -package server - -import cats.Functor -import cats.syntax.all._ -import cats.effect.kernel.Deferred -import cats.effect.Async -import cats.effect.std.{Dispatcher, Queue} -import io.grpc.ServerCall - -class Fs2StreamServerCallListener[F[_], Request, Response] private ( - requestQ: Queue[F, Option[Request]], - val isCancelled: Deferred[F, Unit], - val call: Fs2ServerCall[F, Request, Response], - val dispatcher: Dispatcher[F] -)(implicit F: Functor[F]) - extends ServerCall.Listener[Request] - with Fs2ServerCallListener[F, Stream[F, *], Request, Response] { - - override def onCancel(): Unit = - dispatcher.unsafeRunSync(isCancelled.complete(()).void) - - override def onMessage(message: Request): Unit = { - call.call.request(1) - dispatcher.unsafeRunSync(requestQ.offer(message.some)) - } - - override def onHalfClose(): Unit = - dispatcher.unsafeRunSync(requestQ.offer(none)) - - override def source: Stream[F, Request] = - Stream.repeatEval(requestQ.take).unNoneTerminate -} - -object Fs2StreamServerCallListener { - - class PartialFs2StreamServerCallListener[F[_]](val dummy: Boolean = false) extends AnyVal { - - private[server] def apply[Request, Response]( - call: ServerCall[Request, Response], - dispatcher: Dispatcher[F], - options: ServerOptions - )(implicit F: Async[F]): F[Fs2StreamServerCallListener[F, Request, Response]] = for { - inputQ <- Queue.unbounded[F, Option[Request]] - isCancelled <- Deferred[F, Unit] - serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, isCancelled, serverCall, dispatcher) - - } - - private[server] def apply[F[_]] = new PartialFs2StreamServerCallListener[F] - -} diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2UnaryServerCallHandler.scala new file mode 100644 index 00000000..744ff9d3 --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2UnaryServerCallHandler.scala @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2 +package grpc +package server + +import cats.effect.Async +import cats.effect.std.Dispatcher +import io.grpc._ + +object Fs2UnaryServerCallHandler { + import Fs2StatefulServerCall.Cancel + private val Noop: Cancel = () => () + private val Closed: Cancel = () => () + + private def mkListener[Request, Response]( + run: Request => Cancel, + call: ServerCall[Request, Response] + ): ServerCall.Listener[Request] = + new ServerCall.Listener[Request] { + + private[this] var request: Request = _ + private[this] var cancel: Cancel = Noop + + override def onCancel(): Unit = { + cancel() + () + } + + override def onMessage(message: Request): Unit = + if (request == null) { + request = message + } else if (cancel eq Noop) { + earlyClose(Status.INTERNAL.withDescription("Too many requests")) + } + + override def onHalfClose(): Unit = + if (cancel eq Noop) { + if (request == null) { + earlyClose(Status.INTERNAL.withDescription("Half-closed without a request")) + } else { + cancel = run(request) + } + } + + private def earlyClose(status: Status): Unit = { + cancel = Closed + call.close(status, new Metadata()) + } + } + + def unary[F[_]: Async, Request, Response]( + impl: (Request, Metadata) => F[Response], + options: ServerOptions, + dispatcher: Dispatcher[F] + ): ServerCallHandler[Request, Response] = + new ServerCallHandler[Request, Response] { + private val opt = options.callOptionsFn(ServerCallOptions.default) + + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher) + call.request(2) + mkListener[Request, Response](req => responder.unary(impl(req, headers)), call) + } + } + + def stream[F[_]: Async, Request, Response]( + impl: (Request, Metadata) => fs2.Stream[F, Response], + options: ServerOptions, + dispatcher: Dispatcher[F] + ): ServerCallHandler[Request, Response] = + new ServerCallHandler[Request, Response] { + private val opt = options.callOptionsFn(ServerCallOptions.default) + + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher) + call.request(2) + mkListener[Request, Response](req => responder.stream(impl(req, headers)), call) + } + } +} diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2UnaryServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2UnaryServerCallListener.scala deleted file mode 100644 index 4c82bf61..00000000 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2UnaryServerCallListener.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package fs2 -package grpc -package server - -import cats.MonadError -import cats.syntax.all._ -import cats.effect.kernel.{Async, Deferred, Ref} -import cats.effect.std.Dispatcher -import io.grpc._ - -class Fs2UnaryServerCallListener[F[_], Request, Response] private ( - request: Ref[F, Option[Request]], - isComplete: Deferred[F, Unit], - val isCancelled: Deferred[F, Unit], - val call: Fs2ServerCall[F, Request, Response], - val dispatcher: Dispatcher[F] -)(implicit F: MonadError[F, Throwable]) - extends ServerCall.Listener[Request] - with Fs2ServerCallListener[F, F, Request, Response] { - - import Fs2UnaryServerCallListener._ - - override def onCancel(): Unit = - dispatcher.unsafeRunSync(isCancelled.complete(()).void) - - override def onMessage(message: Request): Unit = { - dispatcher.unsafeRunSync( - request.access - .flatMap[Unit] { case (curValue, modify) => - if (curValue.isDefined) - F.raiseError(statusException(TooManyRequests)) - else - modify(message.some).void - } - ) - } - - override def onHalfClose(): Unit = - dispatcher.unsafeRunSync(isComplete.complete(()).void) - - override def source: F[Request] = - for { - _ <- isComplete.get - valueOrNone <- request.get - value <- valueOrNone.fold[F[Request]](F.raiseError(statusException(NoMessage)))(F.pure) - } yield value -} - -object Fs2UnaryServerCallListener { - - val TooManyRequests: String = "Too many requests" - val NoMessage: String = "No message for unary call" - - private val statusException: String => StatusRuntimeException = msg => - new StatusRuntimeException(Status.INTERNAL.withDescription(msg)) - - class PartialFs2UnaryServerCallListener[F[_]](val dummy: Boolean = false) extends AnyVal { - - private[server] def apply[Request, Response]( - call: ServerCall[Request, Response], - dispatch: Dispatcher[F], - options: ServerOptions - )(implicit F: Async[F]): F[Fs2UnaryServerCallListener[F, Request, Response]] = for { - request <- Ref.of[F, Option[Request]](none) - isComplete <- Deferred[F, Unit] - isCancelled <- Deferred[F, Unit] - serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2UnaryServerCallListener[F, Request, Response](request, isComplete, isCancelled, serverCall, dispatch) - - } - - private[server] def apply[F[_]] = new PartialFs2UnaryServerCallListener[F] -} diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 799b9128..73ef64ce 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -42,9 +42,9 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val listener = handler.startCall(dummy, new Metadata()) - val listener = Fs2UnaryServerCallListener[IO](dummy, d, options).unsafeRunSync() - listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) listener.onMessage("123") listener.onHalfClose() tc.tick() @@ -57,20 +57,33 @@ class ServerSuite extends Fs2GrpcSuite { runTest("cancellation for unaryToUnary") { (tc, d) => val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO](dummy, d, ServerOptions.default).unsafeRunSync() - - listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), ServerOptions.default, d) + val listener = handler.startCall(dummy, new Metadata()) listener.onCancel() tc.tick() - val cancelled = listener.isCancelled.get.unsafeToFuture() - tc.tick() + assertEquals(dummy.currentStatus, None) + assertEquals(dummy.messages.length, 0) + } - IO.sleep(50.millis).unsafeRunSync() + runTest("cancellation on the fly for unaryToUnary") { (tc, d) => + val dummy = new DummyServerCall + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]( + (req, _) => IO(req.length).delayBy(10.seconds), + ServerOptions.default, + d + ) + val listener = handler.startCall(dummy, new Metadata()) - assertEquals(cancelled.isCompleted, true) + listener.onMessage("123") + listener.onHalfClose() + tc.tick() + listener.onCancel() + tc.tick() + assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.CANCELLED)) + assertEquals(dummy.messages.length, 0) } runTest("multiple messages to unaryToUnary")(multipleUnaryToUnary()) @@ -80,21 +93,31 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO](dummy, d, options).unsafeRunSync() + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val listener = handler.startCall(dummy, new Metadata()) - listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) listener.onMessage("123") + listener.onMessage("456") + listener.onHalfClose() + tc.tick() - intercept[StatusRuntimeException] { - listener.onMessage("456") - } + assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.INTERNAL)) + } - listener.onHalfClose() - tc.tickAll() + runTest("no messages to unaryToUnary")(noMessageUnaryToUnary()) + runTest("no messages to unaryToUnary with compression")(noMessageUnaryToUnary(compressionOps)) - assertEquals(dummy.currentStatus.isDefined, true) - assertEquals(dummy.currentStatus.get.isOk, true, "Current status true because stream completed successfully") + private def noMessageUnaryToUnary( + options: ServerOptions = ServerOptions.default + ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => + val dummy = new DummyServerCall + val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) + val listener = handler.startCall(dummy, new Metadata()) + + listener.onHalfClose() + tc.tick() + assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.INTERNAL)) } runTest0("resource awaits termination of server") { (tc, r, _) => @@ -115,9 +138,10 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO][String, Int](dummy, d, options).unsafeRunSync() + val handler = + Fs2UnaryServerCallHandler.stream[IO, String, Int]((s, _) => Stream(s).map(_.length).repeat.take(5), options, d) + val listener = handler.startCall(dummy, new Metadata()) - listener.unsafeStreamResponse(new Metadata(), s => Stream.eval(s).map(_.length).repeat.take(5)) listener.onMessage("123") listener.onHalfClose() tc.tick() @@ -130,9 +154,14 @@ class ServerSuite extends Fs2GrpcSuite { runTest("zero messages to streamingToStreaming") { (tc, d) => val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, d, ServerOptions.default).unsafeRunSync() - listener.unsafeStreamResponse(new Metadata(), _ => Stream.emit(3).repeat.take(5)) + val handler = Fs2StreamServerCallHandler.stream[IO, String, Int]( + (_, _) => Stream.emit(3).repeat.take(5), + ServerOptions.default, + d + ) + val listener = handler.startCall(dummy, new Metadata()) + listener.onHalfClose() tc.tick() @@ -144,15 +173,18 @@ class ServerSuite extends Fs2GrpcSuite { runTest("cancellation for streamingToStreaming") { (tc, d) => val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, d, ServerOptions.default).unsafeRunSync() + val handler = Fs2StreamServerCallHandler.stream[IO, String, Int]( + (_, _) => Stream.emit(3).repeat.take(5).zipLeft(Stream.awakeDelay[IO](1.seconds)), + ServerOptions.default, + d + ) + val listener = handler.startCall(dummy, new Metadata()) - listener.unsafeStreamResponse(new Metadata(), _ => Stream.emit(3).repeat.take(5)) + tc.tick() listener.onCancel() + tc.tick() - val cancelled = listener.isCancelled.get.unsafeToFuture() - tc.tickAll() - - assertEquals(cancelled.isCompleted, true) + assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.CANCELLED)) } runTest("messages to streamingToStreaming")(multipleStreamingToStreaming()) @@ -162,9 +194,10 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, d, options).unsafeRunSync() + val handler = + Fs2StreamServerCallHandler.stream[IO, String, Int]((req, _) => req.map(_.length).intersperse(0), options, d) + val listener = handler.startCall(dummy, new Metadata()) - listener.unsafeStreamResponse(new Metadata(), _.map(_.length).intersperse(0)) listener.onMessage("a") listener.onMessage("ab") listener.onHalfClose() @@ -179,9 +212,14 @@ class ServerSuite extends Fs2GrpcSuite { runTest("messages to streamingToStreaming with error") { (tc, d) => val dummy = new DummyServerCall val error = new RuntimeException("hello") - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, d, ServerOptions.default).unsafeRunSync() - listener.unsafeStreamResponse(new Metadata(), _.map(_.length) ++ Stream.emit(0) ++ Stream.raiseError[IO](error)) + val handler = Fs2StreamServerCallHandler.stream[IO, String, Int]( + (req, _) => req.map(_.length) ++ Stream.emit(0) ++ Stream.raiseError[IO](error), + ServerOptions.default, + d + ) + val listener = handler.startCall(dummy, new Metadata()) + listener.onMessage("a") listener.onMessage("ab") listener.onHalfClose() @@ -204,9 +242,10 @@ class ServerSuite extends Fs2GrpcSuite { _.compile.foldMonoid.map(_.length) val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, d, so).unsafeRunSync() - listener.unsafeUnaryResponse(new Metadata(), implementation) + val handler = Fs2StreamServerCallHandler.unary[IO, String, Int]((req, _) => implementation(req), so, d) + val listener = handler.startCall(dummy, new Metadata()) + listener.onMessage("ab") listener.onMessage("abc") listener.onHalfClose()