Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WS heartbeat timeouts. #507

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import io.circe
import io.circe.syntax.EncoderOps
import io.circe.{Json, Printer}
import izumi.functional.bio.Exit.{Error, Interruption, Success, Termination}
import izumi.functional.bio.{Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.functional.bio.{Clock1, Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.fundamentals.platform.language.Quirks
import izumi.fundamentals.platform.language.Quirks.Discarder
import izumi.functional.bio.Clock1
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.HttpServer.{ServerWsRpcHandler, WsResponseMarker}
import izumi.idealingua.runtime.rpc.http4s.context.{HttpContextExtractor, WsContextExtractor}
Expand All @@ -26,7 +25,7 @@ import org.typelevel.vault.Key

import java.time.ZonedDateTime
import java.util.concurrent.RejectedExecutionException
import scala.concurrent.duration.DurationInt
import scala.concurrent.duration.{DurationInt, FiniteDuration}

class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
val contextServices: Set[IRTContextServices.AnyContext[F, AuthCtx]],
Expand All @@ -39,12 +38,13 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
)(implicit val AT: Async[F[Throwable, _]]
) {
import dsl.*
// WS Response attribute key, to differ from usual HTTP responses
private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type])

protected val serverMuxer: IRTServerMultiplexor[F, AuthCtx] = IRTServerMultiplexor.combine(contextServices.map(_.authorizedMuxer))
protected val wsContextsSessions: Set[WsContextSessions.AnyContext[F, AuthCtx]] = contextServices.map(_.authorizedWsSessions)

// WS Response attribute key, to differ from usual HTTP responses
private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type])
protected val wsHeartbeatTimeout: FiniteDuration = 1.minute
protected val wsHeartbeatInterval: FiniteDuration = 10.seconds

def service(ws: WebSocketBuilder2[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = {
val svc = HttpRoutes.of(router(ws))
Expand All @@ -62,18 +62,33 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
ws: WebSocketBuilder2[F[Throwable, _]],
): F[Throwable, Response[F[Throwable, _]]] = {
Quirks.discard(request)
def pingStream: Stream[F[Throwable, _], WebSocketFrame.Ping] = {
def pingStream(clientSession: WsClientSession[F, AuthCtx]): Stream[F[Throwable, _], WebSocketFrame] = {
Stream
.awakeEvery[F[Throwable, _]](5.second)
.evalMap(_ => logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping()))
.awakeEvery[F[Throwable, _]](wsHeartbeatInterval)
.evalMap[F[Throwable, _], WebSocketFrame] {
_ =>
for {
expiration <- F.sync(Clock1.Standard.nowZoned().minusNanos(wsHeartbeatTimeout.toNanos))
frame <- clientSession.lastHeartbeat().flatMap {
case Some(lastHeartbeat) if lastHeartbeat.isBefore(expiration) =>
logger.warn(s"WS Session: Websocket client heartbeat timeout: ${clientSession.sessionId}, $wsHeartbeatTimeout.") *>
F.fromEither(WebSocketFrame.Close(1006, s"Ping-Pong heartbeat timed-out after '$wsHeartbeatTimeout'."))
case _ =>
logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping())
}
} yield frame
}.takeThrough {
case _: WebSocketFrame.Close => false
case _ => true
}
}
for {
outQueue <- Queue.unbounded[F[Throwable, _], WebSocketFrame]
authContext <- F.syncThrowable(httpContextExtractor.extract(request))
clientSession = new WsClientSession.Queued(outQueue, authContext, wsContextsSessions, wsSessionsStorage, wsContextExtractor, logger, printer)
_ <- clientSession.start(onWsConnected)

outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream)
outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream(clientSession))
inStream = {
(inputStream: Stream[F[Throwable, _], WebSocketFrame]) =>
inputStream.evalMap {
Expand All @@ -86,6 +101,7 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
wsSessionIdHeader = Header.Raw(HttpServer.`X-Ws-Session-Id`, clientSession.sessionId.sessionId.toString)

response <- ws
.withFilterPingPongs(false)
.withOnClose(handleWsClose(clientSession))
.withHeaders(Headers(wsSessionIdHeader))
.build(outStream, inStream)
Expand All @@ -100,14 +116,16 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx](
)(frame: WebSocketFrame
): F[Throwable, Option[String]] = {
(frame match {
case WebSocketFrame.Text(msg, _) => wsHandler(clientSession).processRpcMessage(msg)
case WebSocketFrame.Close(_) => F.pure(None)
case _: WebSocketFrame.Pong => onWsHeartbeat(requestTime).as(None)
case WebSocketFrame.Text(msg, _) =>
wsHandler(clientSession).processRpcMessage(msg)
case WebSocketFrame.Close(_) =>
F.pure(None)
case _: WebSocketFrame.Pong =>
clientSession.heartbeat(requestTime) *>
onWsHeartbeat(requestTime).as(None)
case unknownMessage =>
val message = s"Unsupported WS frame: $unknownMessage."
logger
.error(s"WS request failed: $message.")
.as(Some(RpcPacket.rpcCritical(message, None)))
logger.error(s"WS request failed: $message.").as(Some(RpcPacket.rpcCritical(message, None)))
}).map(_.map(p => printer.print(p.asJson)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ package izumi.idealingua.runtime.rpc.http4s.ws
import cats.effect.std.Queue
import io.circe.syntax.EncoderOps
import io.circe.{Json, Printer}
import izumi.functional.bio.{Applicative2, F, IO2, Primitives2, Temporal2}
import izumi.functional.bio.Clock1
import izumi.functional.bio.{Applicative2, Clock1, F, IO2, Primitives2, Temporal2}
import izumi.fundamentals.platform.uuid.UUIDGen
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.ClientWsRpcHandler
Expand All @@ -28,18 +27,24 @@ trait WsClientSession[F[+_, +_], SessionCtx] extends WsResponder[F] {

def start(onStart: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit]
def finish(onFinish: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit]

def heartbeat(at: ZonedDateTime): F[Nothing, Unit]
def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]]
}

object WsClientSession {

def empty[F[+_, +_]: Applicative2, Ctx](wsSessionId: WsSessionId): WsClientSession[F, Ctx] = new WsClientSession[F, Ctx] {
override def sessionId: WsSessionId = wsSessionId
private val heartbeatTimestamp = new AtomicReference[Option[ZonedDateTime]](None)
override def sessionId: WsSessionId = wsSessionId
override def requestAndAwaitResponse(method: IRTMethodId, data: Json, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = F.pure(None)
override def updateRequestCtx(newContext: Ctx): F[Throwable, Ctx] = F.pure(newContext)
override def start(onStart: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit
override def finish(onFinish: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit
override def responseWith(id: RpcPacketId, response: RawResponse): F[Throwable, Unit] = F.unit
override def responseWithData(id: RpcPacketId, data: Json): F[Throwable, Unit] = F.unit
override def heartbeat(at: ZonedDateTime): F[Nothing, Unit] = F.pure(heartbeatTimestamp.set(Some(at)))
override def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] = F.pure(heartbeatTimestamp.get())
}

abstract class Base[F[+_, +_]: IO2: Temporal2: Primitives2, SessionCtx](
Expand All @@ -49,6 +54,7 @@ object WsClientSession {
wsContextExtractor: WsContextExtractor[SessionCtx],
logger: LogIO2[F],
) extends WsClientSession[F, SessionCtx] {
private val heartbeatTimestamp = new AtomicReference[Option[ZonedDateTime]](None)
private val requestCtxRef = new AtomicReference[SessionCtx](initialContext)
private val openingTime: ZonedDateTime = Clock1.Standard.nowZoned()

Expand Down Expand Up @@ -111,6 +117,10 @@ object WsClientSession {
onStart(requestCtx)
}

override def heartbeat(at: ZonedDateTime): F[Nothing, Unit] = F.sync(heartbeatTimestamp.set(Some(at)))

override def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] = F.sync(heartbeatTimestamp.get())

override def toString: String = s"[$sessionId, ${duration().toSeconds}s]"

private def duration(): FiniteDuration = {
Expand Down
Loading