Skip to content

Commit

Permalink
implement non-blocking unary request server runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
naoh87 committed Feb 13, 2022
1 parent d2da1ad commit 920a3a3
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 54 deletions.
23 changes: 5 additions & 18 deletions runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ package fs2
package grpc
package server

import cats.syntax.all._
import cats.effect._
import cats.effect.std.Dispatcher
import fs2.grpc.server.internal.Fs2UnaryServerCallHandler
import io.grpc._

class Fs2ServerCallHandler[F[_]: Async] private (
Expand All @@ -35,26 +35,13 @@ class Fs2ServerCallHandler[F[_]: Async] private (

def unaryToUnaryCall[Request, Response](
implementation: (Request, Metadata) => F[Response]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2UnaryServerCallListener[F](call, dispatcher, options))
listener.unsafeUnaryResponse(new Metadata(), _ flatMap { request => implementation(request, headers) })
listener
}
}
): ServerCallHandler[Request, Response] =
Fs2UnaryServerCallHandler.unary(implementation, options, dispatcher)

def unaryToStreamingCall[Request, Response](
implementation: (Request, Metadata) => Stream[F, Response]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2UnaryServerCallListener[F](call, dispatcher, options))
listener.unsafeStreamResponse(
new Metadata(),
v => Stream.eval(v) flatMap { request => implementation(request, headers) }
)
listener
}
}
): ServerCallHandler[Request, Response] =
Fs2UnaryServerCallHandler.stream(implementation, options, dispatcher)

def streamingToUnaryCall[Request, Response](
implementation: (Stream[F, Request], Metadata) => F[Response]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.grpc.server.internal

import cats.effect._
import cats.effect.std.Dispatcher
import fs2._
import fs2.grpc.server.ServerCallOptions
import io.grpc._

private[server] object Fs2ServerCall {
type Cancel = SyncIO[Unit]

def setup[I, O](
options: ServerCallOptions,
call: ServerCall[I, O]
): SyncIO[Fs2ServerCall[I, O]] =
SyncIO {
call.setMessageCompression(options.messageCompression)
options.compressor.map(_.name).foreach(call.setCompression)
new Fs2ServerCall[I, O](call)
}
}

private[server] final class Fs2ServerCall[Request, Response](
call: ServerCall[Request, Response]
) {

import Fs2ServerCall.Cancel

def stream[F[_]](response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] =
run(
response.pull.peek1
.flatMap {
case Some((_, stream)) =>
Pull.suspend {
call.sendHeaders(new Metadata())
stream.map(call.sendMessage).pull.echo
}
case None => Pull.done
}
.stream
.compile
.drain,
dispatcher
)

def unary[F[_]](response: F[Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] =
run(
F.map(response) { message =>
call.sendHeaders(new Metadata())
call.sendMessage(message)
},
dispatcher
)

def request(n: Int): SyncIO[Unit] =
SyncIO(call.request(n))

def close(status: Status, metadata: Metadata): SyncIO[Unit] =
SyncIO(call.close(status, metadata))

private def run[F[_]](completed: F[Unit], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = {
SyncIO {
val cancel = dispatcher.unsafeRunCancelable(F.guaranteeCase(completed) {
case Outcome.Succeeded(_) => close(Status.OK, new Metadata()).to[F]
case Outcome.Errored(e) => handleError(e).to[F]
case Outcome.Canceled() => close(Status.CANCELLED, new Metadata()).to[F]
})
SyncIO(cancel()).void
}
}

private def handleError(t: Throwable): SyncIO[Unit] = t match {
case ex: StatusException => close(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata()))
case ex: StatusRuntimeException => close(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata()))
case ex => close(Status.INTERNAL.withDescription(ex.getMessage).withCause(ex), new Metadata())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.grpc.server.internal

import cats.effect.Async
import cats.effect.Ref
import cats.effect.Sync
import cats.effect.SyncIO
import cats.effect.std.Dispatcher
import fs2.grpc.server.ServerCallOptions
import fs2.grpc.server.ServerOptions
import io.grpc._

private[server] object Fs2UnaryServerCallHandler {

import Fs2ServerCall.Cancel

sealed trait CallerState[A]
object CallerState {
def init[A](cb: A => SyncIO[Cancel]): SyncIO[Ref[SyncIO, CallerState[A]]] =
Ref[SyncIO].of[CallerState[A]](PendingMessage(cb))
}
case class PendingMessage[A](callback: A => SyncIO[Cancel]) extends CallerState[A] {
def receive(a: A): PendingHalfClose[A] = PendingHalfClose(callback, a)
}
case class PendingHalfClose[A](callback: A => SyncIO[Cancel], received: A) extends CallerState[A] {
def call(): SyncIO[Called[A]] = callback(received).map(Called.apply)
}
case class Called[A](cancel: Cancel) extends CallerState[A]
case class Cancelled[A]() extends CallerState[A]

private def mkListener[Request, Response](
call: Fs2ServerCall[Request, Response],
state: Ref[SyncIO, CallerState[Request]]
): ServerCall.Listener[Request] =
new ServerCall.Listener[Request] {
override def onCancel(): Unit =
state.get
.flatMap {
case Called(cancel) => cancel >> state.set(Cancelled())
case _ => SyncIO.unit
}
.unsafeRunSync()

override def onMessage(message: Request): Unit =
state.get
.flatMap {
case s: PendingMessage[Request] =>
state.set(s.receive(message))
case _: PendingHalfClose[Request] =>
sendError(Status.INTERNAL.withDescription("Too many requests"))
case _ =>
SyncIO.unit
}
.unsafeRunSync()

override def onHalfClose(): Unit =
state.get
.flatMap {
case s: PendingHalfClose[Request] =>
s.call().flatMap(state.set)
case _: PendingMessage[Request] =>
sendError(Status.INTERNAL.withDescription("Half-closed without a request"))
case _ =>
SyncIO.unit
}
.unsafeRunSync()

private def sendError(status: Status): SyncIO[Unit] =
state.set(Cancelled()) >> call.close(status, new Metadata())
}

def unary[F[_]: Sync, Request, Response](
impl: (Request, Metadata) => F[Response],
options: ServerOptions,
dispatcher: Dispatcher[F]
): ServerCallHandler[Request, Response] =
new ServerCallHandler[Request, Response] {
private val opt = options.callOptionsFn(ServerCallOptions.default)

def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] =
startCallSync(call, opt)(call => req => call.unary(impl(req, headers), dispatcher)).unsafeRunSync()
}

def stream[F[_]: Sync, Request, Response](
impl: (Request, Metadata) => fs2.Stream[F, Response],
options: ServerOptions,
dispatcher: Dispatcher[F]
): ServerCallHandler[Request, Response] =
new ServerCallHandler[Request, Response] {
private val opt = options.callOptionsFn(ServerCallOptions.default)

def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] =
startCallSync(call, opt)(call => req => call.stream(impl(req, headers), dispatcher)).unsafeRunSync()
}

private def startCallSync[F[_], Request, Response](
call: ServerCall[Request, Response],
options: ServerCallOptions
)(f: Fs2ServerCall[Request, Response] => Request => SyncIO[Cancel]): SyncIO[ServerCall.Listener[Request]] = {
for {
call <- Fs2ServerCall.setup(options, call)
// We expect only 1 request, but we ask for 2 requests here so that if a misbehaving client
// sends more than 1 requests, ServerCall will catch it.
_ <- call.request(2)
state <- CallerState.init(f(call))
} yield mkListener[Request, Response](call, state)
}
}
Loading

0 comments on commit 920a3a3

Please sign in to comment.