Skip to content

Commit

Permalink
Guard against multiple FetchRequest on the same (partition,consumer)
Browse files Browse the repository at this point in the history
  • Loading branch information
backuitist committed Mar 29, 2019
1 parent e99bf0f commit e38d3ed
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
7 changes: 6 additions & 1 deletion src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package fs2.kafka

import java.util.concurrent.atomic.AtomicInteger

import cats.{Foldable, Reducible}
import cats.data.{NonEmptyList, NonEmptySet}
import cats.effect._
Expand Down Expand Up @@ -358,7 +360,10 @@ private[kafka] object KafkaConsumer {
actorFiber combine pollsFiber
}

private val streamIdCounter = new AtomicInteger(0)

override def partitionedStream: Stream[F, Stream[F, CommittableMessage[F, K, V]]] = {
val streamId = streamIdCounter.getAndIncrement()
val chunkQueue: F[Queue[F, Option[Chunk[CommittableMessage[F, K, V]]]]] =
Queue.bounded(settings.maxPrefetchBatches - 1)

Expand All @@ -379,7 +384,7 @@ private[kafka] object KafkaConsumer {
Stream
.repeatEval {
Deferred[F, PartitionRequest].flatMap { deferred =>
val request = Request.Fetch(partition, deferred)
val request = Request.Fetch(partition, streamId, deferred)
val fetch = requests.enqueue1(request) >> deferred.get
F.race(shutdown, fetch).flatMap {
case Left(()) => F.unit
Expand Down
41 changes: 30 additions & 11 deletions src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.time.Duration
import java.util
import java.util.regex.Pattern

import cats.data.{Chain, NonEmptyChain, NonEmptyList, NonEmptySet}
import cats.data.{Chain, NonEmptyList, NonEmptySet}
import cats.effect.concurrent.{Deferred, Ref}
import cats.effect.{ConcurrentEffect, ContextShift, IO, Timer}
import cats.instances.list._
Expand Down Expand Up @@ -246,15 +246,23 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](

private[this] def fetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): F[Unit] = {
val assigned =
withConsumer { consumer =>
F.delay(consumer.assignment.contains(partition))
}

def storeFetch =
ref.update(_.withFetch(partition, deferred))
def storeFetch = {
// revoke any previous fetch on the same (partition,streamId)
// TODO log a warning as this shouldn't happen!
ref.modify(_.withFetch(partition, streamId, deferred)).flatMap {
case None => F.unit
case Some(oldFetch) => oldFetch.completeRevoked(Chunk.empty)
}
}


def completeRevoked =
deferred.complete((Chunk.empty, FetchCompletedReason.TopicPartitionRevoked))
Expand Down Expand Up @@ -301,15 +309,15 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(withRecords).traverse {
case (partition, partitionFetches) =>
val records = Chunk.buffer(state.records(partition))
partitionFetches.traverse(_.completeRevoked(records))
partitionFetches.values.toList.traverse(_.completeRevoked(records))
} >> ref.update(_.withoutFetchesAndRecords(withRecords))
} else F.unit

val completeWithoutRecords =
if (withoutRecords.nonEmpty) {
state.fetches
.filterKeysStrictValuesList(withoutRecords)
.traverse(_.traverse(_.completeRevoked(Chunk.empty))) >>
.traverse(_.values.toList.traverse(_.completeRevoked(Chunk.empty))) >>
ref.update(_.withoutFetches(withoutRecords))
} else F.unit

Expand Down Expand Up @@ -437,7 +445,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(canBeCompleted).traverse {
case (partition, fetches) =>
val records = Chunk.buffer(allRecords(partition))
fetches.traverse(_.completeRecords(records))
fetches.values.toList.traverse(_.completeRecords(records))
} >> ref.update(_.withoutFetchesAndRecords(canBeCompleted))
} else F.unit

Expand Down Expand Up @@ -469,7 +477,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
case Request.Poll() => poll
case Request.SubscribeTopics(topics, deferred) => subscribe(topics, deferred)
case Request.SubscribePattern(pattern, deferred) => subscribe(pattern, deferred)
case Request.Fetch(partition, deferred) => fetch(partition, deferred)
case Request.Fetch(partition, streamId, deferred) => fetch(partition, streamId, deferred)
case Request.Commit(offsets, deferred) => commit(offsets, deferred)
case Request.Seek(partition, offset, deferred) => seek(partition, offset, deferred)
case Request.SeekToBeginning(partitions, deferred) => seekToBeginning(partitions, deferred)
Expand All @@ -489,21 +497,31 @@ private[kafka] object KafkaConsumerActor {
deferred.complete((chunk, FetchCompletedReason.FetchedRecords))
}

type StreamId = Int

final case class State[F[_], K, V](
fetches: Map[TopicPartition, NonEmptyChain[FetchRequest[F, K, V]]],
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]],
records: Map[TopicPartition, ArrayBuffer[CommittableMessage[F, K, V]]],
onRebalances: Chain[OnRebalance[F, K, V]],
subscribed: Boolean
) {
def withOnRebalance(onRebalance: OnRebalance[F, K, V]): State[F, K, V] =
copy(onRebalances = onRebalances append onRebalance)

/**
* @return (new-state, old-fetch-to-revoke)
*/
def withFetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): State[F, K, V] = {
val fetch = NonEmptyChain.one(FetchRequest(deferred))
copy(fetches = fetches combine Map(partition -> fetch))
): (State[F, K, V], Option[FetchRequest[F, K, V]]) = {

val maybeFetchesForPartition = fetches.get(partition)
val existing = maybeFetchesForPartition.flatMap(_.get(streamId))

val fetchesForPartition = maybeFetchesForPartition.getOrElse(Map.empty).updated(streamId, FetchRequest(deferred))
(copy(fetches = fetches.updated(partition, fetchesForPartition)), existing)
}

def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
Expand Down Expand Up @@ -580,6 +598,7 @@ private[kafka] object KafkaConsumerActor {

final case class Fetch[F[_], K, V](
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
) extends Request[F, K, V]

Expand Down

0 comments on commit e38d3ed

Please sign in to comment.