Skip to content

Commit

Permalink
Add WS heartbeat timeouts. (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caparow authored Apr 10, 2024
1 parent ffc4e9d commit b005a17
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import io.circe
import io.circe.syntax.EncoderOps
import io.circe.{Json, Printer}
import izumi.functional.bio.Exit.{Error, Interruption, Success, Termination}
import izumi.functional.bio.{Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.functional.bio.{Clock1, Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.fundamentals.platform.language.Quirks
import izumi.fundamentals.platform.language.Quirks.Discarder
import izumi.functional.bio.Clock1
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.HttpServer.{ServerWsRpcHandler, WsResponseMarker}
import izumi.idealingua.runtime.rpc.http4s.context.{HttpContextExtractor, WsContextExtractor}
Expand All @@ -26,7 +25,7 @@ import org.typelevel.vault.Key

import java.time.ZonedDateTime
import java.util.concurrent.RejectedExecutionException
import scala.concurrent.duration.DurationInt
import scala.concurrent.duration.{DurationInt, FiniteDuration}

class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
val contextServices: Set[IRTContextServices.AnyContext[F, AuthCtx]],
Expand All @@ -39,12 +38,13 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
)(implicit val AT: Async[F[Throwable, _]]
) {
import dsl.*
// WS Response attribute key, to differ from usual HTTP responses
private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type])

protected val serverMuxer: IRTServerMultiplexor[F, AuthCtx] = IRTServerMultiplexor.combine(contextServices.map(_.authorizedMuxer))
protected val wsContextsSessions: Set[WsContextSessions.AnyContext[F, AuthCtx]] = contextServices.map(_.authorizedWsSessions)

// WS Response attribute key, to differ from usual HTTP responses
private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type])
protected val wsHeartbeatTimeout: FiniteDuration = 1.minute
protected val wsHeartbeatInterval: FiniteDuration = 10.seconds

def service(ws: WebSocketBuilder2[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = {
val svc = HttpRoutes.of(router(ws))
Expand All @@ -62,18 +62,33 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
ws: WebSocketBuilder2[F[Throwable, _]],
): F[Throwable, Response[F[Throwable, _]]] = {
Quirks.discard(request)
def pingStream: Stream[F[Throwable, _], WebSocketFrame.Ping] = {
def pingStream(clientSession: WsClientSession[F, AuthCtx]): Stream[F[Throwable, _], WebSocketFrame] = {
Stream
.awakeEvery[F[Throwable, _]](5.second)
.evalMap(_ => logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping()))
.awakeEvery[F[Throwable, _]](wsHeartbeatInterval)
.evalMap[F[Throwable, _], WebSocketFrame] {
_ =>
for {
expiration <- F.sync(Clock1.Standard.nowZoned().minusNanos(wsHeartbeatTimeout.toNanos))
frame <- clientSession.lastHeartbeat().flatMap {
case Some(lastHeartbeat) if lastHeartbeat.isBefore(expiration) =>
logger.warn(s"WS Session: Websocket client heartbeat timeout: ${clientSession.sessionId}, $wsHeartbeatTimeout.") *>
F.fromEither(WebSocketFrame.Close(1006, s"Ping-Pong heartbeat timed-out after '$wsHeartbeatTimeout'."))
case _ =>
logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping())
}
} yield frame
}.takeThrough {
case _: WebSocketFrame.Close => false
case _ => true
}
}
for {
outQueue <- Queue.unbounded[F[Throwable, _], WebSocketFrame]
authContext <- F.syncThrowable(httpContextExtractor.extract(request))
clientSession = new WsClientSession.Queued(outQueue, authContext, wsContextsSessions, wsSessionsStorage, wsContextExtractor, logger, printer)
_ <- clientSession.start(onWsConnected)

outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream)
outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream(clientSession))
inStream = {
(inputStream: Stream[F[Throwable, _], WebSocketFrame]) =>
inputStream.evalMap {
Expand All @@ -86,6 +101,7 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
wsSessionIdHeader = Header.Raw(HttpServer.`X-Ws-Session-Id`, clientSession.sessionId.sessionId.toString)

response <- ws
.withFilterPingPongs(false)
.withOnClose(handleWsClose(clientSession))
.withHeaders(Headers(wsSessionIdHeader))
.build(outStream, inStream)
Expand All @@ -100,14 +116,16 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
)(frame: WebSocketFrame
): F[Throwable, Option[String]] = {
(frame match {
case WebSocketFrame.Text(msg, _) => wsHandler(clientSession).processRpcMessage(msg)
case WebSocketFrame.Close(_) => F.pure(None)
case _: WebSocketFrame.Pong => onWsHeartbeat(requestTime).as(None)
case WebSocketFrame.Text(msg, _) =>
wsHandler(clientSession).processRpcMessage(msg)
case WebSocketFrame.Close(_) =>
F.pure(None)
case _: WebSocketFrame.Pong =>
clientSession.heartbeat(requestTime) *>
onWsHeartbeat(requestTime).as(None)
case unknownMessage =>
val message = s"Unsupported WS frame: $unknownMessage."
logger
.error(s"WS request failed: $message.")
.as(Some(RpcPacket.rpcCritical(message, None)))
logger.error(s"WS request failed: $message.").as(Some(RpcPacket.rpcCritical(message, None)))
}).map(_.map(p => printer.print(p.asJson)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ package izumi.idealingua.runtime.rpc.http4s.ws
import cats.effect.std.Queue
import io.circe.syntax.EncoderOps
import io.circe.{Json, Printer}
import izumi.functional.bio.{Applicative2, F, IO2, Primitives2, Temporal2}
import izumi.functional.bio.Clock1
import izumi.functional.bio.{Applicative2, Clock1, F, IO2, Primitives2, Temporal2}
import izumi.fundamentals.platform.uuid.UUIDGen
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.ClientWsRpcHandler
Expand All @@ -28,18 +27,24 @@ trait WsClientSession[F[+_, +_], SessionCtx] extends WsResponder[F] {

def start(onStart: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit]
def finish(onFinish: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit]

def heartbeat(at: ZonedDateTime): F[Nothing, Unit]
def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]]
}

object WsClientSession {

def empty[F[+_, +_]: Applicative2, Ctx](wsSessionId: WsSessionId): WsClientSession[F, Ctx] = new WsClientSession[F, Ctx] {
override def sessionId: WsSessionId = wsSessionId
private val heartbeatTimestamp = new AtomicReference[Option[ZonedDateTime]](None)
override def sessionId: WsSessionId = wsSessionId
override def requestAndAwaitResponse(method: IRTMethodId, data: Json, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = F.pure(None)
override def updateRequestCtx(newContext: Ctx): F[Throwable, Ctx] = F.pure(newContext)
override def start(onStart: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit
override def finish(onFinish: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit
override def responseWith(id: RpcPacketId, response: RawResponse): F[Throwable, Unit] = F.unit
override def responseWithData(id: RpcPacketId, data: Json): F[Throwable, Unit] = F.unit
override def heartbeat(at: ZonedDateTime): F[Nothing, Unit] = F.pure(heartbeatTimestamp.set(Some(at)))
override def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] = F.pure(heartbeatTimestamp.get())
}

abstract class Base[F[+_, +_]: IO2: Temporal2: Primitives2, SessionCtx](
Expand All @@ -49,6 +54,7 @@ object WsClientSession {
wsContextExtractor: WsContextExtractor[SessionCtx],
logger: LogIO2[F],
) extends WsClientSession[F, SessionCtx] {
private val heartbeatTimestamp = new AtomicReference[Option[ZonedDateTime]](None)
private val requestCtxRef = new AtomicReference[SessionCtx](initialContext)
private val openingTime: ZonedDateTime = Clock1.Standard.nowZoned()

Expand Down Expand Up @@ -111,6 +117,10 @@ object WsClientSession {
onStart(requestCtx)
}

override def heartbeat(at: ZonedDateTime): F[Nothing, Unit] = F.sync(heartbeatTimestamp.set(Some(at)))

override def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] = F.sync(heartbeatTimestamp.get())

override def toString: String = s"[$sessionId, ${duration().toSeconds}s]"

private def duration(): FiniteDuration = {
Expand Down

0 comments on commit b005a17

Please sign in to comment.