Skip to content

Commit

Permalink
Merge pull request #506 from naoh87/add_unary_client
Browse files Browse the repository at this point in the history
add lightweight toUnary client runtime
  • Loading branch information
ahjohannessen authored Feb 26, 2022
2 parents 141d20d + c976963 commit 6684a90
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 22 deletions.
19 changes: 3 additions & 16 deletions runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ package client
import cats.syntax.all._
import cats.effect.{Async, Resource}
import cats.effect.std.Dispatcher
import fs2.grpc.client.internal.Fs2UnaryCallHandler
import io.grpc.{Metadata, _}

final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus])
Expand Down Expand Up @@ -65,15 +66,10 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
//

def unaryToUnaryCall(message: Request, headers: Metadata): F[Response] =
mkUnaryListenerR(headers)
.use(sendSingleMessage(message) *> _.getValue.adaptError(ea))
Fs2UnaryCallHandler.unary(call, options, message, headers)

def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] =
Stream
.resource(mkUnaryListenerR(headers))
.flatMap(l => Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(messages)))
.compile
.lastOrError
Fs2UnaryCallHandler.stream(call, options, messages, headers)

def unaryToStreamingCall(message: Request, md: Metadata): Stream[F, Response] =
Stream
Expand All @@ -93,15 +89,6 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
case (_, Resource.ExitCase.Errored(t)) => cancel(t.getMessage.some, t.some)
}

private def mkUnaryListenerR(md: Metadata): Resource[F, Fs2UnaryClientCallListener[F, Response]] = {

val create = Fs2UnaryClientCallListener.create[F, Response](dispatcher)
val acquire = start(create, md) <* request(1)
val release = handleExitCase(cancelSucceed = false)

Resource.makeCase(acquire)(release)
}

private def mkStreamListenerR(md: Metadata): Resource[F, Fs2StreamClientCallListener[F, Response]] = {

val prefetchN = options.prefetchN.max(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.grpc.client.internal

import cats.effect.Sync
import cats.effect.SyncIO
import cats.effect.syntax.all._
import cats.effect.kernel.Async
import cats.effect.kernel.Outcome
import cats.effect.kernel.Ref
import cats.syntax.functor._
import cats.syntax.flatMap._
import fs2._
import fs2.grpc.client.ClientOptions
import io.grpc._

private[client] object Fs2UnaryCallHandler {
sealed trait ReceiveState[R]

object ReceiveState {
def init[F[_]: Sync, R](
callback: Either[Throwable, R] => Unit,
pf: PartialFunction[StatusRuntimeException, Exception]
): F[Ref[SyncIO, ReceiveState[R]]] =
Ref.in(new PendingMessage[R]({
case r: Right[Throwable, R] => callback(r)
case Left(e: StatusRuntimeException) => callback(Left(pf.lift(e).getOrElse(e)))
case l: Left[Throwable, R] => callback(l)
}))
}

class PendingMessage[R](callback: Either[Throwable, R] => Unit) extends ReceiveState[R] {
def receive(message: R): PendingHalfClose[R] = new PendingHalfClose(callback, message)

def sendError(error: Throwable): SyncIO[ReceiveState[R]] =
SyncIO(callback(Left(error))).as(new Done[R])
}

class PendingHalfClose[R](callback: Either[Throwable, R] => Unit, message: R) extends ReceiveState[R] {
def sendError(error: Throwable): SyncIO[ReceiveState[R]] =
SyncIO(callback(Left(error))).as(new Done[R])

def done: SyncIO[ReceiveState[R]] = SyncIO(callback(Right(message))).as(new Done[R])
}

class Done[R] extends ReceiveState[R]

private def mkListener[Response](
state: Ref[SyncIO, ReceiveState[Response]]
): ClientCall.Listener[Response] =
new ClientCall.Listener[Response] {
override def onMessage(message: Response): Unit =
state.get
.flatMap {
case expected: PendingMessage[Response] =>
state.set(expected.receive(message))
case current: PendingHalfClose[Response] =>
current
.sendError(
Status.INTERNAL
.withDescription("More than one value received for unary call")
.asRuntimeException()
)
.flatMap(state.set)
case _ => SyncIO.unit
}
.unsafeRunSync()

override def onClose(status: Status, trailers: Metadata): Unit = {
if (status.isOk) {
state.get.flatMap {
case expected: PendingHalfClose[Response] =>
expected.done.flatMap(state.set)
case current: PendingMessage[Response] =>
current
.sendError(
Status.INTERNAL
.withDescription("No value received for unary call")
.asRuntimeException(trailers)
)
.flatMap(state.set)
case _ => SyncIO.unit
}
} else {
state.get.flatMap {
case current: PendingHalfClose[Response] =>
current.sendError(status.asRuntimeException(trailers)).flatMap(state.set)
case current: PendingMessage[Response] =>
current.sendError(status.asRuntimeException(trailers)).flatMap(state.set)
case _ => SyncIO.unit
}
}
}.unsafeRunSync()
}

def unary[F[_], Request, Response](
call: ClientCall[Request, Response],
options: ClientOptions,
message: Request,
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).map { state =>
call.start(mkListener[Response](state), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
call.request(2)
call.sendMessage(message)
call.halfClose()
Some(onCancel(call))
}
}

def stream[F[_], Request, Response](
call: ClientCall[Request, Response],
options: ClientOptions,
messages: Stream[F, Request],
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).flatMap { state =>
call.start(mkListener[Response](state), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
call.request(2)
messages
.map(call.sendMessage)
.compile
.drain
.guaranteeCase {
case Outcome.Succeeded(_) => F.delay(call.halfClose())
case Outcome.Errored(e) => F.delay(call.cancel(e.getMessage, e))
case Outcome.Canceled() => onCancel(call)
}
.start
.map(sending => Some(sending.cancel >> onCancel(call)))
}
}

private def onCancel[F[_]](call: ClientCall[_, _])(implicit F: Async[F]): F[Unit] =
F.delay(call.cancel("call was cancelled", null))

}
10 changes: 5 additions & 5 deletions runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ClientSuite extends Fs2GrpcSuite {
tc.tick()
assertEquals(result.value, Some(Success(5)))
assertEquals(dummy.messagesSent.size, 1)
assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)

}

Expand Down Expand Up @@ -93,7 +93,7 @@ class ClientSuite extends Fs2GrpcSuite {
assert(result.value.get.isFailure)
assert(result.value.get.failed.get.isInstanceOf[StatusRuntimeException])
assertEquals(dummy.messagesSent.size, 1)
assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)

}

Expand All @@ -118,7 +118,7 @@ class ClientSuite extends Fs2GrpcSuite {
Status.INTERNAL
)
assertEquals(dummy.messagesSent.size, 1)
assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)

}

Expand All @@ -142,7 +142,7 @@ class ClientSuite extends Fs2GrpcSuite {
tc.tick()
assertEquals(result.value, Some(Success(5)))
assertEquals(dummy.messagesSent.size, 3)
assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)

}

Expand All @@ -166,7 +166,7 @@ class ClientSuite extends Fs2GrpcSuite {
tc.tick()
assertEquals(result.value, Some(Success(5)))
assertEquals(dummy.messagesSent.size, 0)
assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)

}

Expand Down
4 changes: 3 additions & 1 deletion runtime/src/test/scala/fs2/grpc/client/DummyClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DummyClientCall extends ClientCall[String, Int] {
val messagesSent: ArrayBuffer[String] = ArrayBuffer[String]()
var listener: Option[ClientCall.Listener[Int]] = None
var cancelled: Option[(String, Throwable)] = None
var halfClosed = false

override def start(responseListener: ClientCall.Listener[Int], headers: Metadata): Unit =
listener = Some(responseListener)
Expand All @@ -40,7 +41,8 @@ class DummyClientCall extends ClientCall[String, Int] {
override def cancel(message: String, cause: Throwable): Unit =
cancelled = Some((message, cause))

override def halfClose(): Unit = ()
override def halfClose(): Unit =
halfClosed = true

override def sendMessage(message: String): Unit = {
messagesSent += message
Expand Down

0 comments on commit 6684a90

Please sign in to comment.