diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala index defc9484e..2df17a78d 100644 --- a/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala @@ -292,7 +292,7 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { }.tap { case (_, idx) => ZIO.logDebug(s"Consumed $idx") } } .runDrain - .tap(_ => ZIO.logDebug("Stream completed")) + .zipLeft(ZIO.logDebug("Stream completed")) .provideSomeLayer[Kafka]( consumer(client, Some(group)) ) @@ -530,18 +530,30 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { // Consume messages subscription = Subscription.topics(topic) + assignedPartitionsRef <- Ref.make(Set.empty[Int]) // Set of partition numbers + // Create a Promise to signal when consumer1 has processed half the partitions + consumer1Ready <- Promise.make[Nothing, Unit] consumer1 <- Consumer .partitionedStream(subscription, Serde.string, Serde.string) .flatMapPar(nrPartitions) { case (tp, partition) => ZStream - .fromZIO(partition.runDrain) + .fromZIO( + consumer1Ready + .succeed(()) + .whenZIO( + assignedPartitionsRef + .updateAndGet(_ + tp.partition()) + .map(_.size >= (nrPartitions / 2)) + ) *> + partition.runDrain + ) .as(tp) } .take(nrPartitions.toLong / 2) .runDrain .provideSomeLayer[Kafka](consumer(client1, Some(group))) .fork - _ <- Live.live(ZIO.sleep(5.seconds)) + _ <- consumer1Ready.await consumer2 <- Consumer .partitionedStream(subscription, Serde.string, Serde.string) .take(nrPartitions.toLong / 2) @@ -574,11 +586,22 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { // Consume messages subscription = Subscription.topics(topic) + consumer1Ready <- Promise.make[Nothing, Unit] + assignedPartitionsRef <- Ref.make(Set.empty[Int]) // Set of partition numbers consumer1 <- Consumer .partitionedStream(subscription, Serde.string, Serde.string) .flatMapPar(nrPartitions) { case (tp, partition) => ZStream - .fromZIO(partition.runDrain) + .fromZIO( + consumer1Ready + .succeed(()) + .whenZIO( + assignedPartitionsRef + .updateAndGet(_ + tp.partition()) + .map(_.size >= (nrPartitions / 2)) + ) *> + partition.runDrain + ) .as(tp) } .take(nrPartitions.toLong / 2) @@ -592,7 +615,7 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { .collect { case rebalance: DiagnosticEvent.Rebalance => rebalance } .runCollect .fork - _ <- ZIO.sleep(5.seconds) + _ <- consumer1Ready.await consumer2 <- Consumer .partitionedStream(subscription, Serde.string, Serde.string) .take(nrPartitions.toLong / 2) @@ -600,7 +623,6 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { .provideSomeLayer[Kafka](consumer(client2, Some(group))) .fork _ <- consumer1.join - _ <- consumer1.join _ <- consumer2.join } yield diagnosticStream.join } @@ -1481,6 +1503,7 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { "it's possible to start a new consumption session from a Consumer that had a consumption session stopped previously" ) { val numberOfMessages: Int = 100000 + val messagesToConsumeBeforeStop = 1000 // Adjust this threshold as needed val kvs: Iterable[(String, String)] = Iterable.tabulate(numberOfMessages)(i => (s"key-$i", s"msg-$i")) def test(diagnostics: Diagnostics): ZIO[Producer & Scope & Kafka, Throwable, TestResult] = @@ -1490,22 +1513,28 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { settings <- consumerSettings(clientId = clientId) consumer <- Consumer.make(settings, diagnostics = diagnostics) _ <- produceMany(topic, kvs) + // Create a Ref to track messages consumed and a Promise to signal when to stop consumption + messagesConsumedRef <- Ref.make(0) + stopPromise <- Promise.make[Nothing, Unit] // Starting a consumption session to start the Runloop. - fiber <- - consumer - .plainStream(Subscription.manual(topic -> 0), Serde.string, Serde.string) - .tap(_ => ZIO.sleep(1.millisecond)) // sleep to avoid consuming all messages in under 200 millis - .take(numberOfMessages.toLong) - .runCount - .forkScoped - _ <- ZIO.sleep(200.millis) + fiber <- consumer + .plainStream(Subscription.manual(topic -> 0), Serde.string, Serde.string) + .mapZIO { _ => + messagesConsumedRef.updateAndGet(_ + 1).flatMap { count => + if (count >= messagesToConsumeBeforeStop) stopPromise.succeed(()).as(1L) + else ZIO.succeed(1L) + } + } + .take(numberOfMessages.toLong) + .runSum + .forkScoped + + // Wait for the consumption to reach the desired threshold + _ <- stopPromise.await _ <- consumer.stopConsumption consumed0 <- fiber.join _ <- ZIO.logDebug(s"consumed0: $consumed0") - _ <- ZIO.logDebug("About to sleep 5 seconds") - _ <- ZIO.sleep(5.seconds) - _ <- ZIO.logDebug("Slept 5 seconds") consumed1 <- consumer .plainStream(Subscription.manual(topic -> 0), Serde.string, Serde.string) .take(numberOfMessages.toLong)