Skip to content

Commit

Permalink
wrap side effect with SyncIO
Browse files Browse the repository at this point in the history
  • Loading branch information
naoh87 committed Feb 6, 2022
1 parent c9356cd commit 1cc473c
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 93 deletions.
4 changes: 2 additions & 2 deletions runtime/src/main/scala/fs2/grpc/internal/UnsafeChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ final class UnsafeChannel[A] extends AtomicReference[State[A]](State.Consumed) {
import State._
import scala.annotation._

/** Send message to stream. This method is thread-unsafe
/** Send message to stream.
*/
@nowarn
@tailrec
Expand All @@ -49,7 +49,7 @@ final class UnsafeChannel[A] extends AtomicReference[State[A]](State.Consumed) {
}
}

/** Close stream. This method is thread-unsafe
/** Close stream.
*/
@tailrec
def close(): Unit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,60 +25,82 @@ import cats.effect._
import cats.effect.std.Dispatcher
import fs2.grpc.server.ServerCallOptions
import io.grpc._
import fs2._

object Fs2StatefulServerCall {
type Cancel = () => Any
type Cancel = SyncIO[Unit]

def setup[F[_], I, O](
def setup[I, O](
options: ServerCallOptions,
call: ServerCall[I, O],
dispatcher: Dispatcher[F]
): Fs2StatefulServerCall[F, I, O] = {
call.setMessageCompression(options.messageCompression)
options.compressor.map(_.name).foreach(call.setCompression)
new Fs2StatefulServerCall[F, I, O](call, dispatcher)
}
call: ServerCall[I, O]
): SyncIO[Fs2StatefulServerCall[I, O]] =
SyncIO {
call.setMessageCompression(options.messageCompression)
options.compressor.map(_.name).foreach(call.setCompression)
new Fs2StatefulServerCall[I, O](call)
}
}

final class Fs2StatefulServerCall[F[_], Request, Response](
call: ServerCall[Request, Response],
dispatcher: Dispatcher[F]
final class Fs2StatefulServerCall[Request, Response](
call: ServerCall[Request, Response]
) {

import Fs2StatefulServerCall.Cancel

def stream(response: fs2.Stream[F, Response])(implicit F: Sync[F]): Cancel =
run(response.map(sendMessage).compile.drain)
def stream[F[_]](response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Async[F]): SyncIO[Cancel] =
run(
response.pull.peek1
.flatMap {
case Some((_, tail)) =>
Pull.suspend {
call.sendHeaders(new Metadata())
tail.map(call.sendMessage).pull.echo
}
case None => Pull.done
}
.stream
.compile
.drain,
dispatcher
)

def unary(response: F[Response])(implicit F: Sync[F]): Cancel =
run(F.map(response)(sendMessage))
def unary[F[_]](response: F[Response], dispatcher: Dispatcher[F])(implicit F: Async[F]): SyncIO[Cancel] =
run(
F.map(response) { message =>
call.sendHeaders(new Metadata())
call.sendMessage(message)
},
dispatcher
)

private var sentHeader: Boolean = false
def request(n: Int): SyncIO[Unit] =
SyncIO(call.request(n))

private def sendMessage(message: Response): Unit =
if (!sentHeader) {
sentHeader = true
call.sendHeaders(new Metadata())
call.sendMessage(message)
} else {
call.sendMessage(message)
}
def close(status: Status, metadata: Metadata): SyncIO[Unit] =
SyncIO(call.close(status, metadata))

private def run(completed: F[Unit])(implicit F: Sync[F]): Cancel =
dispatcher.unsafeRunCancelable(F.guaranteeCase(completed) {
case Outcome.Succeeded(_) => closeStream(Status.OK, new Metadata())
case Outcome.Errored(e) =>
e match {
case ex: StatusException =>
closeStream(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata()))
case ex: StatusRuntimeException =>
closeStream(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata()))
case ex =>
closeStream(Status.INTERNAL.withDescription(ex.getMessage).withCause(ex), new Metadata())
}
case Outcome.Canceled() => closeStream(Status.CANCELLED, new Metadata())
})
private def run[F[_]](completed: F[Unit], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = {
SyncIO {
val cancel = dispatcher.unsafeRunCancelable(F.guaranteeCase(completed) {
case Outcome.Succeeded(_) => closeStreamF(Status.OK, new Metadata())
case Outcome.Errored(e) =>
e match {
case ex: StatusException =>
closeStreamF(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata()))
case ex: StatusRuntimeException =>
closeStreamF(ex.getStatus, Option(ex.getTrailers).getOrElse(new Metadata()))
case ex =>
closeStreamF(Status.INTERNAL.withDescription(ex.getMessage).withCause(ex), new Metadata())
}
case Outcome.Canceled() => closeStreamF(Status.CANCELLED, new Metadata())
})
SyncIO {
cancel()
()
}
}
}

private def closeStream(status: Status, metadata: Metadata)(implicit F: Sync[F]): F[Unit] =
private def closeStreamF[F[_]](status: Status, metadata: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.close(status, metadata))
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package fs2.grpc.internal.server

import cats.effect.Async
import cats.effect.SyncIO
import cats.effect.std.Dispatcher
import io.grpc._
import fs2._
Expand All @@ -35,7 +36,7 @@ object Fs2StreamServerCallHandler {
import Fs2StatefulServerCall.Cancel

private def mkListener[F[_]: Async, Request, Response](
run: Stream[F, Request] => Cancel,
run: Stream[F, Request] => SyncIO[Cancel],
call: ServerCall[Request, Response]
): ServerCall.Listener[Request] =
new ServerCall.Listener[Request] {
Expand All @@ -44,12 +45,10 @@ object Fs2StreamServerCallHandler {
val size = chunk.size
if (size > 0) call.request(size)
chunk
})
}).unsafeRunSync()

override def onCancel(): Unit = {
cancel()
()
}
override def onCancel(): Unit =
cancel.unsafeRunSync()

override def onMessage(message: Request): Unit =
ch.send(message)
Expand All @@ -67,10 +66,11 @@ object Fs2StreamServerCallHandler {
private val opt = options.callOptionsFn(ServerCallOptions.default)

def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher)
call.request(1) // prefetch size
mkListener[F, Request, Response](req => responder.unary(impl(req, headers)), call)
}
for {
responder <- Fs2StatefulServerCall.setup(opt, call)
_ <- responder.request(1)
} yield mkListener[F, Request, Response](req => responder.unary(impl(req, headers), dispatcher), call)
}.unsafeRunSync()
}

def stream[F[_]: Async, Request, Response](
Expand All @@ -82,9 +82,10 @@ object Fs2StreamServerCallHandler {
private val opt = options.callOptionsFn(ServerCallOptions.default)

def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher)
call.request(1) // prefetch size
mkListener[F, Request, Response](req => responder.stream(impl(req, headers)), call)
}
for {
responder <- Fs2StatefulServerCall.setup(opt, call)
_ <- responder.request(1)
} yield mkListener[F, Request, Response](req => responder.stream(impl(req, headers), dispatcher), call)
}.unsafeRunSync()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,54 @@
package fs2.grpc.internal.server

import cats.effect.Async
import cats.effect.Ref
import cats.effect.SyncIO
import cats.effect.std.Dispatcher
import fs2.grpc.server.ServerOptions
import fs2.grpc.server.ServerCallOptions
import fs2.grpc.server.ServerOptions
import io.grpc._

object Fs2UnaryServerCallHandler {

import Fs2StatefulServerCall.Cancel
private val Noop: Cancel = () => ()
private val Closed: Cancel = () => ()

case class State[Request](cancel: Option[Cancel], request: Option[Request])

private def mkListener[Request, Response](
run: Request => Cancel,
call: ServerCall[Request, Response]
run: Request => SyncIO[Cancel],
call: Fs2StatefulServerCall[Request, Response],
state: Ref[SyncIO, State[Request]]
): ServerCall.Listener[Request] =
new ServerCall.Listener[Request] {

private[this] var request: Request = _
private[this] var cancel: Cancel = Noop

override def onCancel(): Unit = {
cancel()
()
}
override def onCancel(): Unit =
state.get.flatMap(_.cancel.getOrElse(SyncIO.unit)).unsafeRunSync()

override def onMessage(message: Request): Unit =
if (request == null) {
request = message
} else if (cancel eq Noop) {
earlyClose(Status.INTERNAL.withDescription("Too many requests"))
}
state.get
.flatMap {
case cur if cur.request.isEmpty =>
state.set(cur.copy(request = Some(message)))
case cur =>
earlyClose(cur, Status.INTERNAL.withDescription("Too many requests"))
}
.unsafeRunSync()

override def onHalfClose(): Unit =
if (cancel eq Noop) {
if (request == null) {
earlyClose(Status.INTERNAL.withDescription("Half-closed without a request"))
} else {
cancel = run(request)
state.get
.flatMap {
case State(None, Some(request)) =>
run(request).flatMap(c => state.set(State(Some(c), None)))
case cur =>
earlyClose(cur, Status.INTERNAL.withDescription("Half-closed without a request"))
}
}
.unsafeRunSync()

private def earlyClose(status: Status): Unit = {
cancel = Closed
call.close(status, new Metadata())
private def earlyClose(current: State[Request], status: Status): SyncIO[Unit] = {
if (current.cancel.isEmpty) {
state.set(State(Some(SyncIO.unit), None)) >> call.close(status, new Metadata())
} else {
SyncIO.unit
}
}
}

Expand All @@ -76,11 +81,8 @@ object Fs2UnaryServerCallHandler {
new ServerCallHandler[Request, Response] {
private val opt = options.callOptionsFn(ServerCallOptions.default)

def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher)
call.request(2)
mkListener[Request, Response](req => responder.unary(impl(req, headers)), call)
}
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] =
startCallSync(call, opt)(call => req => call.unary(impl(req, headers), dispatcher)).unsafeRunSync()
}

def stream[F[_]: Async, Request, Response](
Expand All @@ -91,10 +93,18 @@ object Fs2UnaryServerCallHandler {
new ServerCallHandler[Request, Response] {
private val opt = options.callOptionsFn(ServerCallOptions.default)

def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val responder = Fs2StatefulServerCall.setup(opt, call, dispatcher)
call.request(2)
mkListener[Request, Response](req => responder.stream(impl(req, headers)), call)
}
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] =
startCallSync(call, opt)(call => req => call.stream(impl(req, headers), dispatcher)).unsafeRunSync()
}

private def startCallSync[F[_], Request, Response](
call: ServerCall[Request, Response],
options: ServerCallOptions
)(f: Fs2StatefulServerCall[Request, Response] => Request => SyncIO[Cancel]): SyncIO[ServerCall.Listener[Request]] = {
for {
call <- Fs2StatefulServerCall.setup(options, call)
_ <- call.request(2)
state <- Ref.of[SyncIO, State[Request]](State(None, None))
} yield mkListener[Request, Response](f(call), call, state)
}
}

0 comments on commit 1cc473c

Please sign in to comment.