diff --git a/core/shared/src/main/scala/fs2/concurrent/Signal.scala b/core/shared/src/main/scala/fs2/concurrent/Signal.scala index 8188b815f7..0e9502b6e4 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Signal.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Signal.scala @@ -24,6 +24,7 @@ package concurrent import cats.data.OptionT import cats.effect.kernel.{Concurrent, Deferred, Ref} +import cats.effect.std.MapRef import cats.syntax.all._ import cats.{Applicative, Functor, Invariant, Monad} @@ -287,6 +288,70 @@ object SignallingRef { } } + /** Creates an instance focused on a component of another SignallingRef's value. Delegates every get and + * modification to underlying SignallingRef, so both instances are always in sync. + */ + def lens[F[_], A, B]( + ref: SignallingRef[F, A] + )(get: A => B, set: A => B => A)(implicit F: Functor[F]): SignallingRef[F, B] = + new LensSignallingRef(ref)(get, set) + + private final class LensSignallingRef[F[_], A, B](underlying: SignallingRef[F, A])( + lensGet: A => B, + lensSet: A => B => A + )(implicit F: Functor[F]) + extends SignallingRef[F, B] { + + def discrete: Stream[F, B] = underlying.discrete.map(lensGet) + + def continuous: Stream[F, B] = underlying.continuous.map(lensGet) + + def get: F[B] = F.map(underlying.get)(a => lensGet(a)) + + def set(b: B): F[Unit] = underlying.update(a => lensModify(a)(_ => b)) + + override def getAndSet(b: B): F[B] = + underlying.modify(a => (lensModify(a)(_ => b), lensGet(a))) + + def update(f: B => B): F[Unit] = + underlying.update(a => lensModify(a)(f)) + + def modify[C](f: B => (B, C)): F[C] = + underlying.modify { a => + val oldB = lensGet(a) + val (b, c) = f(oldB) + (lensSet(a)(b), c) + } + + def tryUpdate(f: B => B): F[Boolean] = + F.map(tryModify(a => (f(a), ())))(_.isDefined) + + def tryModify[C](f: B => (B, C)): F[Option[C]] = + underlying.tryModify { a => + val oldB = lensGet(a) + val (b, result) = f(oldB) + (lensSet(a)(b), result) + } + + def tryModifyState[C](state: cats.data.State[B, C]): F[Option[C]] = { + val f = state.runF.value + tryModify(a => f(a).value) + } + + def modifyState[C](state: cats.data.State[B, C]): F[C] = { + val f = state.runF.value + modify(a => f(a).value) + } + + val access: F[(B, B => F[Boolean])] = + F.map(underlying.access) { case (a, update) => + (lensGet(a), b => update(lensSet(a)(b))) + } + + private def lensModify(s: A)(f: B => B): A = lensSet(s)(f(lensGet(s))) + + } + implicit def invariantInstance[F[_]: Functor]: Invariant[SignallingRef[F, *]] = new Invariant[SignallingRef[F, *]] { override def imap[A, B](fa: SignallingRef[F, A])(f: A => B)(g: B => A): SignallingRef[F, B] = @@ -317,6 +382,131 @@ object SignallingRef { } } +/** A [[MapRef]] with a [[SignallingRef]] for each key. */ +trait SignallingMapRef[F[_], K, V] extends MapRef[F, K, V] { + override def apply(k: K): SignallingRef[F, V] +} + +object SignallingMapRef { + + /** Builds a `SignallingMapRef` for effect `F`, initialized to the supplied value. + */ + def ofSingleImmutableMap[F[_], K, V]( + initial: Map[K, V] = Map.empty[K, V] + )(implicit F: Concurrent[F]): F[SignallingMapRef[F, K, Option[V]]] = { + case class State( + value: Map[K, V], + lastUpdate: Long, + listeners: Map[K, LongMap[Deferred[F, (Option[V], Long)]]] + ) + + F.ref(State(initial, 0L, initial.flatMap(_ => Nil))) + .product(F.ref(1L)) + .map { case (state, ids) => + def newId = ids.getAndUpdate(_ + 1) + + def updateAndNotify[U](state: State, k: K, f: Option[V] => (Option[V], U)) + : (State, F[U]) = { + val (newValue, result) = f(state.value.get(k)) + val newMap = newValue.fold(state.value - k)(v => state.value + (k -> v)) + val lastUpdate = state.lastUpdate + 1 + val newListeners = state.listeners - k + val newState = State(newMap, lastUpdate, newListeners) + val notifyListeners = state.listeners.get(k).fold(F.unit) { listeners => + listeners.values.toVector.traverse_ { listener => + listener.complete(newValue -> lastUpdate) + } + } + + newState -> notifyListeners.as(result) + } + + k => + new SignallingRef[F, Option[V]] { + def get: F[Option[V]] = state.get.map(_.value.get(k)) + + def continuous: Stream[F, Option[V]] = Stream.repeatEval(get) + + def discrete: Stream[F, Option[V]] = { + def go(id: Long, lastSeen: Long): Stream[F, Option[V]] = { + def getNext: F[(Option[V], Long)] = + F.deferred[(Option[V], Long)].flatMap { wait => + state.modify { case state @ State(value, lastUpdate, listeners) => + if (lastUpdate != lastSeen) + state -> (value.get(k) -> lastUpdate).pure[F] + else { + val newListeners = + listeners + .updated(k, listeners.getOrElse(k, LongMap.empty) + (id -> wait)) + state.copy(listeners = newListeners) -> wait.get + } + }.flatten + } + + Stream.eval(getNext).flatMap { case (v, lastUpdate) => + Stream.emit(v) ++ go(id, lastSeen = lastUpdate) + } + } + + def cleanup(id: Long): F[Unit] = + state.update { s => + val newListeners = s.listeners + .get(k) + .map(_ - id) + .filterNot(_.isEmpty) + .fold(s.listeners - k)(s.listeners.updated(k, _)) + s.copy(listeners = newListeners) + } + + Stream.bracket(newId)(cleanup).flatMap { id => + Stream.eval(state.get).flatMap { state => + Stream.emit(state.value.get(k)) ++ go(id, state.lastUpdate) + } + } + } + + def set(v: Option[V]): F[Unit] = update(_ => v) + + def update(f: Option[V] => Option[V]): F[Unit] = modify(v => (f(v), ())) + + def modify[U](f: Option[V] => (Option[V], U)): F[U] = + state.modify(updateAndNotify(_, k, f)).flatten + + def tryModify[U](f: Option[V] => (Option[V], U)): F[Option[U]] = + state.tryModify(updateAndNotify(_, k, f)).flatMap(_.sequence) + + def tryUpdate(f: Option[V] => Option[V]): F[Boolean] = + tryModify(a => (f(a), ())).map(_.isDefined) + + def access: F[(Option[V], Option[V] => F[Boolean])] = + state.access.map { case (state, set) => + val setter = { (newValue: Option[V]) => + val (newState, notifyListeners) = + updateAndNotify(state, k, _ => (newValue, ())) + + set(newState).flatTap { succeeded => + notifyListeners.whenA(succeeded) + } + } + + (state.value.get(k), setter) + } + + def tryModifyState[U](state: cats.data.State[Option[V], U]): F[Option[U]] = { + val f = state.runF.value + tryModify(v => f(v).value) + } + + def modifyState[U](state: cats.data.State[Option[V], U]): F[U] = { + val f = state.runF.value + modify(v => f(v).value) + } + } + } + } + +} + private[concurrent] trait SignalInstances extends SignalLowPriorityInstances { implicit def applicativeInstance[F[_]: Concurrent]: Applicative[Signal[F, *]] = { def nondeterministicZip[A0, A1](xs: Stream[F, A0], ys: Stream[F, A1]): Stream[F, (A0, A1)] = { diff --git a/core/shared/src/test/scala/fs2/concurrent/SignalSuite.scala b/core/shared/src/test/scala/fs2/concurrent/SignalSuite.scala index 9c423140df..655086e8df 100644 --- a/core/shared/src/test/scala/fs2/concurrent/SignalSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/SignalSuite.scala @@ -59,6 +59,53 @@ class SignalSuite extends Fs2Suite { } } + test("lens - get/set/discrete") { + case class Foo(bar: Long, baz: Long) + object Foo { + def get(foo: Foo): Long = foo.bar + def set(foo: Foo)(bar: Long): Foo = foo.copy(bar = bar) + } + + forAllF { (vs0: List[Long]) => + val vs = vs0.map(n => if (n == 0) 1 else n) + SignallingRef[IO].of(Foo(0L, -1L)).flatMap { s => + val l = SignallingRef.lens(s)(Foo.get, Foo.set) + Ref.of[IO, Foo](Foo(0L, -1L)).flatMap { r => + val publisher = s.discrete.evalMap(r.set) + val consumer = vs.traverse { v => + l.set(v) >> waitFor(l.get.map(_ == v)) >> waitFor( + r.get.flatMap(rval => + if (rval == Foo(0L, -1L)) IO.pure(true) + else waitFor(r.get.map(_ == Foo(v, -1L))).as(true) + ) + ) + } + Stream.eval(consumer).concurrently(publisher).compile.drain + } + } + } + } + + test("mapref - get/set/discrete") { + forAllF { (vs0: List[Option[Long]]) => + val vs = vs0.map(_.map(n => if (n == 0) 1 else n)) + SignallingMapRef.ofSingleImmutableMap[IO, Unit, Long](Map(() -> 0L)).map(_(())).flatMap { s => + Ref.of[IO, Option[Long]](Some(0)).flatMap { r => + val publisher = s.discrete.evalMap(r.set) + val consumer = vs.traverse { v => + s.set(v) >> waitFor(s.get.map(_ == v)) >> waitFor( + r.get.flatMap(rval => + if (rval == Some(0)) IO.pure(true) + else waitFor(r.get.map(_ == v)).as(true) + ) + ) + } + Stream.eval(consumer).concurrently(publisher).compile.drain + } + } + } + } + test("discrete") { // verifies that discrete always receives the most recent value, even when updates occur rapidly forAllF { (v0: Long, vsTl: List[Long]) => @@ -80,6 +127,27 @@ class SignalSuite extends Fs2Suite { } } + test("mapref - discrete") { + // verifies that discrete always receives the most recent value, even when updates occur rapidly + forAllF { (v0: Option[Long], vsTl: List[Option[Long]]) => + val vs = v0 :: vsTl + SignallingMapRef.ofSingleImmutableMap[IO, Unit, Long](Map(() -> 0L)).map(_(())).flatMap { s => + Ref.of[IO, Option[Long]](Some(0L)).flatMap { r => + val publisherR = s.discrete.evalMap(i => IO.sleep(10.millis) >> r.set(i)) + val publisherS = vs.traverse(s.set) + val last = vs.last + val consumer = waitFor(r.get.map(_ == last)) + Stream + .eval(consumer) + .concurrently(publisherR) + .concurrently(Stream.eval(publisherS)) + .compile + .drain + } + } + } + } + test("access cannot be used twice") { for { s <- SignallingRef[IO, Long](0L) @@ -97,6 +165,23 @@ class SignalSuite extends Fs2Suite { } } + test("mapref - access cannot be used twice") { + for { + s <- SignallingMapRef.ofSingleImmutableMap[IO, Unit, Long](Map(() -> 0L)).map(_(())) + access <- s.access + (v, set) = access + v1 = v.map(_ + 1) + v2 = v1.map(_ + 1) + r1 <- set(v1) + r2 <- set(v2) + r3 <- s.get + } yield { + assert(r1) + assert(!r2) + assertEquals(r3, v1) + } + } + test("access updates discrete") { SignallingRef[IO, Int](0).flatMap { s => def cas: IO[Unit] = @@ -113,6 +198,22 @@ class SignalSuite extends Fs2Suite { } } + test("mapref - access updates discrete") { + SignallingMapRef.ofSingleImmutableMap[IO, Unit, Int](Map(() -> 0)).map(_(())).flatMap { s => + def cas: IO[Unit] = + s.access.flatMap { case (v, set) => + set(v.map(_ + 1)).ifM(IO.unit, cas) + } + + def updates = + s.discrete.takeWhile(_ != Some(1)).compile.drain + + updates.start.flatMap { fiber => + cas >> fiber.join.timeout(5.seconds) + } + } + } + test("holdOption") { val s = Stream.range(1, 10).covary[IO].holdOption s.compile.drain