diff --git a/src/main/scala/fs2/kafka/KafkaConsumer.scala b/src/main/scala/fs2/kafka/KafkaConsumer.scala index 0ecf7bfb0..b94c385f1 100644 --- a/src/main/scala/fs2/kafka/KafkaConsumer.scala +++ b/src/main/scala/fs2/kafka/KafkaConsumer.scala @@ -340,6 +340,7 @@ private[kafka] object KafkaConsumer { settings: ConsumerSettings[K, V], actor: Fiber[F, Unit], polls: Fiber[F, Unit], + streamIdRef: Ref[F, Int], id: Int )(implicit F: Concurrent[F]): KafkaConsumer[F, K, V] = new KafkaConsumer[F, K, V] { @@ -367,25 +368,28 @@ private[kafka] object KafkaConsumer { (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason) def enqueueStream( + streamId: Int, partition: TopicPartition, partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]] ): F[Unit] = chunkQueue.flatMap { chunks => Deferred[F, Unit].flatMap { dequeueDone => - Deferred.tryable[F, Unit].flatMap { partitionRevoked => + Deferred.tryable[F, Unit].flatMap { stopRequests => val shutdown = F.race(fiber.join.attempt, dequeueDone.get).void partitions.enqueue1 { Stream.eval { F.guarantee { Stream .repeatEval { - partitionRevoked.tryGet.flatMap { + stopRequests.tryGet.flatMap { case None => Deferred[F, PartitionRequest].flatMap { deferred => - val request = Request.Fetch(partition, deferred) + val request = Request.Fetch(partition, streamId, deferred) val fetch = requests.enqueue1(request) >> deferred.get F.race(shutdown, fetch).flatMap { - case Left(()) => F.unit + case Left(()) => + stopRequests.complete(()) + case Right((chunk, reason)) => val enqueueChunk = if (chunk.nonEmpty) @@ -394,7 +398,7 @@ private[kafka] object KafkaConsumer { val completeRevoked = if (reason.topicPartitionRevoked) - partitionRevoked.complete(()) + stopRequests.complete(()) else F.unit enqueueChunk >> completeRevoked @@ -403,11 +407,12 @@ private[kafka] object KafkaConsumer { case Some(()) => // Prevent issuing additional requests after partition is - // revoked, in case stream interruption isn't fast enough + // revoked or shutdown happens, in case the stream isn't + // interrupted fast enough F.unit } } - .interruptWhen(F.race(shutdown, partitionRevoked.get).void.attempt) + .interruptWhen(F.race(shutdown, stopRequests.get).void.attempt) .compile .drain }(F.race(dequeueDone.get, chunks.enqueue1(None)).void) @@ -425,22 +430,26 @@ private[kafka] object KafkaConsumer { } def enqueueStreams( + streamId: Int, assigned: NonEmptySet[TopicPartition], partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]] - ): F[Unit] = assigned.foldLeft(F.unit)(_ >> enqueueStream(_, partitions)) + ): F[Unit] = assigned.foldLeft(F.unit)(_ >> enqueueStream(streamId, _, partitions)) def onRebalance( + streamId: Int, partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]] ): OnRebalance[F, K, V] = OnRebalance( - onAssigned = assigned => enqueueStreams(assigned, partitions), + onAssigned = assigned => enqueueStreams(streamId, assigned, partitions), onRevoked = _ => F.unit ) def requestAssignment( + streamId: Int, partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]] ): F[SortedSet[TopicPartition]] = { Deferred[F, Either[Throwable, SortedSet[TopicPartition]]].flatMap { deferred => - val request = Request.Assignment[F, K, V](deferred, Some(onRebalance(partitions))) + val request = + Request.Assignment[F, K, V](deferred, Some(onRebalance(streamId, partitions))) val assignment = requests.enqueue1(request) >> deferred.get.rethrow F.race(fiber.join.attempt, assignment).map { case Left(_) => SortedSet.empty[TopicPartition] @@ -449,11 +458,14 @@ private[kafka] object KafkaConsumer { } } - def initialEnqueue(partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]): F[Unit] = - requestAssignment(partitions).flatMap { assigned => + def initialEnqueue( + streamId: Int, + partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]] + ): F[Unit] = + requestAssignment(streamId, partitions).flatMap { assigned => if (assigned.nonEmpty) { val nonEmpty = NonEmptySet.fromSetUnsafe(assigned) - enqueueStreams(nonEmpty, partitions) + enqueueStreams(streamId, nonEmpty, partitions) } else F.unit } @@ -461,8 +473,10 @@ private[kafka] object KafkaConsumer { Queue.unbounded[F, Stream[F, CommittableMessage[F, K, V]]] Stream.eval(partitionQueue).flatMap { partitions => - Stream.eval(initialEnqueue(partitions)) >> - partitions.dequeue.interruptWhen(fiber.join.attempt) + Stream.eval(streamIdRef.modify(n => (n + 1, n))).flatMap { streamId => + Stream.eval(initialEnqueue(streamId, partitions)) >> + partitions.dequeue.interruptWhen(fiber.join.attempt) + } } } @@ -623,20 +637,22 @@ private[kafka] object KafkaConsumer { Resource.liftF(Jitter.default[F]).flatMap { implicit jitter => Resource.liftF(F.delay(new Object().hashCode)).flatMap { id => Resource.liftF(Logging.default[F](id)).flatMap { implicit logging => - executionContextResource(settings).flatMap { executionContext => - createConsumer(settings, executionContext).flatMap { synchronized => - val actor = - new KafkaConsumerActor( - settings = settings, - executionContext = executionContext, - ref = ref, - requests = requests, - synchronized = synchronized - ) - - startConsumerActor(requests, polls, actor).flatMap { actor => - startPollScheduler(polls, settings.pollInterval).map { polls => - createKafkaConsumer(requests, settings, actor, polls, id) + Resource.liftF(Ref.of[F, Int](0)).flatMap { streamId => + executionContextResource(settings).flatMap { executionContext => + createConsumer(settings, executionContext).flatMap { synchronized => + val actor = + new KafkaConsumerActor( + settings = settings, + executionContext = executionContext, + ref = ref, + requests = requests, + synchronized = synchronized + ) + + startConsumerActor(requests, polls, actor).flatMap { actor => + startPollScheduler(polls, settings.pollInterval).map { polls => + createKafkaConsumer(requests, settings, actor, polls, streamId, id) + } } } } diff --git a/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala b/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala index a148dd516..06a9a71de 100644 --- a/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala +++ b/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala @@ -20,16 +20,10 @@ import java.time.Duration import java.util import java.util.regex.Pattern -import cats.data.{Chain, NonEmptyChain, NonEmptyList, NonEmptySet} +import cats.data.{Chain, NonEmptyList, NonEmptySet} import cats.effect.concurrent.{Deferred, Ref} import cats.effect.{ConcurrentEffect, ContextShift, IO, Timer} -import cats.instances.list._ -import cats.instances.map._ -import cats.syntax.applicativeError._ -import cats.syntax.flatMap._ -import cats.syntax.monadError._ -import cats.syntax.semigroup._ -import cats.syntax.traverse._ +import cats.implicits._ import fs2.Chunk import fs2.concurrent.Queue import fs2.kafka._ @@ -256,6 +250,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V]( private[this] def fetch( partition: TopicPartition, + streamId: Int, deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)] ): F[Unit] = { val assigned = @@ -265,8 +260,19 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V]( def storeFetch = ref - .updateAndGet(_.withFetch(partition, deferred)) - .log(StoredFetch(partition, deferred, _)) + .modify { state => + val (newState, oldFetch) = + state.withFetch(partition, streamId, deferred) + (newState, (newState, oldFetch)) + } + .flatMap { + case (newState, oldFetch) => + log(StoredFetch(partition, deferred, newState)) >> + oldFetch.fold(F.unit) { fetch => + fetch.completeRevoked(Chunk.empty) >> + log(RevokedPreviousFetch(partition, streamId)) + } + } def completeRevoked = deferred.complete((Chunk.empty, FetchCompletedReason.TopicPartitionRevoked)) @@ -319,7 +325,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V]( state.fetches.filterKeysStrictList(withRecords).traverse { case (partition, partitionFetches) => val records = Chunk.buffer(state.records(partition)) - partitionFetches.traverse(_.completeRevoked(records)) + partitionFetches.values.toList.traverse(_.completeRevoked(records)) } >> ref .updateAndGet(_.withoutFetchesAndRecords(withRecords)) .log(RevokedFetchesWithRecords(state.records.filterKeysStrict(withRecords), _)) @@ -329,7 +335,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V]( if (withoutRecords.nonEmpty) { state.fetches .filterKeysStrictValuesList(withoutRecords) - .traverse(_.traverse(_.completeRevoked(Chunk.empty))) >> + .traverse(_.values.toList.traverse(_.completeRevoked(Chunk.empty))) >> ref .updateAndGet(_.withoutFetches(withoutRecords)) .log(RevokedFetchesWithoutRecords(withoutRecords, _)) @@ -472,7 +478,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V]( state.fetches.filterKeysStrictList(canBeCompleted).traverse { case (partition, fetches) => val records = Chunk.buffer(allRecords(partition)) - fetches.traverse(_.completeRecords(records)) + fetches.values.toList.traverse(_.completeRecords(records)) } >> ref .updateAndGet(_.withoutFetchesAndRecords(canBeCompleted)) .log(CompletedFetchesWithRecords(allRecords.filterKeysStrict(canBeCompleted), _)) @@ -509,7 +515,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V]( case Request.Poll() => poll case Request.SubscribeTopics(topics, deferred) => subscribe(topics, deferred) case Request.SubscribePattern(pattern, deferred) => subscribe(pattern, deferred) - case Request.Fetch(partition, deferred) => fetch(partition, deferred) + case Request.Fetch(partition, streamId, deferred) => fetch(partition, streamId, deferred) case Request.Commit(offsets, deferred) => commit(offsets, deferred) case Request.Seek(partition, offset, deferred) => seek(partition, offset, deferred) case Request.SeekToBeginning(partitions, deferred) => seekToBeginning(partitions, deferred) @@ -532,8 +538,10 @@ private[kafka] object KafkaConsumerActor { "FetchRequest$" + System.identityHashCode(this) } + type StreamId = Int + final case class State[F[_], K, V]( - fetches: Map[TopicPartition, NonEmptyChain[FetchRequest[F, K, V]]], + fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]], records: Map[TopicPartition, ArrayBuffer[CommittableMessage[F, K, V]]], onRebalances: Chain[OnRebalance[F, K, V]], subscribed: Boolean @@ -541,12 +549,29 @@ private[kafka] object KafkaConsumerActor { def withOnRebalance(onRebalance: OnRebalance[F, K, V]): State[F, K, V] = copy(onRebalances = onRebalances append onRebalance) + /** + * @return (new-state, old-fetch-to-revoke) + */ def withFetch( partition: TopicPartition, + streamId: Int, deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)] - ): State[F, K, V] = { - val fetch = NonEmptyChain.one(FetchRequest(deferred)) - copy(fetches = fetches combine Map(partition -> fetch)) + ): (State[F, K, V], Option[FetchRequest[F, K, V]]) = { + val oldPartitionFetches = + fetches.get(partition) + + val oldPartitionFetch = + oldPartitionFetches.flatMap(_.get(streamId)) + + val newPartitionFetches = + oldPartitionFetches + .getOrElse(Map.empty) + .updated(streamId, FetchRequest(deferred)) + + val newFetches = + fetches.updated(partition, newPartitionFetches) + + (copy(fetches = newFetches), oldPartitionFetch) } def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] = @@ -640,6 +665,7 @@ private[kafka] object KafkaConsumerActor { final case class Fetch[F[_], K, V]( partition: TopicPartition, + streamId: Int, deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)] ) extends Request[F, K, V] diff --git a/src/main/scala/fs2/kafka/internal/LogEntry.scala b/src/main/scala/fs2/kafka/internal/LogEntry.scala index c1ce4516e..da4211b45 100644 --- a/src/main/scala/fs2/kafka/internal/LogEntry.scala +++ b/src/main/scala/fs2/kafka/internal/LogEntry.scala @@ -126,6 +126,15 @@ private[kafka] object LogEntry { s"Stored records for partitions [${recordsString(records)}]. Current state [$state]." } + final case class RevokedPreviousFetch( + partition: TopicPartition, + streamId: Int + ) extends LogEntry { + override def level: LogLevel = Warn + override def message: String = + s"Revoked previous fetch for partition [$partition] in stream with id [$streamId]." + } + def recordsString[F[_], K, V]( records: Records[F, K, V] ): String = diff --git a/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala b/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala index e183cbda1..857239672 100644 --- a/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala +++ b/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala @@ -12,19 +12,16 @@ import org.apache.kafka.common.TopicPartition import scala.concurrent.duration._ final class KafkaConsumerSpec extends BaseKafkaSpec { - describe("KafkaConsumer#stream") { - tests(_.stream) - } type Consumer = KafkaConsumer[IO, String, String] type ConsumerStream = Stream[IO, CommittableMessage[IO, String, String]] - def tests(stream: Consumer => ConsumerStream): Unit = { + describe("KafkaConsumer#stream") { it("should consume all messages") { withKafka { (config, topic) => createCustomTopic(topic, partitions = 3) - val produced = (0 until 100).map(n => s"key-$n" -> s"value->$n") + val produced = (0 until 5).map(n => s"key-$n" -> s"value->$n") publishToKafka(topic, produced) val consumed = @@ -32,9 +29,10 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { .using(consumerSettings(config)) .evalTap(_.subscribeTo(topic)) .evalTap(consumer => IO(consumer.toString should startWith("KafkaConsumer$")).void) - .flatMap(stream) - .take(produced.size.toLong) + .evalMap(IO.sleep(3.seconds).as) // sleep a bit to trigger potential race condition with _.stream + .flatMap(_.stream) .map(message => message.record.key -> message.record.value) + .interruptAfter(10.seconds) // wait some time to catch potentially duplicated records .compile .toVector .unsafeRunSync @@ -54,7 +52,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { consumerStream[IO] .using(consumerSettings(config)) .evalTap(_.subscribeTo(topic)) - .evalMap(stream(_).evalMap(queue.enqueue1).compile.drain.start.void) + .evalMap(_.stream.evalMap(queue.enqueue1).compile.drain.start.void) (for { queue <- Stream.eval(Queue.unbounded[IO, CommittableMessage[IO, String, String]]) @@ -126,7 +124,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { .flatMap(consumerStream[IO].using) .evalTap(_.subscribe(topic.r)) .flatMap { consumer => - stream(consumer) + consumer.stream .take(produced.size.toLong) .map(_.committableOffset) .fold(CommittableOffsetBatch.empty[IO])(_ updated _) @@ -155,7 +153,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { .using(consumerSettings(config)) .evalTap(_.subscribeTo(topic)) .evalTap(_.fiber.cancel) - .flatTap(stream) + .flatTap(_.stream) .evalTap(_.fiber.join) .compile .toVector @@ -172,7 +170,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { val consumed = consumerStream[IO] .using(consumerSettings(config)) - .flatMap(stream) + .flatMap(_.stream) .compile .lastOrError .attempt @@ -232,7 +230,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { .withAutoOffsetReset(AutoOffsetReset.None) } .evalTap(_.subscribeTo(topic)) - .flatMap(stream) + .flatMap(_.stream) .compile .lastOrError .attempt @@ -260,7 +258,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { .using(consumerSettings(config)) .evalTap(_.subscribeTo(topic)) .flatTap { consumer => - stream(consumer) + consumer.stream .take(produced.size.toLong) .map(_.committableOffset) .through(commitBatch) @@ -319,7 +317,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { .using(consumerSettings(config)) .flatMap { consumer => val validSeekParams = - stream(consumer) + consumer.stream .take(Math.max(readOffset, 1)) .map(_.committableOffset) .compile @@ -339,7 +337,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { val setOffset = seekParams.flatMap { case (tp, o) => consumer.seek(tp, o) } - val consume = stream(consumer).take(numMessages - readOffset) + val consume = consumer.stream.take(numMessages - readOffset) Stream.eval(consumer.subscribeTo(topic)).drain ++ (Stream.eval_(setOffset) ++ consume)