Skip to content

Commit

Permalink
Guard against multiple FetchRequest on the same (partition,consumer)
Browse files Browse the repository at this point in the history
  • Loading branch information
backuitist committed Mar 28, 2019
1 parent 255e767 commit 88229f5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
7 changes: 6 additions & 1 deletion src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package fs2.kafka

import java.util.concurrent.atomic.AtomicInteger

import cats.{Foldable, Reducible}
import cats.data.{NonEmptyList, NonEmptySet}
import cats.effect._
Expand Down Expand Up @@ -358,7 +360,10 @@ private[kafka] object KafkaConsumer {
actorFiber combine pollsFiber
}

private val streamIdCounter = new AtomicInteger(0)

override def partitionedStream: Stream[F, Stream[F, CommittableMessage[F, K, V]]] = {
val streamId = streamIdCounter.getAndIncrement()
val chunkQueue: F[Queue[F, Option[Chunk[CommittableMessage[F, K, V]]]]] =
Queue.bounded(settings.maxPrefetchBatches - 1)

Expand All @@ -379,7 +384,7 @@ private[kafka] object KafkaConsumer {
Stream
.repeatEval {
Deferred[F, PartitionRequest].flatMap { deferred =>
val request = Request.Fetch(partition, deferred)
val request = Request.Fetch(partition, streamId, deferred)
val fetch = requests.enqueue1(request) >> deferred.get
F.race(shutdown, fetch).flatMap {
case Left(()) => F.unit
Expand Down
30 changes: 21 additions & 9 deletions src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.time.Duration
import java.util
import java.util.regex.Pattern

import cats.data.{Chain, NonEmptyChain, NonEmptyList, NonEmptySet}
import cats.data.{Chain, NonEmptyList, NonEmptySet}
import cats.effect.concurrent.{Deferred, Ref}
import cats.effect.{ConcurrentEffect, ContextShift, IO, Timer}
import cats.instances.list._
Expand Down Expand Up @@ -248,6 +248,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](

private[this] def fetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): F[Unit] = {
val assigned =
Expand All @@ -256,7 +257,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
}

def storeFetch =
ref.update(_.withFetch(partition, deferred))
ref.update(_.withFetch(partition, streamId, deferred))

def completeRevoked =
deferred.complete((Chunk.empty, FetchCompletedReason.TopicPartitionRevoked))
Expand Down Expand Up @@ -303,15 +304,15 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(withRecords).traverse {
case (partition, partitionFetches) =>
val records = Chunk.buffer(state.records(partition))
partitionFetches.traverse(_.completeRevoked(records))
partitionFetches.values.toList.traverse(_.completeRevoked(records))
} >> ref.update(_.withoutFetchesAndRecords(withRecords))
} else F.unit

val completeWithoutRecords =
if (withoutRecords.nonEmpty) {
state.fetches
.filterKeysStrictValuesList(withoutRecords)
.traverse(_.traverse(_.completeRevoked(Chunk.empty))) >>
.traverse(_.values.toList.traverse(_.completeRevoked(Chunk.empty))) >>
ref.update(_.withoutFetches(withoutRecords))
} else F.unit

Expand Down Expand Up @@ -438,7 +439,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(canBeCompleted).traverse {
case (partition, fetches) =>
val records = Chunk.buffer(allRecords(partition))
fetches.traverse(_.completeRecords(records))
fetches.values.toList.traverse(_.completeRecords(records))
} >> ref.update(_.withoutFetchesAndRecords(canBeCompleted))
} else F.unit

Expand Down Expand Up @@ -470,7 +471,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
case Request.Poll() => poll
case Request.SubscribeTopics(topics, deferred) => subscribe(topics, deferred)
case Request.SubscribePattern(pattern, deferred) => subscribe(pattern, deferred)
case Request.Fetch(partition, deferred) => fetch(partition, deferred)
case Request.Fetch(partition, streamId, deferred) => fetch(partition, streamId, deferred)
case Request.Commit(offsets, deferred) => commit(offsets, deferred)
case request @ Request.Revoked(_) => revoked(request)
case Request.Seek(partition, offset, deferred) => seek(partition, offset, deferred)
Expand All @@ -491,8 +492,10 @@ private[kafka] object KafkaConsumerActor {
deferred.complete((chunk, FetchCompletedReason.FetchedRecords))
}

type StreamId = Int

final case class State[F[_], K, V](
fetches: Map[TopicPartition, NonEmptyChain[FetchRequest[F, K, V]]],
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]],
records: Map[TopicPartition, ArrayBuffer[CommittableMessage[F, K, V]]],
onRebalances: Chain[OnRebalance[F, K, V]],
subscribed: Boolean
Expand All @@ -502,10 +505,18 @@ private[kafka] object KafkaConsumerActor {

def withFetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): State[F, K, V] = {
val fetch = NonEmptyChain.one(FetchRequest(deferred))
copy(fetches = fetches combine Map(partition -> fetch))
val fetch = Map(streamId -> FetchRequest(deferred))

// revoke any previous fetch on the same (partition,streamId)
// TODO log a warning as this shouldn't happen!
val existing = fetches.get(partition).flatMap(_.get(streamId))
existing.foreach(_.completeRevoked(Chunk.empty))

val fetchesForPartition = fetches.getOrElse(partition, Map.empty) ++ fetch
copy(fetches = fetches ++ Map(partition -> fetchesForPartition))
}

def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
Expand Down Expand Up @@ -590,6 +601,7 @@ private[kafka] object KafkaConsumerActor {

final case class Fetch[F[_], K, V](
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
) extends Request[F, K, V]

Expand Down

0 comments on commit 88229f5

Please sign in to comment.