diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 9f6faad245..4fe347b35a 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -1378,12 +1378,12 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, } /** Splits this stream into a stream of chunks of elements, such that - * 1. each chunk in the output has at most `outputSize` elements, and + * 1. each chunk in the output has at most `chunkSize` elements, and * 2. the concatenation of those chunks, which is obtained by calling * `unchunks`, yields the same element sequence as this stream. * - * As `this` stream emits input elements, the result stream them in a - * waiting buffer, until it has enough elements to emit next chunk. + * As `this` stream ingests input elements, they will be collected in a + * waiting buffer, until it has enough elements to emit the next chunk. * * To avoid holding input elements for too long, this method takes a * `timeout`. This timeout is reset after each output chunk is emitted. @@ -1403,6 +1403,9 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, * When the input stream terminates, any accumulated elements are emitted * immediately in a chunk, even if `timeout` has not expired. * + * If the chunkSize is equal to zero the stream will block until the + * timeout expires at which point it will terminate. + * * @param chunkSize the maximum size of chunks emitted by resulting stream. * @param timeout maximum time that input elements are held in the buffer * before being emitted by the resulting stream. @@ -1410,106 +1413,64 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, def groupWithin[F2[x] >: F[x]]( chunkSize: Int, timeout: FiniteDuration - )(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] = { - - case class JunctionBuffer[T]( - data: Vector[T], - endOfSupply: Option[Either[Throwable, Unit]], - endOfDemand: Option[Either[Throwable, Unit]] - ) { - def splitAt(n: Int): (JunctionBuffer[T], JunctionBuffer[T]) = - if (this.data.size >= n) { - val (head, tail) = this.data.splitAt(n.toInt) - (this.copy(tail), this.copy(head)) - } else { - (this.copy(Vector.empty), this) - } - } - - val outputLong = chunkSize.toLong - fs2.Stream.force { - for { - demand <- Semaphore[F2](outputLong) - supply <- Semaphore[F2](0L) - buffer <- Ref[F2].of( - JunctionBuffer[O](Vector.empty[O], endOfSupply = None, endOfDemand = None) - ) - } yield { - /* - Buffer: stores items from input to be sent on next output chunk - * - Demand Semaphore: to avoid adding too many items to buffer - * - Supply: counts filled positions for next output chunk */ - def enqueue(t: O): F2[Boolean] = - for { - _ <- demand.acquire - buf <- buffer.modify(buf => (buf.copy(buf.data :+ t), buf)) - _ <- supply.release - } yield buf.endOfDemand.isEmpty - - val dequeueNextOutput: F2[Option[Vector[O]]] = { - // Trigger: waits until the supply buffer is full (with acquireN) - val waitSupply = supply.acquireN(outputLong).guaranteeCase { - case Outcome.Succeeded(_) => supply.releaseN(outputLong) - case _ => F.unit - } + )(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] = + if (chunkSize == 0) Stream.sleep_[F2](timeout) + else if (timeout.toNanos == 0 || chunkSize == 1) chunkN(chunkSize) + else + Stream.force { + for { + supply <- Semaphore[F2](0) + buffer <- Ref[F2].empty[Vector[O]] + backpressure <- Semaphore[F2](chunkSize.toLong) + supplyEnded <- SignallingRef[F2].of(false) + } yield { - val onTimeout: F2[Long] = - for { - _ <- supply.acquire // waits until there is at least one element in buffer - m <- supply.available - k = m.min(outputLong - 1) - b <- supply.tryAcquireN(k) - } yield if (b) k + 1 else 1 - - // in JS cancellation doesn't always seem to run, so race conditions should restore state on their own - for { - acq <- F.race(F.sleep(timeout), waitSupply).flatMap { - case Left(_) => onTimeout - case Right(_) => supply.acquireN(outputLong).as(outputLong) - } - buf <- buffer.modify(_.splitAt(acq.toInt)) - _ <- demand.releaseN(buf.data.size.toLong) - res <- buf.endOfSupply match { - case Some(Left(error)) => F.raiseError(error) - case Some(Right(_)) if buf.data.isEmpty => F.pure(None) - case _ => F.pure(Some(buf.data)) + def push(o: O): F2[Unit] = + backpressure.acquire *> buffer.update(_ :+ o) + + val flush: F2[Vector[O]] = + buffer.getAndSet(Vector.empty).flatTap(os => backpressure.releaseN(os.size.toLong)) + + // wait until the first chunk becomes available or when we reach the end of the stream. + val awaitSupply: F2[Unit] = + Stream.exec(supply.acquire).interruptWhen(supplyEnded).compile.drain + + // in order to ensure prompt termination on interruption when the timeout has not kicked + // in yet or if we haven't seen enough elements we need provide enough supply for 2 iterations + val endSupply: F2[Unit] = supplyEnded.set(true) *> supply.releaseN(chunkSize * 2L) + + // flush immediately or wait before doing so, subsequently lowering the supply by however + // many elements have been flushed (excluding the element already awaited, if needed) + def flushOnSupplyReceived(noSupply: Boolean): F2[Vector[O]] = for { + flushed <- awaitSupply.whenA(noSupply) *> flush + awaitedCount = if (noSupply) 1L else 0L + _ <- supply.acquireN((flushed.size - awaitedCount).max(0)) + } yield flushed + + // edge case: supply semaphore loses the race, but acquires the permits. In such scenario + // we flush the buffer without lowering the supply, since it has already been lowered + val onTimeout: F2[Vector[O]] = for { + bufferFull <- buffer.get.map(_.size == chunkSize) + noSupply <- supply.available.map(_ == 0) + edgeCase = bufferFull && noSupply + flushed <- if (edgeCase) flush else flushOnSupplyReceived(noSupply) + } yield flushed + + val enqueue: F2[Unit] = + foreach(push(_) *> supply.release).compile.drain.guarantee(endSupply) + + val dequeue: F2[Vector[O]] = + F.race(supply.acquireN(chunkSize.toLong), F.sleep(timeout)).flatMap { + case Left(_) => flush + case Right(_) => onTimeout } - } yield res - } - - def endSupply(result: Either[Throwable, Unit]): F2[Unit] = - buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue) - - def endDemand(result: Either[Throwable, Unit]): F2[Unit] = - buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue) - - def toEnding(ec: ExitCase): Either[Throwable, Unit] = ec match { - case ExitCase.Succeeded => Right(()) - case ExitCase.Errored(e) => Left(e) - case ExitCase.Canceled => Right(()) - } - - val enqueueAsync = F.start { - this - .evalMap(enqueue) - .forall(identity) - .onFinalizeCase(ec => endSupply(toEnding(ec))) - .compile - .drain - } - val outputStream: Stream[F2, Chunk[O]] = Stream - .eval(dequeueNextOutput) - .repeat - .collectWhile { case Some(data) => Chunk.vector(data) } - - Stream - .bracketCase(enqueueAsync) { case (upstream, exitCase) => - endDemand(toEnding(exitCase)) *> upstream.cancel - } >> outputStream + .repeatEval(dequeue) + .collectWhile { case os if os.nonEmpty => Chunk.vector(os) } + .concurrently(Stream.eval(enqueue)) + } } - } - } /** If `this` terminates with `Stream.raiseError(e)`, invoke `h(e)`. * diff --git a/core/shared/src/test/scala/fs2/Fs2Suite.scala b/core/shared/src/test/scala/fs2/Fs2Suite.scala index e50396ff67..a79fd699b7 100644 --- a/core/shared/src/test/scala/fs2/Fs2Suite.scala +++ b/core/shared/src/test/scala/fs2/Fs2Suite.scala @@ -89,6 +89,9 @@ abstract class Fs2Suite expect <- expected.compile.toList } yield assertEquals(actual.toSet, expect.toSet) + def assertCompletes: IO[Unit] = + str.compile.drain.assert + def intercept[T <: Throwable](implicit T: ClassTag[T], loc: Location): IO[T] = str.compile.drain.intercept[T] } diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index fb49f21339..26623bb51c 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -23,7 +23,7 @@ package fs2 import cats.effect.kernel.Deferred import cats.effect.kernel.Ref -import cats.effect.std.{Semaphore, Queue} +import cats.effect.std.{Queue, Semaphore} import cats.effect.testkit.TestControl import cats.effect.{IO, SyncIO} import cats.syntax.all._ @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll import scala.concurrent.duration._ import scala.concurrent.TimeoutException +import scala.util.control.NoStackTrace class StreamCombinatorsSuite extends Fs2Suite { override def munitIOTimeout = 1.minute @@ -748,7 +749,7 @@ class StreamCombinatorsSuite extends Fs2Suite { } } - test("accumulation and splitting".flaky) { + test("accumulation and splitting") { val t = 200.millis val size = 5 val sleep = Stream.sleep_[IO](2 * t) @@ -775,6 +776,36 @@ class StreamCombinatorsSuite extends Fs2Suite { source.groupWithin(size, t).map(_.toList).assertEmits(expected) } + test("accumulation and splitting 2") { + val t = 200.millis + val size = 5 + val sleep = Stream.sleep_[IO](2 * t) + val longSleep = sleep.repeatN(5) + + def chunk(from: Int, size: Int) = + Stream.range(from, from + size).chunkAll.unchunks + + // this test example is designed to have good coverage of + // the chunk manipulation logic in groupWithin + val source = + chunk(from = 1, size = 3) ++ + sleep ++ + chunk(from = 4, size = 1) ++ longSleep ++ + chunk(from = 5, size = 11) ++ + chunk(from = 16, size = 7) + + val expected = List( + List(1, 2, 3), + List(4), + List(5, 6, 7, 8, 9), + List(10, 11, 12, 13, 14), + List(15, 16, 17, 18, 19), + List(20, 21, 22) + ) + + source.groupWithin(size, t).map(_.toList).assertEmits(expected) + } + test("does not reset timeout if nothing is emitted") { TestControl .executeEmbed( @@ -834,6 +865,122 @@ class StreamCombinatorsSuite extends Fs2Suite { ) .assertEquals(0.millis) } + + test("stress test (short execution): all elements are processed") { + + val rangeLength = 100000 + + Stream + .eval(Ref[IO].of(0)) + .flatMap { counter => + Stream + .range(0, rangeLength) + .covary[IO] + .groupWithin(256, 100.micros) + .evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get) + } + .compile + .lastOrError + .assertEquals(rangeLength) + + } + + // ignoring because it's a (relatively) long running test (around 3/4 minutes), but it's useful + // to asses the validity of permits management and timeout logic over an extended period of time + test("stress test (long execution): all elements are processed".ignore) { + val rangeLength = 10000000 + + Stream + .eval(Ref[IO].of(0)) + .flatMap { counter => + Stream + .range(0, rangeLength) + .covary[IO] + .evalTap(d => IO.sleep((d % 10 + 2).micros)) + .groupWithin(275, 5.millis) + .evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get) + } + .compile + .lastOrError + .assertEquals(rangeLength) + } + + test("upstream failures are propagated downstream") { + + case object SevenNotAllowed extends NoStackTrace + + val source = Stream + .unfold(0)(s => Some((s, s + 1))) + .covary[IO] + .evalMap(n => if (n == 7) IO.raiseError(SevenNotAllowed) else IO.pure(n)) + + val downstream = source.groupWithin(100, 2.seconds) + + downstream.compile.lastOrError.intercept[SevenNotAllowed.type] + } + + test( + "upstream interruption causes immediate downstream termination with all elements being emitted" + ) { + + val sourceTimeout = 5.5.seconds + val downstreamTimeout = sourceTimeout + 2.seconds + + TestControl + .executeEmbed( + Ref[IO] + .of(0.millis) + .flatMap { ref => + val source: Stream[IO, Int] = + Stream + .unfold(0)(s => Some((s, s + 1))) + .covary[IO] + .meteredStartImmediately(1.second) + .interruptAfter(sourceTimeout) + + // large chunkSize and timeout (no emissions expected in the window + // specified, unless source ends, due to interruption or + // natural termination (i.e runs out of elements) + val downstream: Stream[IO, Chunk[Int]] = + source.groupWithin(Int.MaxValue, 1.day) + + downstream.compile.lastOrError + .map(_.toList) + .timeout(downstreamTimeout) + .flatTap(_ => IO.monotonic.flatMap(ref.set)) + .flatMap(emit => ref.get.map(timeLapsed => (timeLapsed, emit))) + } + ) + .assertEquals( + // downstream ended immediately (i.e timeLapsed = sourceTimeout) + // emitting whatever was accumulated at the time of interruption + (sourceTimeout, List(0, 1, 2, 3, 4, 5)) + ) + } + + test( + "Edge case: if the buffer fills up and timeout expires at the same time there won't be a deadlock" + ) { + + forAllF { (s0: Stream[Pure, Int], b: Byte) => + TestControl + .executeEmbed { + + // preventing empty or singleton streams that would bypass the logic being tested + val n = b.max(2).toInt + val s = s0 ++ Stream.range(0, n) + + // the buffer will reach its full capacity every + // n seconds exactly when the timeout expires + s + .covary[IO] + .metered(1.second) + .groupWithin(n, n.seconds) + .map(_.toList) + .assertCompletes + } + } + } } property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))