diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala index dd9b3305f..6eb733317 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala @@ -23,9 +23,8 @@ import cats.syntax.all._ import dev.profunktor.redis4cats.connection.RedisClient import dev.profunktor.redis4cats.data._ import dev.profunktor.redis4cats.effect._ -import dev.profunktor.redis4cats.pubsub.internals.{ LivePubSubCommands, Publisher, Subscriber } +import dev.profunktor.redis4cats.pubsub.internals.{ LivePubSubCommands, PubSubState, Publisher, Subscriber } import fs2.Stream -import dev.profunktor.redis4cats.pubsub.internals.PubSubState import io.lettuce.core.pubsub.StatefulRedisPubSubConnection object PubSub { @@ -58,7 +57,7 @@ object PubSub { val (acquire, release) = acquireAndRelease[F, K, V](client, codec) // One exclusive connection for subscriptions and another connection for publishing / stats for { - state <- Resource.eval(Ref.of[F, PubSubState[F, K, V]](PubSubState(Map.empty, Map.empty))) + state <- Resource.eval(PubSubState.make[F, K, V]) sConn <- Resource.make(acquire)(release) pConn <- Resource.make(acquire)(release) } yield new LivePubSubCommands[F, K, V](state, sConn, pConn) @@ -88,7 +87,7 @@ object PubSub { ): Resource[F, SubscribeCommands[F, Stream[F, *], K, V]] = { val (acquire, release) = acquireAndRelease[F, K, V](client, codec) for { - state <- Resource.eval(Ref.of[F, PubSubState[F, K, V]](PubSubState(Map.empty, Map.empty))) + state <- Resource.eval(PubSubState.make[F, K, V]) conn <- Resource.make(acquire)(release) } yield new Subscriber(state, conn) } diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSubCommands.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSubCommands.scala index 408ddcfe0..12058d870 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSubCommands.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSubCommands.scala @@ -37,8 +37,12 @@ trait PubSubStats[F[_], K] { * @tparam V the value type */ trait PublishCommands[F[_], S[_], K, V] extends PubSubStats[F, K] { - def publish(channel: RedisChannel[K]): S[V] => S[Unit] - def publish(channel: RedisChannel[K], value: V): F[Unit] + + /** @return The number of clients that received the message. */ + def publish(channel: RedisChannel[K]): S[V] => S[Long] + + /** @return The number of clients that received the message. */ + def publish(channel: RedisChannel[K], value: V): F[Long] } /** @@ -51,17 +55,40 @@ trait SubscribeCommands[F[_], S[_], K, V] { /** * Subscribes to a channel. + * + * @note If you invoke `subscribe` multiple times for the same channel, we will not call 'SUBSCRIBE' in Redis multiple + * times but instead will return a stream that will use the existing subscription to that channel. The underlying + * subscription is cleaned up when all the streams terminate or when `unsubscribe` is invoked. */ def subscribe(channel: RedisChannel[K]): S[V] + /** Terminates all streams that are subscribed to the channel. */ def unsubscribe(channel: RedisChannel[K]): F[Unit] /** * Subscribes to a pattern. + * + * @note If you invoke `subscribe` multiple times for the same pattern, we will not call 'SUBSCRIBE' in Redis multiple + * times but instead will return a stream that will use the existing subscription to that pattern. The underlying + * subscription is cleaned up when all the streams terminate or when `unsubscribe` is invoked. */ def psubscribe(channel: RedisPattern[K]): S[RedisPatternEvent[K, V]] + /** Terminates all streams that are subscribed to the pattern. */ def punsubscribe(channel: RedisPattern[K]): F[Unit] + + /** Returns the channel subscriptions that the library keeps of. + * + * @return how many streams are subscribed to each channel. + * @see [[SubscribeCommands.subscribe]] for more information. + * */ + def internalChannelSubscriptions: F[Map[RedisChannel[K], Long]] + + /** Returns the pattern subscriptions that the library keeps of. + * + * @return how many streams are subscribed to each pattern. + * @see [[SubscribeCommands.psubscribe]] for more information. */ + def internalPatternSubscriptions: F[Map[RedisPattern[K], Long]] } /** diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubCommands.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubCommands.scala index 2529fe63a..9376c302f 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubCommands.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubCommands.scala @@ -29,7 +29,7 @@ import fs2.Stream import io.lettuce.core.pubsub.StatefulRedisPubSubConnection private[pubsub] class LivePubSubCommands[F[_]: Async: Log, K, V]( - state: Ref[F, PubSubState[F, K, V]], + state: PubSubState[F, K, V], subConnection: StatefulRedisPubSubConnection[K, V], pubConnection: StatefulRedisPubSubConnection[K, V] ) extends PubSubCommands[F, Stream[F, *], K, V] { @@ -50,13 +50,17 @@ private[pubsub] class LivePubSubCommands[F[_]: Async: Log, K, V]( override def punsubscribe(pattern: RedisPattern[K]): F[Unit] = subCommands.punsubscribe(pattern) - override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Unit] = + override def internalChannelSubscriptions: F[Map[RedisChannel[K], Long]] = + subCommands.internalChannelSubscriptions + + override def internalPatternSubscriptions: F[Map[RedisPattern[K], Long]] = + subCommands.internalPatternSubscriptions + + override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Long] = _.evalMap(publish(channel, _)) - override def publish(channel: RedisChannel[K], message: V): F[Unit] = { - val resource = Resource.eval(state.get) >>= PubSubInternals.channel[F, K, V](state, subConnection).apply(channel) - resource.use(_ => FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).void) - } + override def publish(channel: RedisChannel[K], message: V): F[Long] = + FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).map(l => l: Long) override def numPat: F[Long] = pubSubStats.numPat @@ -78,5 +82,4 @@ private[pubsub] class LivePubSubCommands[F[_]: Async: Log, K, V]( override def shardNumSub(channels: List[RedisChannel[K]]): F[List[Subscription[K]]] = pubSubStats.shardNumSub(channels) - } diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubStats.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubStats.scala index 9784e0622..0c5ff0ea8 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubStats.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubStats.scala @@ -63,7 +63,6 @@ private[pubsub] class LivePubSubStats[F[_]: FlatMap: FutureLift, K, V]( FutureLift[F] .lift(pubConnection.async().pubsubShardNumsub(channels.map(_.underlying): _*)) .map(toSubscription[K]) - } object LivePubSubStats { private def toSubscription[K](map: ju.Map[K, JLong]): List[Subscription[K]] = diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubInternals.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubInternals.scala index 6ee382b45..6d8e4837f 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubInternals.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubInternals.scala @@ -17,31 +17,26 @@ package dev.profunktor.redis4cats.pubsub.internals import scala.util.control.NoStackTrace - -import cats.effect.kernel.{ Async, Ref, Resource, Sync } -import cats.effect.std.Dispatcher -import cats.syntax.all._ +import cats.effect.std.{ Dispatcher } import dev.profunktor.redis4cats.data.RedisChannel import dev.profunktor.redis4cats.data.RedisPattern import dev.profunktor.redis4cats.data.RedisPatternEvent -import dev.profunktor.redis4cats.effect.Log -import fs2.concurrent.Topic -import io.lettuce.core.pubsub.{ RedisPubSubListener, StatefulRedisPubSubConnection } +import io.lettuce.core.pubsub.{ RedisPubSubListener } import io.lettuce.core.pubsub.RedisPubSubAdapter object PubSubInternals { case class DispatcherAlreadyShutdown() extends NoStackTrace - private[redis4cats] def channelListener[F[_]: Async, K, V]( + private[redis4cats] def channelListener[F[_], K, V]( channel: RedisChannel[K], - topic: Topic[F, Option[V]], + publish: V => F[Unit], dispatcher: Dispatcher[F] ): RedisPubSubListener[K, V] = new RedisPubSubAdapter[K, V] { override def message(ch: K, msg: V): Unit = if (ch == channel.underlying) { try { - dispatcher.unsafeRunSync(topic.publish1(Option(msg)).void) + dispatcher.unsafeRunSync(publish(msg)) } catch { case _: IllegalStateException => throw DispatcherAlreadyShutdown() } @@ -50,65 +45,19 @@ object PubSubInternals { // Do not uncomment this, as if you will do this the channel listener will get a message twice // override def message(pattern: K, channel: K, message: V): Unit = {} } - private[redis4cats] def patternListener[F[_]: Async, K, V]( + private[redis4cats] def patternListener[F[_], K, V]( redisPattern: RedisPattern[K], - topic: Topic[F, Option[RedisPatternEvent[K, V]]], + publish: RedisPatternEvent[K, V] => F[Unit], dispatcher: Dispatcher[F] ): RedisPubSubListener[K, V] = new RedisPubSubAdapter[K, V] { override def message(pattern: K, channel: K, message: V): Unit = if (pattern == redisPattern.underlying) { try { - dispatcher.unsafeRunSync(topic.publish1(Option(RedisPatternEvent(pattern, channel, message))).void) + dispatcher.unsafeRunSync(publish(RedisPatternEvent(pattern, channel, message))) } catch { case _: IllegalStateException => throw DispatcherAlreadyShutdown() } } } - - private[redis4cats] def channel[F[_]: Async: Log, K, V]( - state: Ref[F, PubSubState[F, K, V]], - subConnection: StatefulRedisPubSubConnection[K, V] - ): GetOrCreateTopicListener[F, K, V] = { channel => st => - st.channels - .get(channel.underlying) - .fold { - for { - dispatcher <- Dispatcher.parallel[F] - topic <- Resource.eval(Topic[F, Option[V]]) - _ <- Resource.eval(Log[F].info(s"Creating listener for channel: $channel")) - listener = channelListener(channel, topic, dispatcher) - _ <- Resource.make { - Sync[F].delay(subConnection.addListener(listener)) *> - state.update(s => s.copy(channels = s.channels.updated(channel.underlying, topic))) - } { _ => - Sync[F].delay(subConnection.removeListener(listener)) *> - state.update(s => s.copy(channels = s.channels - channel.underlying)) - } - } yield topic - }(Resource.pure) - } - - private[redis4cats] def pattern[F[_]: Async: Log, K, V]( - state: Ref[F, PubSubState[F, K, V]], - subConnection: StatefulRedisPubSubConnection[K, V] - ): GetOrCreatePatternListener[F, K, V] = { channel => st => - st.patterns - .get(channel.underlying) - .fold { - for { - dispatcher <- Dispatcher.parallel[F] - topic <- Resource.eval(Topic[F, Option[RedisPatternEvent[K, V]]]) - _ <- Resource.eval(Log[F].info(s"Creating listener for pattern: $channel")) - listener = patternListener(channel, topic, dispatcher) - _ <- Resource.make { - Sync[F].delay(subConnection.addListener(listener)) *> - state.update(s => s.copy(patterns = s.patterns.updated(channel.underlying, topic))) - } { _ => - Sync[F].delay(subConnection.removeListener(listener)) *> - state.update(s => s.copy(patterns = s.patterns - channel.underlying)) - } - } yield topic - }(Resource.pure) - } } diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala index c2a3d02b6..06f884ca9 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala @@ -16,10 +16,21 @@ package dev.profunktor.redis4cats.pubsub.internals -import dev.profunktor.redis4cats.data.RedisPatternEvent -import fs2.concurrent.Topic +import cats.syntax.all._ +import cats.effect.kernel.Concurrent +import cats.effect.std.AtomicCell +import dev.profunktor.redis4cats.data.{ RedisChannel, RedisPattern, RedisPatternEvent } -final case class PubSubState[F[_], K, V]( - channels: Map[K, Topic[F, Option[V]]], - patterns: Map[K, Topic[F, Option[RedisPatternEvent[K, V]]]] +/** We use `AtomicCell` instead of `Ref` because we need locking while side-effecting. */ +case class PubSubState[F[_], K, V]( + channelSubs: AtomicCell[F, Map[RedisChannel[K], Redis4CatsSubscription[F, V]]], + patternSubs: AtomicCell[F, Map[RedisPattern[K], Redis4CatsSubscription[F, RedisPatternEvent[K, V]]]] ) +object PubSubState { + def make[F[_]: Concurrent, K, V]: F[PubSubState[F, K, V]] = + for { + channelSubs <- AtomicCell[F].of(Map.empty[RedisChannel[K], Redis4CatsSubscription[F, V]]) + patternSubs <- AtomicCell[F].of(Map.empty[RedisPattern[K], Redis4CatsSubscription[F, RedisPatternEvent[K, V]]]) + } yield apply(channelSubs, patternSubs) + +} diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Publisher.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Publisher.scala index e4fe367d0..f5e45901c 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Publisher.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Publisher.scala @@ -32,11 +32,11 @@ private[pubsub] class Publisher[F[_]: FlatMap: FutureLift, K, V]( private[redis4cats] val pubSubStats: PubSubStats[F, K] = new LivePubSubStats(pubConnection) - override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Unit] = + override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Long] = _.evalMap(publish(channel, _)) - override def publish(channel: RedisChannel[K], message: V): F[Unit] = - FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).void + override def publish(channel: RedisChannel[K], message: V): F[Long] = + FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).map(l => l: Long) override def pubSubChannels: F[List[RedisChannel[K]]] = pubSubStats.pubSubChannels diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Redis4CatsSubscription.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Redis4CatsSubscription.scala new file mode 100644 index 000000000..751769f40 --- /dev/null +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Redis4CatsSubscription.scala @@ -0,0 +1,43 @@ +/* + * Copyright 2018-2021 ProfunKtor + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.profunktor.redis4cats.pubsub.internals + +import cats.Applicative +import fs2.concurrent.Topic + +/** + * Stores an ongoing subscription. + * + * @param topic single-publisher, multiple-subscribers. The same topic is reused if `subscribe` is invoked more than + * once. The subscribers' streams are terminated when `None` is published. + * @param subscribers subscriber count, when `subscribers` reaches 0 `cleanup` is called and `None` is published + * to the topic. + */ +final private[redis4cats] case class Redis4CatsSubscription[F[_], V]( + topic: Topic[F, Option[V]], + subscribers: Long, + cleanup: F[Unit] +) { + assert(subscribers > 0, s"subscribers must be > 0, was $subscribers") + + def addSubscriber: Redis4CatsSubscription[F, V] = copy(subscribers = subscribers + 1) + def removeSubscriber: Redis4CatsSubscription[F, V] = copy(subscribers = subscribers - 1) + def isLastSubscriber: Boolean = subscribers == 1 + + def stream(onTermination: F[Unit])(implicit F: Applicative[F]): fs2.Stream[F, V] = + topic.subscribe(500).unNoneTerminate.onFinalize(onTermination) +} diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala index 97428e4d5..f36e9c102 100644 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala +++ b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Subscriber.scala @@ -21,54 +21,149 @@ package internals import cats.Applicative import cats.effect.kernel._ import cats.effect.kernel.implicits._ +import cats.effect.std.{ AtomicCell, Dispatcher } import cats.syntax.all._ import dev.profunktor.redis4cats.data.{ RedisChannel, RedisPattern, RedisPatternEvent } import dev.profunktor.redis4cats.effect.{ FutureLift, Log } import fs2.Stream -import io.lettuce.core.pubsub.StatefulRedisPubSubConnection +import fs2.concurrent.Topic +import io.lettuce.core.pubsub.{ RedisPubSubListener, StatefulRedisPubSubConnection } private[pubsub] class Subscriber[F[_]: Async: FutureLift: Log, K, V]( - state: Ref[F, PubSubState[F, K, V]], + state: PubSubState[F, K, V], subConnection: StatefulRedisPubSubConnection[K, V] ) extends SubscribeCommands[F, Stream[F, *], K, V] { override def subscribe(channel: RedisChannel[K]): Stream[F, V] = - Stream - .resource(Resource.eval(state.get) >>= PubSubInternals.channel[F, K, V](state, subConnection).apply(channel)) - .evalTap(_ => - FutureLift[F] - .lift(subConnection.async().subscribe(channel.underlying)) - ) - .flatMap(_.subscribe(500).unNone) + Subscriber.subscribe( + channel, + state.channelSubs, + subConnection, + subscribeToRedis = FutureLift[F].lift(subConnection.async().subscribe(channel.underlying)).void, + unsubscribeFromRedis = FutureLift[F].lift(subConnection.async().unsubscribe(channel.underlying)).void + )((dispatcher, topic) => + PubSubInternals.channelListener(channel, (v: V) => topic.publish1(Some(v)).void, dispatcher) + ) override def unsubscribe(channel: RedisChannel[K]): F[Unit] = - FutureLift[F] - .lift(subConnection.async().unsubscribe(channel.underlying)) - .void - .guarantee(state.get.flatMap { st => - st.channels.get(channel.underlying).fold(Applicative[F].unit)(_.publish1(none[V]).void) *> state - .update(s => s.copy(channels = s.channels - channel.underlying)) - }) + Subscriber.unsubscribeFrom(channel, state.channelSubs) override def psubscribe( pattern: RedisPattern[K] ): Stream[F, RedisPatternEvent[K, V]] = - Stream - .resource(Resource.eval(state.get) >>= PubSubInternals.pattern[F, K, V](state, subConnection).apply(pattern)) - .evalTap(_ => - FutureLift[F] - .lift(subConnection.async().psubscribe(pattern.underlying)) - ) - .flatMap(_.subscribe(500).unNone) + Subscriber.subscribe( + pattern, + state.patternSubs, + subConnection, + subscribeToRedis = FutureLift[F].lift(subConnection.async().psubscribe(pattern.underlying)).void, + unsubscribeFromRedis = FutureLift[F].lift(subConnection.async().punsubscribe(pattern.underlying)).void + )((dispatcher, topic) => + PubSubInternals + .patternListener(pattern, (evt: RedisPatternEvent[K, V]) => topic.publish1(Some(evt)).void, dispatcher) + ) override def punsubscribe(pattern: RedisPattern[K]): F[Unit] = - FutureLift[F] - .lift(subConnection.async().punsubscribe(pattern.underlying)) - .void - .guarantee(state.get.flatMap { st => - st.patterns - .get(pattern.underlying) - .fold(Applicative[F].unit)(_.publish1(none[RedisPatternEvent[K, V]]).void) *> state - .update(s => s.copy(patterns = s.patterns - pattern.underlying)) + Subscriber.unsubscribeFrom(pattern, state.patternSubs) + + override def internalChannelSubscriptions: F[Map[RedisChannel[K], Long]] = + state.channelSubs.get.map(_.iterator.map { case (k, v) => k -> v.subscribers }.toMap) + + override def internalPatternSubscriptions: F[Map[RedisPattern[K], Long]] = + state.patternSubs.get.map(_.iterator.map { case (k, v) => k -> v.subscribers }.toMap) +} +object Subscriber { + + /** + * Check if we have a subscriber for this channel and remove it if we do. + * + * If it is the last subscriber, perform the subscription cleanup. + */ + private def onStreamTermination[F[_]: Applicative: Log, K, V]( + subs: AtomicCell[F, Map[K, Redis4CatsSubscription[F, V]]], + key: K + ): F[Unit] = subs.evalUpdate { subscribers => + subscribers.get(key) match { + case None => + Log[F] + .error( + s"We were notified about stream termination for $key but we don't have a subscription, " + + s"this is a bug in redis4cats!" + ) + .as(subscribers) + case Some(sub) => + if (!sub.isLastSubscriber) subscribers.updated(key, sub.removeSubscriber).pure + else sub.cleanup.as(subscribers - key) + } + } + + private def unsubscribeFrom[F[_]: MonadCancelThrow: Log, K, V]( + key: K, + subs: AtomicCell[F, Map[K, Redis4CatsSubscription[F, V]]] + ): F[Unit] = + subs.evalUpdate { subscribers => + subscribers.get(key) match { + case None => + // No subscription = nothing to do + Log[F] + .debug(s"Not unsubscribing from $key because we don't have a subscription") + .as(subscribers) + case Some(sub) => + // Publish `None` which will terminate all streams, which will perform cleanup once the last stream + // terminates. + (Log[F].info( + s"Unsubscribing from $key with ${sub.subscribers} subscribers" + ) *> sub.topic.publish1(None)).uncancelable.as(subscribers) + } + } + + private def subscribe[F[_]: Async: Log, TypedKey, SubValue, K, V]( + key: TypedKey, + subs: AtomicCell[F, Map[TypedKey, Redis4CatsSubscription[F, SubValue]]], + subConnection: StatefulRedisPubSubConnection[K, V], + subscribeToRedis: F[Unit], + unsubscribeFromRedis: F[Unit] + )(makeListener: (Dispatcher[F], Topic[F, Option[SubValue]]) => RedisPubSubListener[K, V]): Stream[F, SubValue] = + Stream + .eval(subs.evalModify { subscribers => + def stream(sub: Redis4CatsSubscription[F, SubValue]) = + sub.stream(onStreamTermination(subs, key)) + + subscribers.get(key) match { + case Some(subscription) => + // We have an existing subscription, mark that it has one more subscriber. + val newSubscription = subscription.addSubscriber + val newSubscribers = subscribers.updated(key, newSubscription) + Log[F] + .debug( + s"Returning existing subscription for $key, " + + s"subscribers: ${subscription.subscribers} -> ${newSubscription.subscribers}" + ) + .as((newSubscribers, stream(newSubscription))) + + case None => + // No existing subscription, create a new one. + val makeSubscription = for { + _ <- Log[F].info(s"Creating subscription for $key") + // We use parallel dispatcher because multiple subscribers can be interested in the same key + dispatcherTpl <- Dispatcher.parallel[F].allocated + (dispatcher, cleanupDispatcher) = dispatcherTpl + topic <- Topic[F, Option[SubValue]] + listener = makeListener(dispatcher, topic) + cleanupListener = Sync[F].delay(subConnection.removeListener(listener)) + cleanup = ( + Log[F].debug(s"Cleaning up resources for $key subscription") *> + unsubscribeFromRedis *> cleanupListener *> cleanupDispatcher *> + Log[F].debug(s"Cleaned up resources for $key subscription") + ).uncancelable + _ <- Sync[F].delay(subConnection.addListener(listener)) + _ <- subscribeToRedis + sub = Redis4CatsSubscription(topic, subscribers = 1, cleanup) + newSubscribers = subscribers.updated(key, sub) + _ <- Log[F].debug(s"Created subscription for $key") + } yield (newSubscribers, stream(sub)) + + makeSubscription.uncancelable + } }) + .flatten } diff --git a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/package.scala b/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/package.scala deleted file mode 100644 index e8b624b74..000000000 --- a/modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/package.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2018-2021 ProfunKtor - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.profunktor.redis4cats.pubsub - -import cats.effect.kernel.Resource -import dev.profunktor.redis4cats.data.RedisChannel -import fs2.concurrent.Topic -import dev.profunktor.redis4cats.data.RedisPattern -import dev.profunktor.redis4cats.data.RedisPatternEvent - -package object internals { - private[pubsub] type GetOrCreateTopicListener[F[_], K, V] = - RedisChannel[K] => PubSubState[F, K, V] => Resource[F, Topic[F, Option[V]]] - - private[pubsub] type GetOrCreatePatternListener[F[_], K, V] = - RedisPattern[K] => PubSubState[F, K, V] => Resource[F, Topic[F, Option[RedisPatternEvent[K, V]]]] -} diff --git a/modules/tests/src/test/scala/dev/profunktor/redis4cats/Redis4CatsFunSuite.scala b/modules/tests/src/test/scala/dev/profunktor/redis4cats/Redis4CatsFunSuite.scala index 511745f1b..7ef9383a0 100644 --- a/modules/tests/src/test/scala/dev/profunktor/redis4cats/Redis4CatsFunSuite.scala +++ b/modules/tests/src/test/scala/dev/profunktor/redis4cats/Redis4CatsFunSuite.scala @@ -20,13 +20,15 @@ import cats.effect._ import cats.syntax.all._ import dev.profunktor.redis4cats.Redis4CatsFunSuite.{ Fs2PubSub, Fs2Streaming } import dev.profunktor.redis4cats.connection._ -import dev.profunktor.redis4cats.data.RedisCodec +import dev.profunktor.redis4cats.data.{ RedisChannel, RedisCodec } import dev.profunktor.redis4cats.effect.Log.NoOp._ +import dev.profunktor.redis4cats.pubsub.data.Subscription import dev.profunktor.redis4cats.pubsub.{ PubSub, PubSubCommands } import dev.profunktor.redis4cats.streams.{ RedisStream, Streaming } import io.lettuce.core.{ ClientOptions, TimeoutOptions } +import munit.{ Compare, Location } -import scala.concurrent.duration.{ Duration, DurationInt } +import scala.concurrent.duration.{ Duration, DurationInt, FiniteDuration } import scala.concurrent.{ Await, Future } abstract class Redis4CatsFunSuite(isCluster: Boolean) extends IOSuite { @@ -114,6 +116,53 @@ abstract class Redis4CatsFunSuite(isCluster: Boolean) extends IOSuite { def withRedisCluster[A](f: RedisCommands[IO, String, String] => IO[A]): Future[Unit] = withAbstractRedisCluster[A, String, String](f)(stringCodec) + implicit class PubSubExts(pubSub: Fs2PubSub[String, String]) { + + /** Assert that a given channel has the given number of subscriptions. + * + * @param waitFor max time to wait for the expected number of subscriptions to be present + * */ + def shouldHaveNSubs( + channel: RedisChannel[String], + count: Long, + waitFor: FiniteDuration = 0.nanos + )(implicit loc: Location): IO[Unit] = + waitUntilEquals( + pubSub.pubSubSubscriptions(List(channel)), + List(Subscription(channel, count)), + waitFor + ) + } + + case class FiberWithStatus[A](fiber: FiberIO[A], status: Ref[IO, Option[Either[Unit, OutcomeIO[A]]]]) { + def isRunning: IO[Boolean] = status.get.map(_.contains(Left(()))) + + def waitUntilRunning(timeout: FiniteDuration = 250.millis): IO[Unit] = + waitUntilEquals(isRunning, true, timeout, s"fiber $fiber should have started by now") + } + implicit class IOExts[A](io: IO[A]) { + def startWithStatus: IO[FiberWithStatus[A]] = + for { + status <- Ref[IO].of(Option.empty[Either[Unit, OutcomeIO[A]]]) + fiber <- (status.set(Some(Left(()))) *> io.guaranteeCase(outcome => status.set(Some(Right(outcome))))).start + } yield FiberWithStatus(fiber, status) + + def startAndWaitUntilRunning(timeout: FiniteDuration = 250.millis): IO[FiberIO[A]] = + io.startWithStatus.flatTap(_.waitUntilRunning(timeout)).map(_.fiber) + } + + /** Waits at most `waitFor` until the `io` starts returning `expected`, failing the assertion otherwise. */ + def waitUntilEquals[A, B]( + io: IO[A], + expected: B, + waitFor: FiniteDuration, + clue: => Any = "values are not the same" + )(implicit loc: Location, compare: Compare[A, B]): IO[Unit] = { + val checker = false.iterateUntilM(_ => + io.map(compare.isEqual(_, expected)).flatTap(if (_) IO.unit else IO.sleep(50.millis)) + )(identity) + checker.void.timeoutTo(waitFor, io.map(assertEquals(_, expected, clue))) + } } object Redis4CatsFunSuite { type Fs2PubSub[K, V] = PubSubCommands[IO, fs2.Stream[IO, *], K, V] diff --git a/modules/tests/src/test/scala/dev/profunktor/redis4cats/RedisPubSubSpec.scala b/modules/tests/src/test/scala/dev/profunktor/redis4cats/RedisPubSubSpec.scala index 6d1c014bb..9f60b017b 100644 --- a/modules/tests/src/test/scala/dev/profunktor/redis4cats/RedisPubSubSpec.scala +++ b/modules/tests/src/test/scala/dev/profunktor/redis4cats/RedisPubSubSpec.scala @@ -19,6 +19,7 @@ package dev.profunktor.redis4cats import dev.profunktor.redis4cats.data.{ RedisChannel, RedisPattern, RedisPatternEvent } import cats.effect.IO import cats.effect.kernel.Deferred + import scala.concurrent.duration._ class RedisPubSubSpec extends Redis4CatsFunSuite(isCluster = false) { @@ -56,6 +57,103 @@ class RedisPubSubSpec extends Redis4CatsFunSuite(isCluster = false) { } } + test("subscribe: to same channel should share an underlying subscription") { + withRedisPubSub { pubSub => + val channel = RedisChannel("test-pubsub-shared") + + for { + sub1 <- pubSub.subscribe(channel).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fiber started. + _ <- pubSub.shouldHaveNSubs(channel, 1) + _ <- pubSub.internalChannelSubscriptions.map(assertEquals(_, Map(channel -> 1L))) + sub2 <- pubSub.subscribe(channel).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fiber started. + _ <- pubSub.internalChannelSubscriptions.map(assertEquals(_, Map(channel -> 2L))) + _ <- pubSub.shouldHaveNSubs(channel, 1) + _ <- sub1.cancel + _ <- pubSub.internalChannelSubscriptions.map(assertEquals(_, Map(channel -> 1L))) + _ <- pubSub.shouldHaveNSubs(channel, 1) + _ <- sub2.cancel + _ <- pubSub.internalChannelSubscriptions.map(assertEquals(_, Map.empty[RedisChannel[String], Long])) + _ <- pubSub.shouldHaveNSubs(channel, 0) + } yield () + } + } + + test("subscribe: messages should be delivered to all subscribers") { + withRedisPubSub { pubSub => + val channel = RedisChannel("test-pubsub-shared") + + for { + sub1 <- pubSub.subscribe(channel).compile.toVector.start + sub2 <- pubSub.subscribe(channel).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fiber started. + _ <- pubSub.publish(channel, "hello") + _ <- IO.sleep(200.millis) // Wait to make sure the message is delivered. + _ <- pubSub.unsubscribe(channel) + sub1Result <- sub1.joinWith(IO.raiseError(new Exception(s"sub1 should not have been cancelled"))) + _ <- IO(assertEquals(sub1Result, Vector("hello"))) + sub2Result <- sub2.joinWith(IO.raiseError(new Exception(s"sub2 should not have been cancelled"))) + _ <- IO(assertEquals(sub2Result, Vector("hello"))) + } yield () + } + } + + test("unsubscribe: should terminate all listening streams") { + withRedisPubSub { pubSub => + val channel = RedisChannel("test-pubsub-shared") + + for { + sub1 <- pubSub.subscribe(channel).compile.toVector.start + sub2 <- pubSub.subscribe(channel).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fibers have started ands streams started processing. + _ <- pubSub.internalChannelSubscriptions.map(assertEquals(_, Map(channel -> 2L))) + _ <- pubSub.shouldHaveNSubs(channel, 1) + _ <- pubSub.unsubscribe(channel) + _ <- sub1.joinWith(IO.raiseError(new Exception("sub1 should not have been cancelled"))) + _ <- sub2.joinWith(IO.raiseError(new Exception("sub2 should not have been cancelled"))) + _ <- pubSub.shouldHaveNSubs(channel, 0) + _ <- pubSub.internalChannelSubscriptions.map(assertEquals(_, Map.empty[RedisChannel[String], Long])) + } yield () + } + } + + test("psubscribe: to same pattern should share an underlying subscription") { + withRedisPubSub { pubSub => + val pattern = RedisPattern("test-pubsub-shared:pattern:*") + + for { + sub1 <- pubSub.psubscribe(pattern).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fiber started. + _ <- pubSub.internalPatternSubscriptions.map(assertEquals(_, Map(pattern -> 1L))) + sub2 <- pubSub.psubscribe(pattern).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fiber started. + _ <- pubSub.internalPatternSubscriptions.map(assertEquals(_, Map(pattern -> 2L))) + _ <- sub1.cancel + _ <- pubSub.internalPatternSubscriptions.map(assertEquals(_, Map(pattern -> 1L))) + _ <- sub2.cancel + _ <- pubSub.internalPatternSubscriptions.map(assertEquals(_, Map.empty[RedisPattern[String], Long])) + } yield () + } + } + + test("punsubscribe: should terminate all streams") { + withRedisPubSub { pubSub => + val pattern = RedisPattern("test-pubsub-shared:pattern:*") + + for { + sub1 <- pubSub.psubscribe(pattern).compile.toVector.start + sub2 <- pubSub.psubscribe(pattern).compile.toVector.start + _ <- IO.sleep(200.millis) // Wait to make sure the fibers have started ands streams started processing. + _ <- pubSub.internalPatternSubscriptions.map(assertEquals(_, Map(pattern -> 2L))) + _ <- pubSub.punsubscribe(pattern) + _ <- sub1.joinWith(IO.raiseError(new Exception("sub1 should not have been cancelled"))) + _ <- sub2.joinWith(IO.raiseError(new Exception("sub2 should not have been cancelled"))) + _ <- pubSub.internalPatternSubscriptions.map(assertEquals(_, Map.empty[RedisPattern[String], Long])) + } yield () + } + } + test("subscribing to a silent channel should not fail with RedisCommandTimeoutException") { timeoutingOperationTest { (options, _) => fs2.Stream.resource(withRedisPubSubOptionsResource(options)).flatMap { pubSub => diff --git a/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala b/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala index 61f0cf334..f3a93c207 100644 --- a/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala +++ b/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala @@ -789,7 +789,7 @@ trait TestScenarios { self: FunSuite => .awakeEvery[IO](100.milli) .as(message) .through(pubsub.publish(RedisChannel(channel))) - .recover { case _: RedisException => () } + .recover { case _: RedisException => 0L } .interruptWhen(i) _ <- Resource.eval(Stream(s1, s2).parJoin(2).compile.drain) fe <- Resource.eval(gate.get)