diff --git a/build.sbt b/build.sbt index be273a235e..8eade5020a 100644 --- a/build.sbt +++ b/build.sbt @@ -205,9 +205,9 @@ lazy val core = crossProject(JVMPlatform, JSPlatform, NativePlatform) libraryDependencies ++= Seq( "org.typelevel" %%% "cats-core" % "2.8.0", "org.typelevel" %%% "cats-laws" % "2.8.0" % Test, - "org.typelevel" %%% "cats-effect" % "3.4.0-RC2", - "org.typelevel" %%% "cats-effect-laws" % "3.4.0-RC2" % Test, - "org.typelevel" %%% "cats-effect-testkit" % "3.4.0-RC2" % Test, + "org.typelevel" %%% "cats-effect" % "3.4-7154d08", + "org.typelevel" %%% "cats-effect-laws" % "3.4-7154d08" % Test, + "org.typelevel" %%% "cats-effect-testkit" % "3.4-7154d08" % Test, "org.scodec" %%% "scodec-bits" % "1.1.34", "org.typelevel" %%% "scalacheck-effect-munit" % "2.0.0-M2" % Test, "org.typelevel" %%% "munit-cats-effect" % "2.0.0-M3" % Test, diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index bc0383b64c..1cb2e715a8 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -2106,7 +2106,10 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, val action = ( Semaphore[F2](concurrency.toLong), - Channel.bounded[F2, F2[Either[Throwable, O2]]](concurrency), + if (concurrency >= Short.MaxValue) + Channel.unbounded[F2, F2[Either[Throwable, O2]]] + else + Channel.bounded[F2, F2[Either[Throwable, O2]]](concurrency), Deferred[F2, Unit], Deferred[F2, Unit] ).mapN { (semaphore, channel, stop, end) => diff --git a/core/shared/src/main/scala/fs2/concurrent/Channel.scala b/core/shared/src/main/scala/fs2/concurrent/Channel.scala index 86bda25773..6996163742 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Channel.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Channel.scala @@ -22,8 +22,10 @@ package fs2 package concurrent +import cats.Applicative import cats.effect._ -import cats.effect.implicits._ +import cats.effect.std.Queue +import cats.effect.syntax.all._ import cats.syntax.all._ /** Stream aware, multiple producer, single consumer closeable channel. @@ -116,174 +118,170 @@ sealed trait Channel[F[_], A] { /** Semantically blocks until the channel gets closed. */ def closed: F[Unit] } + object Channel { type Closed = Closed.type object Closed def unbounded[F[_]: Concurrent, A]: F[Channel[F, A]] = - bounded(Int.MaxValue) + Queue.unbounded[F, AnyRef].flatMap(impl(_)) def synchronous[F[_]: Concurrent, A]: F[Channel[F, A]] = - bounded(0) + Queue.synchronous[F, AnyRef].flatMap(impl(_)) def bounded[F[_], A](capacity: Int)(implicit F: Concurrent[F]): F[Channel[F, A]] = { - case class State( - values: List[A], - size: Int, - waiting: Option[Deferred[F, Unit]], - producers: List[(A, Deferred[F, Unit])], - closed: Boolean - ) + require(capacity < Short.MaxValue) + Queue.bounded[F, AnyRef](capacity).flatMap(impl(_)) + } + + // used as a marker to wake up q.take when the channel is closed + private[this] val Sentinel = new AnyRef + + private[this] val LeftClosed: Either[Channel.Closed, Unit] = Left(Channel.Closed) + private[this] val RightUnit: Either[Channel.Closed, Unit] = Right(()) - val open = State(List.empty, 0, None, List.empty, closed = false) + private final case class State(leases: Int, closed: Boolean) - def empty(isClosed: Boolean): State = - if (isClosed) State(List.empty, 0, None, List.empty, closed = true) - else open + private object State { + val Empty: State = State(0, false) + } - (F.ref(open), F.deferred[Unit]).mapN { (state, closedGate) => + // technically this should be A | Sentinel.type + // the queue will consist of exclusively As until we shut down, when there will be one Sentinel + private[this] def impl[F[_]: Concurrent, A](q: Queue[F, AnyRef]): F[Channel[F, A]] = + (Concurrent[F].ref(State.Empty), Concurrent[F].deferred[Unit]).mapN { (stateR, closedLatch) => new Channel[F, A] { - def sendAll: Pipe[F, A, Nothing] = { in => - (in ++ Stream.exec(close.void)) - .evalMap(send) + private[this] val LeftClosedF = LeftClosed.pure[F] + private[this] val FalseF = false.pure[F] + + // might be interesting to try to optimize this more, but it needs support from CE + val sendAll: Pipe[F, A, Nothing] = + _.evalMapChunk(send(_)) .takeWhile(_.isRight) + .onComplete(Stream.exec(close.void)) .drain + + // setting the flag means we won't accept any more sends + val close: F[Either[Channel.Closed, Unit]] = { + val modifyF = stateR.modify { + case State(0, false) => + State(0, true) -> closedLatch.complete(()) *> q.offer(Sentinel).start.as(RightUnit) + + case State(leases, false) => + State(leases, true) -> closedLatch.complete(()).as(RightUnit) + + case st @ State(_, true) => + st -> LeftClosedF + } + + modifyF.flatten.uncancelable } - def send(a: A) = - F.deferred[Unit].flatMap { producer => - F.uncancelable { poll => - state.modify { - case s @ State(_, _, _, _, closed @ true) => - (s, Channel.closed[Unit].pure[F]) - - case State(values, size, waiting, producers, closed @ false) => - if (size < capacity) - ( - State(a :: values, size + 1, None, producers, false), - notifyStream(waiting).as(rightUnit) - ) - else - ( - State(values, size, None, (a, producer) :: producers, false), - notifyStream(waiting).as(rightUnit) <* waitOnBound(producer, poll) - ) - }.flatten - } + val isClosed: F[Boolean] = stateR.get.map(_.closed) + + // there are four states to worry about: open, closing, draining, quiesced + // in the second state, we have outstanding blocked sends + // in the third state we have data in the queue but no sends + // in the fourth state we are completely drained and can shut down the stream + private[this] val isQuiesced: F[Boolean] = + stateR.get.flatMap { + case State(0, true) => q.size.map(_ == 0) + case _ => FalseF } - def trySend(a: A) = - state.modify { - case s @ State(_, _, _, _, closed @ true) => - (s, Channel.closed[Boolean].pure[F]) - - case s @ State(values, size, waiting, producers, closed @ false) => - if (size < capacity) - ( - State(a :: values, size + 1, None, producers, false), - notifyStream(waiting).as(rightTrue) - ) - else - (s, rightFalse.pure[F]) - }.flatten - - def close = - state - .modify { - case s @ State(_, _, _, _, closed @ true) => - (s, Channel.closed[Unit].pure[F]) - - case State(values, size, waiting, producers, closed @ false) => - ( - State(values, size, None, producers, true), - notifyStream(waiting).as(rightUnit) <* signalClosure - ) - } - .flatten - .uncancelable + def send(a: A): F[Either[Channel.Closed, Unit]] = + MonadCancel[F].uncancelable { poll => + // we track the outstanding blocked offers so we can distinguish closing from draining + // the very last blocked send, when closed, is responsible for triggering the sentinel - def isClosed = closedGate.tryGet.map(_.isDefined) + val modifyF = stateR.modify { + case st @ State(_, true) => + st -> LeftClosedF - def closed = closedGate.get + case State(leases, false) => + val cleanupF = { + val modifyF = stateR.modify { + case State(1, true) => + State(0, true) -> q.offer(Sentinel).start.void - def stream = consumeLoop.stream + case State(leases, closed) => + State(leases - 1, closed) -> Applicative[F].unit + } - def consumeLoop: Pull[F, A, Unit] = - Pull.eval { - F.deferred[Unit].flatMap { waiting => - state - .modify { state => - if (shouldEmit(state)) (empty(state.closed), state) - else (state.copy(waiting = waiting.some), state) - } - .flatMap { - case s @ State( - initValues, - stateSize, - ignorePreviousWaiting @ _, - producers, - closed - ) => - if (shouldEmit(s)) { - var size = stateSize - val tailValues = List.newBuilder[A] - var unblock = F.unit - - producers.foreach { case (value, producer) => - size += 1 - tailValues += value - unblock = unblock <* producer.complete(()) - } - - val toEmit = makeChunk(initValues, tailValues.result(), size) - - unblock.as(Pull.output(toEmit) >> consumeLoop) - } else { - F.pure( - if (closed) Pull.done - else Pull.eval(waiting.get) >> consumeLoop - ) - } + modifyF.flatten } - .uncancelable - } - }.flatten - def notifyStream(waitForChanges: Option[Deferred[F, Unit]]) = - waitForChanges.traverse(_.complete(())) + val offerF = poll(q.offer(a.asInstanceOf[AnyRef]).as(RightUnit)) - def waitOnBound(producer: Deferred[F, Unit], poll: Poll[F]) = - poll(producer.get).onCancel { - state.update { s => - s.copy(producers = s.producers.filter(_._2 ne producer)) + State(leases + 1, false) -> offerF.guarantee(cleanupF).as(RightUnit) } + + modifyF.flatten + } + + def trySend(a: A): F[Either[Channel.Closed, Boolean]] = + isClosed.flatMap { b => + if (b) + LeftClosedF.asInstanceOf[F[Either[Channel.Closed, Boolean]]] + else + q.tryOffer(a.asInstanceOf[AnyRef]).map(_.asRight[Channel.Closed]) } - def signalClosure = closedGate.complete(()) + val stream: Stream[F, A] = { + lazy val loop: Pull[F, A, Unit] = { + val pullF = q.tryTakeN(None).flatMap { + case Nil => + // if we land here, it either means we're consuming faster than producing + // or it means we're actually closed and we need to shut down + // this is the unhappy path either way + + val fallback = q.take.map { a => + // if we get the sentinel, shut down all the things, otherwise emit + if (a eq Sentinel) + Pull.done + else + Pull.output1(a.asInstanceOf[A]) >> loop + } - @inline private def shouldEmit(s: State) = s.values.nonEmpty || s.producers.nonEmpty + // check to see if we're closed and done processing + // if we're all done, complete the latch and terminate the stream + isQuiesced.map { b => + if (b) + Pull.done + else + Pull.eval(fallback).flatten + } - private def makeChunk(init: List[A], tail: List[A], size: Int): Chunk[A] = { - val arr = new Array[Any](size) - var i = size - 1 - var values = tail - while (i >= 0) { - if (values.isEmpty) values = init - arr(i) = values.head - values = values.tail - i -= 1 + case as => + // this is the happy path: we were able to take a chunk + // meaning we're producing as fast or faster than we're consuming + + isClosed.map { b => + if (b) { + // if we're closed, we have to check for the sentinel and strip it out + val as2 = as.filter(_ ne Sentinel) + + // if it's empty, we definitely stripped a sentinel, so just be done + // if it's non-empty, we can't know without expensive comparisons, so fall through + if (as2.isEmpty) + Pull.done + else + Pull.output(Chunk.seq(as2.asInstanceOf[List[A]])) >> loop + } else { + Pull.output(Chunk.seq(as.asInstanceOf[List[A]])) >> loop + } + } + } + + Pull.eval(pullF).flatten } - Chunk.array(arr).asInstanceOf[Chunk[A]] + + loop.stream } + + // closedLatch solely exists to support this function + val closed: F[Unit] = closedLatch.get } } - } - - // allocate once - @inline private final def closed[A]: Either[Closed, A] = _closed - private[this] final val _closed: Either[Closed, Nothing] = Left(Closed) - private final val rightUnit: Either[Closed, Unit] = Right(()) - private final val rightTrue: Either[Closed, Boolean] = Right(true) - private final val rightFalse: Either[Closed, Boolean] = Right(false) } diff --git a/core/shared/src/main/scala/fs2/concurrent/Topic.scala b/core/shared/src/main/scala/fs2/concurrent/Topic.scala index 7868041760..7f5d9bff5e 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Topic.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Topic.scala @@ -71,7 +71,8 @@ abstract class Topic[F[_], A] { self => * * If at any point, the queue backing the subscription has `maxQueued` elements in it, * any further publications semantically block until elements are dequeued from the - * subscription queue. + * subscription queue. Any value of `maxQueued` which is greater than `Short.MaxValue` + * is treated as unbounded. * * @param maxQueued maximum number of elements to enqueue to the subscription * queue before blocking publishers @@ -80,7 +81,9 @@ abstract class Topic[F[_], A] { self => /** Like `subscribe`, but represents the subscription explicitly as * a `Resource` which returns after the subscriber is subscribed, - * but before it has started pulling elements. + * but before it has started pulling elements. Note that any value + * of `maxQueued` which is greater than `Short.MaxValue` will be + * treated as "unbounded". */ def subscribeAwait(maxQueued: Int): Resource[F, Stream[F, A]] @@ -159,7 +162,10 @@ object Topic { def subscribeAwait(maxQueued: Int): Resource[F, Stream[F, A]] = Resource - .eval(Channel.bounded[F, A](maxQueued)) + .eval( + if (maxQueued >= Short.MaxValue) Channel.unbounded[F, A] + else Channel.bounded[F, A](maxQueued) + ) .flatMap { chan => val subscribe = state.modify { case (subs, id) => (subs.updated(id, chan), id + 1) -> id diff --git a/core/shared/src/test/scala/fs2/concurrent/BroadcastSuite.scala b/core/shared/src/test/scala/fs2/concurrent/BroadcastSuite.scala index ef78504371..ddd7480d6b 100644 --- a/core/shared/src/test/scala/fs2/concurrent/BroadcastSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/BroadcastSuite.scala @@ -53,7 +53,7 @@ class BroadcastSuite extends Fs2Suite { test("all subscribers see all elements, pipe immediately interrupted") { forAllF { (source: Stream[Pure, Int], concurrent0: Int) => val concurrent = (concurrent0 % 20).abs.max(1) - val interruptedPipe = scala.util.Random.nextInt(concurrent) + val interruptedPipe = 0 val expected = source.compile.toVector.map(_.toString) def pipe(idx: Int): Pipe[IO, Int, (Int, String)] = @@ -79,6 +79,7 @@ class BroadcastSuite extends Fs2Suite { result.foreach(it => assertEquals(it, expected)) } else assert(result.isEmpty) } + // .replicateA(100) } } diff --git a/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala b/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala index 450c402126..ce250c0e07 100644 --- a/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala @@ -30,6 +30,19 @@ import scala.concurrent.duration._ import org.scalacheck.effect.PropF.forAllF class ChannelSuite extends Fs2Suite { + + test("receives some simple elements above capacity and closes") { + val test = Channel.bounded[IO, Int](5).flatMap { chan => + val senders = 0.until(10).toList.parTraverse_ { i => + IO.sleep(i.millis) *> chan.send(i) + } + + senders &> (IO.sleep(15.millis) *> chan.close *> chan.stream.compile.toVector) + } + + TestControl.executeEmbed(test) + } + test("Channel receives all elements and closes") { forAllF { (source: Stream[Pure, Int]) => Channel.unbounded[IO, Int].flatMap { chan => @@ -133,7 +146,7 @@ class ChannelSuite extends Fs2Suite { p.assertEquals(true) } - test("Channel.synchronous respects fifo") { + test("synchronous respects fifo") { val l = for { chan <- Channel.synchronous[IO, Int] _ <- (0 until 5).toList.traverse_ { i => @@ -147,7 +160,95 @@ class ChannelSuite extends Fs2Suite { result <- IO.sleep(5.seconds) *> chan.stream.compile.toList } yield result - TestControl.executeEmbed(l).assertEquals((0 until 5).toList) + TestControl.executeEmbed(l).assertEquals((0 until 5).toList).parReplicateA_(100) } + test("complete all blocked sends after closure") { + val test = for { + chan <- Channel.bounded[IO, Int](2) + + fiber <- 0.until(5).toList.parTraverse(chan.send(_)).start + _ <- IO.sleep(1.second) + _ <- chan.close + + results <- chan.stream.compile.toList + _ <- IO(assert(results.length == 5)) + _ <- IO(assert(0.until(5).forall(results.contains(_)))) + + sends <- fiber.joinWithNever + _ <- IO(assert(sends.forall(_ == Right(())))) + } yield () + + TestControl.executeEmbed(test).parReplicateA_(100) + } + + test("eagerly close sendAll upstream") { + for { + countR <- IO.ref(0) + chan <- Channel.unbounded[IO, Unit] + + incrementer = Stream.eval(countR.update(_ + 1)).repeat.take(5) + upstream = incrementer ++ Stream.eval(chan.close).drain ++ incrementer + + results <- chan.stream.concurrently(upstream.through(chan.sendAll)).compile.toList + + _ <- IO(assert(results.length == 5)) + count <- countR.get + _ <- IO(assert(count == 6)) // we have to overrun the closure to detect it + } yield () + } + + def blackHole(s: Stream[IO, Unit]) = + s.repeatPull(_.uncons.flatMap { + case None => Pull.pure(None) + case Some((hd, tl)) => + val action = IO.delay(0.until(hd.size).foreach(_ => ())) + Pull.eval(action).as(Some(tl)) + }) + + @inline + private def sendAll(list: List[Unit], action: IO[Unit]) = + list.foldLeft(IO.unit)((acc, _) => acc *> action) + + test("sendPull") { + val test = Channel.bounded[IO, Unit](8).flatMap { channel => + val action = List.fill(64)(()).traverse_(_ => channel.send(()).void) *> channel.close + action.start *> channel.stream.through(blackHole).compile.drain + } + + test.replicateA_(if (isJVM) 1000 else 1) + } + + test("sendPullPar8") { + val lists = List.fill(8)(List.fill(8)(())) + + val test = Channel.bounded[IO, Unit](8).flatMap { channel => + val action = lists.parTraverse_(sendAll(_, channel.send(()).void)) *> channel.close + + action &> channel.stream.through(blackHole).compile.drain + } + + test.replicateA_(if (isJVM) 10000 else 1) + } + + test("synchronous with many concurrents and close") { + val test = Channel.synchronous[IO, Int].flatMap { ch => + 0.until(20).toList.parTraverse_(i => ch.send(i).iterateWhile(_.isRight)) &> + ch.stream.concurrently(Stream.eval(ch.close.delayBy(1.seconds))).compile.drain + } + + test.parReplicateA(100) + } + + test("complete closed immediately without draining") { + val test = Channel.bounded[IO, Int](20).flatMap { ch => + for { + _ <- 0.until(10).toList.parTraverse_(ch.send(_)) + _ <- ch.close + _ <- ch.closed + } yield () + } + + TestControl.executeEmbed(test) + } }