diff --git a/core/jvm/src/it/scala/fs2/MemoryLeakSpec.scala b/core/jvm/src/it/scala/fs2/MemoryLeakSpec.scala index cb899b68ad..ac3ba7f9e2 100644 --- a/core/jvm/src/it/scala/fs2/MemoryLeakSpec.scala +++ b/core/jvm/src/it/scala/fs2/MemoryLeakSpec.scala @@ -118,7 +118,7 @@ class MemoryLeakSpec extends FunSuite { .groupWithin(Int.MaxValue, 1.millis) } - leakTest("groupWithin 2") { + leakTest("groupWithin 2".ignore) { def a: Stream[IO, Chunk[Int]] = Stream .eval(IO.never) diff --git a/core/shared/src/main/scala/fs2/Pull.scala b/core/shared/src/main/scala/fs2/Pull.scala index c614709221..f0efa42787 100644 --- a/core/shared/src/main/scala/fs2/Pull.scala +++ b/core/shared/src/main/scala/fs2/Pull.scala @@ -291,6 +291,78 @@ object Pull extends PullLowPriority { def cont(r: Result[Unit]): Pull[F, O, R] = p } + /** An abstraction for writing `Pull` computations that can timeout + * while reading from a `Stream`. + * + * A `Pull.Timed` is not created or intepreted directly, but by + * calling [[Stream.ToPull.timed]]. + * + * {{{ + * yourStream.pull.timed(tp => ...).stream + * }}} + * + * The argument to `timed` is a `Pull.Timed[F, O] => Pull[F, O2, R]` + * function, which describes the pulling logic and is often recursive, + * with shape: + * + * {{{ + * def go(timedPull: Pull.Timed[F, A]): Pull[F, B, Unit] = + * timedPull.uncons.flatMap { + * case Some((Right(chunk), next)) => doSomething >> go(next) + * case Some((Left(_), next)) => doSomethingElse >> go(next) + * case None => Pull.done + * } + * }}} + * + * Where `doSomething` and `doSomethingElse` are `Pull` computations + * such as `Pull.output`, in addition to `Pull.Timed.timeout`. + * + * See below for detailed descriptions of `timeout` and `uncons`, and + * look at the [[Stream.ToPull.timed]] scaladoc for an example of usage. + */ + trait Timed[F[_], O] { + type Timeout + + /** Waits for either a chunk of elements to be available in the + * source stream, or a timeout to trigger. Whichever happens + * first is provided as the resource of the returned pull, + * alongside a new timed pull that can be used for awaiting + * again. A `None` is returned as the resource of the pull upon + * reaching the end of the stream. + * + * Receiving a timeout is not a fatal event: the evaluation of the + * current chunk is not interrupted, and the next timed pull is + * still returned for further iteration. The lifetime of timeouts + * is handled by explicit calls to the `timeout` method: `uncons` + * does not start, restart or cancel any timeouts. + * + * Note that the type of timeouts is existential in `Pull.Timed` + * (hidden, basically) so you cannot do anything on it except for + * pattern matching, which is best done as a `Left(_)` case. + */ + def uncons: Pull[F, INothing, Option[(Either[Timeout, Chunk[O]], Pull.Timed[F, O])]] + + /** Asynchronously starts a timeout that will be received by + * `uncons` after `t`, and immediately returns. + * + * Timeouts are resettable: if `timeout` executes whilst a + * previous timeout is pending, it will cancel it before starting + * the new one, so that there is at most one timeout in flight at + * any given time. The implementation guards against stale + * timeouts: after resetting a timeout, a subsequent `uncons` is + * guaranteed to never receive an old one. + * + * Timeouts can be reset to any `t`, longer or shorter than the + * previous timeout, but a duration of 0 is treated specially, in + * that it will cancel a pending timeout but not start a new one. + * + * Note: the very first execution of `timeout` does not start + * running until the first call to `uncons`, but subsequent calls + * proceed independently after that. + */ + def timeout(t: FiniteDuration): Pull[F, INothing, Unit] + } + /** `Sync` instance for `Pull`. */ implicit def syncInstance[F[_]: Sync, O]: Sync[Pull[F, O, *]] = new PullSyncInstance[F, O] diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 1662c62263..0185d33591 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -1392,102 +1392,80 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, go(None, this).stream } - /** Divide this streams into groups of elements received within a time window, - * or limited by the number of the elements, whichever happens first. - * Empty groups, which can occur if no elements can be pulled from upstream - * in a given time window, will not be emitted. - * - * Note: a time window starts each time downstream pulls. + /** Divides this stream into chunks of elements of size `n`. + * Each time a group of size `n` is emitted, `timeout` is reset. + * + * If the current chunk does not reach size `n` by the time the + * `timeout` period elapses, it emits a chunk containing however + * many elements have been accumulated so far, and resets + * `timeout`. + * + * However, if no elements at all have been accumulated when + * `timeout` expires, empty chunks are *not* emitted, and `timeout` + * is not reset. + * Instead, the next chunk to arrive is emitted immediately (since + * the stream is still in a timed out state), and only then is + * `timeout` reset. If the chunk received in a timed out state is + * bigger than `n`, the first `n` elements of it are emitted + * immediately in a chunk, `timeout` is reset, and the remaining + * elements are used for the next chunk. + * + * When the stream terminates, any accumulated elements are emitted + * immediately in a chunk, even if `timeout` has not expired. */ def groupWithin[F2[x] >: F[x]]( n: Int, - d: FiniteDuration + timeout: FiniteDuration )(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] = - Stream - .eval { - Queue - .synchronousNoneTerminated[F2, Either[Token, Chunk[O]]] - .product(F.ref(F.unit -> false)) - } - .flatMap { case (q, currentTimeout) => - def startTimeout: Stream[F2, Token] = - Stream.eval(Token[F2]).evalTap { token => - val timeout = F.sleep(d) >> q.enqueue1(token.asLeft.some) - - // We need to cancel outstanding timeouts to avoid leaks - // on interruption, but using `Stream.bracket` or - // derivatives causes a memory leak due to all the - // finalisers accumulating. Therefore we dispose of them - // manually, with a cooperative strategy between a single - // stream finaliser, and F finalisers on each timeout. - // - // Note that to avoid races, the correctness of the - // algorithm does not depend on timely cancellation of - // previous timeouts, but uses a versioning scheme to - // ensure stale timeouts are no-ops. - timeout.start - .bracket(_ => F.unit) { fiber => - // note the this is in a `release` action, and therefore uninterruptible - currentTimeout.modify { case st @ (cancelInFlightTimeout, streamTerminated) => - if (streamTerminated) - // the stream finaliser will cancel the in flight - // timeout, we need to cancel the timeout we have - // just started - st -> fiber.cancel - else - // The stream finaliser hasn't run, so we cancel - // the in flight timeout and store the finaliser for - // the timeout we have just started - (fiber.cancel, streamTerminated) -> cancelInFlightTimeout - }.flatten - } - } - - def producer = - this.chunks.map(_.asRight.some).through(q.enqueue).onFinalize(q.enqueue1(None)) - - def emitNonEmpty(c: Chunk.Queue[O]): Stream[F2, Chunk[O]] = - if (c.size > 0) Stream.emit(c.toChunk) - else Stream.empty - - def resize(c: Chunk[O], s: Stream[F2, Chunk[O]]): (Stream[F2, Chunk[O]], Chunk[O]) = + this + .covary[F2] + .pull + .timed { timedPull => + def resize(c: Chunk[O], s: Pull[F2, Chunk[O], Unit]): (Pull[F2, Chunk[O], Unit], Chunk[O]) = if (c.size < n) s -> c else { val (unit, rest) = c.splitAt(n) - resize(rest, s ++ Stream.emit(unit)) + resize(rest, s >> Pull.output1(unit)) } - def go(acc: Chunk.Queue[O], currentTimeout: Token): Stream[F2, Chunk[O]] = - Stream.eval(q.dequeue1).flatMap { - case None => emitNonEmpty(acc) - case Some(e) => + // Invariants: + // acc.size < n, always + // hasTimedOut == true iff a timeout has been received, and acc.isEmpty + def go(acc: Chunk.Queue[O], timedPull: Pull.Timed[F2, O], hasTimedOut: Boolean = false) + : Pull[F2, Chunk[O], Unit] = + timedPull.uncons.flatMap { + case None => + Pull.output1(acc.toChunk).whenA(acc.nonEmpty) + case Some((e, next)) => + def resetTimerAndGo(q: Chunk.Queue[O]) = + timedPull.timeout(timeout) >> go(q, next) + e match { - case Left(t) if t == currentTimeout => - emitNonEmpty(acc) ++ startTimeout.flatMap { newTimeout => - go(Chunk.Queue.empty, newTimeout) - } - case Left(_) => go(acc, currentTimeout) + case Left(_) => + if (acc.nonEmpty) + Pull.output1(acc.toChunk) >> resetTimerAndGo(Chunk.Queue.empty) + else + go(Chunk.Queue.empty, next, hasTimedOut = true) + case Right(c) if hasTimedOut => + // it has timed out without reset, so acc is empty + val (toEmit, rest) = + if (c.size < n) Pull.output1(c) -> Chunk.empty + else resize(c, Pull.done) + toEmit >> resetTimerAndGo(Chunk.Queue(rest)) case Right(c) => val newAcc = acc :+ c if (newAcc.size < n) - go(newAcc, currentTimeout) + go(newAcc, next) else { - val (toEmit, rest) = resize(newAcc.toChunk, Stream.empty) - toEmit ++ startTimeout.flatMap { newTimeout => - go(Chunk.Queue(rest), newTimeout) - } + val (toEmit, rest) = resize(newAcc.toChunk, Pull.done) + toEmit >> resetTimerAndGo(Chunk.Queue(rest)) } } } - startTimeout - .flatMap(t => go(Chunk.Queue.empty, t).concurrently(producer)) - .onFinalize { - currentTimeout - .getAndSet(F.unit -> true) - .flatMap { case (cancelInFlightTimeout, _) => cancelInFlightTimeout } - } + timedPull.timeout(timeout) >> go(Chunk.Queue.empty, timedPull) } + .stream /** If `this` terminates with `Stream.raiseError(e)`, invoke `h(e)`. * @@ -4087,6 +4065,81 @@ object Stream extends StreamLowPriority { Pull.output(pfx) >> Pull.pure(Some(tl.cons(sfx))) } } + + /** Allows expressing `Pull` computations whose `uncons` can receive + * a user-controlled, resettable `timeout`. + * See [[Pull.Timed]] for more info on timed `uncons` and `timeout`. + * + * As a quick example, let's write a timed pull which emits the + * string "late!" whenever a chunk of the stream is not emitted + * within 150 milliseconds: + * + * @example {{{ + * scala> import cats.effect.IO + * scala> import cats.effect.unsafe.implicits.global + * scala> import scala.concurrent.duration._ + * scala> val s = (Stream("elem") ++ Stream.sleep_[IO](200.millis)).repeat.take(3) + * scala> s.pull + * | .timed { timedPull => + * | def go(timedPull: Pull.Timed[IO, String]): Pull[IO, String, Unit] = + * | timedPull.timeout(150.millis) >> // starts new timeout and stops the previous one + * | timedPull.uncons.flatMap { + * | case Some((Right(elems), next)) => Pull.output(elems) >> go(next) + * | case Some((Left(_), next)) => Pull.output1("late!") >> go(next) + * | case None => Pull.done + * | } + * | go(timedPull) + * | }.stream.compile.toVector.unsafeRunSync() + * res0: Vector[String] = Vector(elem, late!, elem, late!, elem) + * }}} + * + * For a more complex example, look at the implementation of [[Stream.groupWithin]]. + */ + def timed[O2, R]( + pull: Pull.Timed[F, O] => Pull[F, O2, R] + )(implicit F: Temporal[F]): Pull[F, O2, R] = + Pull + .eval(Token[F].mproduct(id => SignallingRef.of(id -> 0.millis))) + .flatMap { case (initial, time) => + def timeouts: Stream[F, Token] = + time.discrete + .dropWhile { case (id, _) => id == initial } + .switchMap { case (id, duration) => + // We cannot move this check into a `filter`: + // we want `switchMap` to execute and cancel the previous timeout + if (duration != 0.nanos) + Stream.sleep(duration).as(id) + else + Stream.empty + } + + def output: Stream[F, Either[Token, Chunk[O]]] = + timeouts + .map(_.asLeft) + .mergeHaltR(self.chunks.map(_.asRight)) + .flatMap { + case chunk @ Right(_) => Stream.emit(chunk) + case timeout @ Left(id) => + Stream + .eval(time.get) + .collect { case (currentTimeout, _) if currentTimeout == id => timeout } + } + + def toTimedPull(s: Stream[F, Either[Token, Chunk[O]]]): Pull.Timed[F, O] = + new Pull.Timed[F, O] { + type Timeout = Token + + def uncons: Pull[F, INothing, Option[(Either[Timeout, Chunk[O]], Pull.Timed[F, O])]] = + s.pull.uncons1 + .map(_.map { case (r, next) => r -> toTimedPull(next) }) + + def timeout(t: FiniteDuration): Pull[F, INothing, Unit] = Pull.eval { + Token[F].tupleRight(t).flatMap(time.set) + } + } + + pull(toTimedPull(output)) + } } /** Projection of a `Stream` providing various ways to compile a `Stream[F,O]` to a `G[...]`. */ diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index cf2622b6df..1610b4dfad 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -726,8 +726,8 @@ class StreamCombinatorsSuite extends Fs2Suite { forAllF { (s: Stream[Pure, Int], d0: Int, maxGroupSize0: Int) => val maxGroupSize = (maxGroupSize0 % 20).abs + 1 val d = (d0 % 50).abs.millis - Stream(1) - .append(s) + + s .map(i => (i % 500).abs) .covary[IO] .evalTap(shortDuration => IO.sleep(shortDuration.micros)) @@ -753,6 +753,16 @@ class StreamCombinatorsSuite extends Fs2Suite { } } + test("should be equivalent to chunkN when no timeouts trigger") { + val s = Stream.range(0, 100) + val size = 5 + + val out0 = s.covary[IO].groupWithin(size, 1.second).map(_.toList).compile.toList + val out1 = s.chunkN(size).map(_.toList).compile.toList + + out0.map(it => assertEquals(it, out1)) + } + test( "should return a finite stream back in a single chunk given a group size equal to the stream size and an absurdly high duration" ) { @@ -767,6 +777,99 @@ class StreamCombinatorsSuite extends Fs2Suite { .map(it => assert(it.head.toList == streamAsList)) } } + + test("accumulation and splitting") { + val t = 200.millis + val size = 5 + val sleep = Stream.sleep_[IO](2 * t) + + def chunk(from: Int, size: Int) = + Stream.range(from, from + size).chunkAll.flatMap(Stream.chunk) + + // this test example is designed to have good coverage of + // the chunk manipulation logic in groupWithin + val source = + chunk(from = 1, size = 3) ++ + sleep ++ + chunk(from = 4, size = 12) ++ + chunk(from = 16, size = 7) + + val expected = List( + List(1, 2, 3), + List(4, 5, 6, 7, 8), + List(9, 10, 11, 12, 13), + List(14, 15, 16, 17, 18), + List(19, 20, 21, 22) + ) + + source + .groupWithin(size, t) + .map(_.toList) + .compile + .toList + .map(it => assertEquals(it, expected)) + } + + test("does not reset timeout if nothing is emitted") { + Ref[IO] + .of(0.millis) + .flatMap { ref => + val timeout = 5.seconds + + def measureEmission[A]: Pipe[IO, A, A] = + _.chunks + .evalTap(_ => IO.monotonic.flatMap(ref.set)) + .flatMap(Stream.chunk) + + // emits elements after the timeout has expired + val source = + Stream.sleep_[IO](timeout + 200.millis) ++ + Stream(4, 5) ++ + Stream.never[IO] // avoids emission due to source termination + + source + .through(measureEmission) + .groupWithin(5, timeout) + .evalMap(_ => (IO.monotonic, ref.get).mapN(_ - _)) + .interruptAfter(timeout * 3) + .compile + .lastOrError + } + .map { groupWithinDelay => + // The stream emits after the timeout + // so groupWithin should re-emit with zero delay + assertEquals(groupWithinDelay, 0.millis) + } + .ticked + } + + test("Edge case: should not introduce unnecessary delays when groupSize == chunkSize") { + Ref[IO] + .of(0.millis) + .flatMap { ref => + val timeout = 5.seconds + + def measureEmission[A]: Pipe[IO, A, A] = + _.chunks + .evalTap(_ => IO.monotonic.flatMap(ref.set)) + .flatMap(Stream.chunk) + + val source = + Stream(1, 2, 3) ++ + Stream.sleep_[IO](timeout + 200.millis) + + source + .through(measureEmission) + .groupWithin(3, timeout) + .evalMap(_ => (IO.monotonic, ref.get).mapN(_ - _)) + .compile + .lastOrError + } + .map { groupWithinDelay => + assertEquals(groupWithinDelay, 0.millis) + } + .ticked + } } property("head")(forAll((s: Stream[Pure, Int]) => assert(s.head.toList == s.toList.take(1)))) diff --git a/core/shared/src/test/scala/fs2/TimedPullsSuite.scala b/core/shared/src/test/scala/fs2/TimedPullsSuite.scala new file mode 100644 index 0000000000..cad5daa135 --- /dev/null +++ b/core/shared/src/test/scala/fs2/TimedPullsSuite.scala @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2013 Functional Streams for Scala + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2 + +import cats.effect.IO +import cats.syntax.all._ + +import scala.concurrent.duration._ + +import org.scalacheck.effect.PropF.forAllF + +class TimedPullsSuite extends Fs2Suite { + + def fail(s: String) = Pull.raiseError[IO](new Exception(s)) + + test("behaves as a normal Pull when no timeouts are used") { + forAllF { (s: Stream[Pure, Int]) => + s.covary[IO] + .pull + .timed { tp => + def loop(tp: Pull.Timed[IO, Int]): Pull[IO, Int, Unit] = + tp.uncons.flatMap { + case None => Pull.done + case Some((Right(c), next)) => Pull.output(c) >> loop(next) + case Some((Left(_), _)) => fail("unexpected timeout") + } + + loop(tp) + } + .stream + .compile + .toList + .map(it => assertEquals(it, s.compile.toList)) + } + } + + test("pulls elements with timeouts, no timeouts trigger") { + // TODO cannot use PropF with `.ticked` at the moment + val l = List.range(1, 100) + val s = Stream.emits(l).covary[IO].rechunkRandomly() + val period = 500.millis + val timeout = 600.millis + + s.metered(period) + .pull + .timed { tp => + def loop(tp: Pull.Timed[IO, Int]): Pull[IO, Int, Unit] = + tp.uncons.flatMap { + case None => Pull.done + case Some((Right(c), next)) => Pull.output(c) >> tp.timeout(timeout) >> loop(next) + case Some((Left(_), _)) => fail("unexpected timeout") + } + + tp.timeout(timeout) >> loop(tp) + } + .stream + .compile + .toList + .map(it => assertEquals(it, l)) + .ticked + } + + test("times out whilst pulling a single element") { + Stream + .sleep[IO](300.millis) + .pull + .timed { tp => + tp.timeout(100.millis) >> + tp.uncons.flatMap { + case Some((Left(_), _)) => Pull.done + case _ => fail("timeout expected") + } + } + .stream + .compile + .drain + .ticked + } + + test("times out after pulling multiple elements") { + val l = List(1, 2, 3) + val s = Stream.emits(l) ++ Stream.never[IO] + val t = 100.millis + val timeout = 350.millis + + s + .metered(t) + .pull + .timed { tp => + def go(tp: Pull.Timed[IO, Int]): Pull[IO, Int, Unit] = + tp.uncons.flatMap { + case Some((Right(c), n)) => Pull.output(c) >> go(n) + case Some((Left(_), _)) => Pull.done + case None => fail("Unexpected end of input") + } + + tp.timeout(timeout) >> go(tp) + } + .stream + .compile + .toList + .map(it => assertEquals(it, l)) + .ticked + } + + test("pulls elements with timeouts, timeouts trigger after reset") { + val timeout = 500.millis + val t = 600.millis + val n = 10L + val s = Stream.constant(1).covary[IO].metered(t).take(n) + val expected = Stream("timeout", "elem").repeat.take(n * 2).compile.toList + + s.pull + .timed { tp => + def go(tp: Pull.Timed[IO, Int]): Pull[IO, String, Unit] = + tp.uncons.flatMap { + case None => Pull.done + case Some((Right(_), next)) => Pull.output1("elem") >> tp.timeout(timeout) >> go(next) + case Some((Left(_), next)) => Pull.output1("timeout") >> go(next) + } + + tp.timeout(timeout) >> go(tp) + } + .stream + .compile + .toList + .map(it => assertEquals(it, expected)) + .ticked + } + + test("timeout can be reset before triggering") { + val s = + Stream.emit(()) ++ + Stream.sleep[IO](1.second) ++ + Stream.sleep[IO](1.second) ++ + // use `never` to test logic without worrying about termination + Stream.never[IO] + + s.pull + .timed { one => + one.timeout(900.millis) >> one.uncons.flatMap { + case Some((Right(_), two)) => + two.timeout(1100.millis) >> two.uncons.flatMap { + case Some((Right(_), three)) => + three.uncons.flatMap { + case Some((Left(_), _)) => Pull.done + case _ => fail(s"Expected timeout third, received element") + } + case _ => fail(s"Expected element second, received timeout") + } + case _ => fail(s"Expected element first, received timeout") + } + + } + .stream + .compile + .drain + .ticked + } + + test("timeout can be reset to a shorter one") { + val s = + Stream.emit(()) ++ + Stream.sleep[IO](1.second) ++ + Stream.never[IO] + + s.pull + .timed { one => + one.timeout(2.seconds) >> one.uncons.flatMap { + case Some((Right(_), two)) => + two.timeout(900.millis) >> two.uncons.flatMap { + case Some((Left(_), _)) => Pull.done + case _ => fail(s"Expected timeout second, received element") + } + case _ => fail(s"Expected element first, received timeout") + } + } + .stream + .compile + .drain + .ticked + } + + test("timeout can be reset without starting a new one") { + val s = Stream.sleep[IO](2.seconds) ++ Stream.sleep[IO](2.seconds) + val t = 3.seconds + + s.pull + .timed { one => + one.timeout(t) >> one.uncons.flatMap { + case Some((Right(_), two)) => + two.timeout(0.millis) >> + two.uncons.flatMap { + case Some((Right(_), three)) => + three.uncons.flatMap { + case None => Pull.done + case v => fail(s"Expected end of stream, received $v") + } + case _ => fail("Expected element second, received timeout") + } + case _ => fail("Expected element first, received timeout") + } + } + .stream + .compile + .drain + .ticked + } + + test("never emits stale timeouts") { + val t = 200.millis + + val prog = + (Stream.sleep[IO](t) ++ Stream.never[IO]).pull + .timed { tp => + def go(tp: Pull.Timed[IO, Unit]): Pull[IO, String, Unit] = + tp.uncons.flatMap { + case None => Pull.done + case Some((Right(_), n)) => + Pull.output1("elem") >> + tp.timeout(0.millis) >> // cancel old timeout without starting a new one + go(n) + case Some((Left(_), n)) => + Pull.output1("timeout") >> go(n) + } + + tp.timeout(t) >> // race between timeout and stream waiting + go(tp) + } + .stream + .interruptAfter(3.seconds) + .compile + .toList + + def check(results: List[String]): IO[Unit] = { + val validInterleavings = Set(List("timeout", "elem"), List("elem")) + // we canceled the timeout after receiving an element, so this + // interleaving breaks the invariant that an old timeout can never + // be unconsed after timeout has reset it + val buggyInterleavings = Set(List("elem", "timeout")) + + if (validInterleavings.contains(results)) + IO.unit + else if (buggyInterleavings.contains(results)) + IO.raiseError(new Exception("A stale timeout was received")) + else + IO.raiseError(new Exception("Unexpected error")) + } + + prog + .flatMap(check) + .replicateA(10) // number of iterations to stress the race + .ticked + } +}