diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala index df84fbd1..e9471df7 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala @@ -8,10 +8,9 @@ import io.circe import io.circe.syntax.EncoderOps import io.circe.{Json, Printer} import izumi.functional.bio.Exit.{Error, Interruption, Success, Termination} -import izumi.functional.bio.{Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2} +import izumi.functional.bio.{Clock1, Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2} import izumi.fundamentals.platform.language.Quirks import izumi.fundamentals.platform.language.Quirks.Discarder -import izumi.functional.bio.Clock1 import izumi.idealingua.runtime.rpc.* import izumi.idealingua.runtime.rpc.http4s.HttpServer.{ServerWsRpcHandler, WsResponseMarker} import izumi.idealingua.runtime.rpc.http4s.context.{HttpContextExtractor, WsContextExtractor} @@ -26,7 +25,7 @@ import org.typelevel.vault.Key import java.time.ZonedDateTime import java.util.concurrent.RejectedExecutionException -import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.{DurationInt, FiniteDuration} class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( val contextServices: Set[IRTContextServices.AnyContext[F, AuthCtx]], @@ -39,12 +38,13 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( )(implicit val AT: Async[F[Throwable, _]] ) { import dsl.* + // WS Response attribute key, to differ from usual HTTP responses + private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type]) protected val serverMuxer: IRTServerMultiplexor[F, AuthCtx] = IRTServerMultiplexor.combine(contextServices.map(_.authorizedMuxer)) protected val wsContextsSessions: Set[WsContextSessions.AnyContext[F, AuthCtx]] = contextServices.map(_.authorizedWsSessions) - - // WS Response attribute key, to differ from usual HTTP responses - private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type]) + protected val wsHeartbeatTimeout: FiniteDuration = 1.minute + protected val wsHeartbeatInterval: FiniteDuration = 10.seconds def service(ws: WebSocketBuilder2[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = { val svc = HttpRoutes.of(router(ws)) @@ -62,10 +62,25 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( ws: WebSocketBuilder2[F[Throwable, _]], ): F[Throwable, Response[F[Throwable, _]]] = { Quirks.discard(request) - def pingStream: Stream[F[Throwable, _], WebSocketFrame.Ping] = { + def pingStream(clientSession: WsClientSession[F, AuthCtx]): Stream[F[Throwable, _], WebSocketFrame] = { Stream - .awakeEvery[F[Throwable, _]](5.second) - .evalMap(_ => logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping())) + .awakeEvery[F[Throwable, _]](wsHeartbeatInterval) + .evalMap[F[Throwable, _], WebSocketFrame] { + _ => + for { + expiration <- F.sync(Clock1.Standard.nowZoned().minusNanos(wsHeartbeatTimeout.toNanos)) + frame <- clientSession.lastHeartbeat().flatMap { + case Some(lastHeartbeat) if lastHeartbeat.isBefore(expiration) => + logger.warn(s"WS Session: Websocket client heartbeat timeout: ${clientSession.sessionId}, $wsHeartbeatTimeout.") *> + F.fromEither(WebSocketFrame.Close(1006, s"Ping-Pong heartbeat timed-out after '$wsHeartbeatTimeout'.")) + case _ => + logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping()) + } + } yield frame + }.takeThrough { + case _: WebSocketFrame.Close => false + case _ => true + } } for { outQueue <- Queue.unbounded[F[Throwable, _], WebSocketFrame] @@ -73,7 +88,7 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( clientSession = new WsClientSession.Queued(outQueue, authContext, wsContextsSessions, wsSessionsStorage, wsContextExtractor, logger, printer) _ <- clientSession.start(onWsConnected) - outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream) + outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream(clientSession)) inStream = { (inputStream: Stream[F[Throwable, _], WebSocketFrame]) => inputStream.evalMap { @@ -86,6 +101,7 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( wsSessionIdHeader = Header.Raw(HttpServer.`X-Ws-Session-Id`, clientSession.sessionId.sessionId.toString) response <- ws + .withFilterPingPongs(false) .withOnClose(handleWsClose(clientSession)) .withHeaders(Headers(wsSessionIdHeader)) .build(outStream, inStream) @@ -100,14 +116,16 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( )(frame: WebSocketFrame ): F[Throwable, Option[String]] = { (frame match { - case WebSocketFrame.Text(msg, _) => wsHandler(clientSession).processRpcMessage(msg) - case WebSocketFrame.Close(_) => F.pure(None) - case _: WebSocketFrame.Pong => onWsHeartbeat(requestTime).as(None) + case WebSocketFrame.Text(msg, _) => + wsHandler(clientSession).processRpcMessage(msg) + case WebSocketFrame.Close(_) => + F.pure(None) + case _: WebSocketFrame.Pong => + clientSession.heartbeat(requestTime) *> + onWsHeartbeat(requestTime).as(None) case unknownMessage => val message = s"Unsupported WS frame: $unknownMessage." - logger - .error(s"WS request failed: $message.") - .as(Some(RpcPacket.rpcCritical(message, None))) + logger.error(s"WS request failed: $message.").as(Some(RpcPacket.rpcCritical(message, None))) }).map(_.map(p => printer.print(p.asJson))) } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala index dc066223..14e5bef5 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala @@ -3,8 +3,7 @@ package izumi.idealingua.runtime.rpc.http4s.ws import cats.effect.std.Queue import io.circe.syntax.EncoderOps import io.circe.{Json, Printer} -import izumi.functional.bio.{Applicative2, F, IO2, Primitives2, Temporal2} -import izumi.functional.bio.Clock1 +import izumi.functional.bio.{Applicative2, Clock1, F, IO2, Primitives2, Temporal2} import izumi.fundamentals.platform.uuid.UUIDGen import izumi.idealingua.runtime.rpc.* import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.ClientWsRpcHandler @@ -28,18 +27,24 @@ trait WsClientSession[F[+_, +_], SessionCtx] extends WsResponder[F] { def start(onStart: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit] def finish(onFinish: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit] + + def heartbeat(at: ZonedDateTime): F[Nothing, Unit] + def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] } object WsClientSession { def empty[F[+_, +_]: Applicative2, Ctx](wsSessionId: WsSessionId): WsClientSession[F, Ctx] = new WsClientSession[F, Ctx] { - override def sessionId: WsSessionId = wsSessionId + private val heartbeatTimestamp = new AtomicReference[Option[ZonedDateTime]](None) + override def sessionId: WsSessionId = wsSessionId override def requestAndAwaitResponse(method: IRTMethodId, data: Json, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = F.pure(None) override def updateRequestCtx(newContext: Ctx): F[Throwable, Ctx] = F.pure(newContext) override def start(onStart: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit override def finish(onFinish: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit override def responseWith(id: RpcPacketId, response: RawResponse): F[Throwable, Unit] = F.unit override def responseWithData(id: RpcPacketId, data: Json): F[Throwable, Unit] = F.unit + override def heartbeat(at: ZonedDateTime): F[Nothing, Unit] = F.pure(heartbeatTimestamp.set(Some(at))) + override def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] = F.pure(heartbeatTimestamp.get()) } abstract class Base[F[+_, +_]: IO2: Temporal2: Primitives2, SessionCtx]( @@ -49,6 +54,7 @@ object WsClientSession { wsContextExtractor: WsContextExtractor[SessionCtx], logger: LogIO2[F], ) extends WsClientSession[F, SessionCtx] { + private val heartbeatTimestamp = new AtomicReference[Option[ZonedDateTime]](None) private val requestCtxRef = new AtomicReference[SessionCtx](initialContext) private val openingTime: ZonedDateTime = Clock1.Standard.nowZoned() @@ -111,6 +117,10 @@ object WsClientSession { onStart(requestCtx) } + override def heartbeat(at: ZonedDateTime): F[Nothing, Unit] = F.sync(heartbeatTimestamp.set(Some(at))) + + override def lastHeartbeat(): F[Nothing, Option[ZonedDateTime]] = F.sync(heartbeatTimestamp.get()) + override def toString: String = s"[$sessionId, ${duration().toSeconds}s]" private def duration(): FiniteDuration = {