Skip to content

Commit

Permalink
Guard against multiple FetchRequest on the same (partition,consumer)
Browse files Browse the repository at this point in the history
  • Loading branch information
backuitist authored and Viktor Lövgren committed Apr 1, 2019
1 parent 837894b commit f6da986
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
7 changes: 6 additions & 1 deletion src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package fs2.kafka

import java.util.concurrent.atomic.AtomicInteger

import cats.{Foldable, Reducible}
import cats.data.{NonEmptyList, NonEmptySet}
import cats.effect._
Expand Down Expand Up @@ -359,7 +361,10 @@ private[kafka] object KafkaConsumer {
actorFiber combine pollsFiber
}

private val streamIdCounter = new AtomicInteger(0)

override def partitionedStream: Stream[F, Stream[F, CommittableMessage[F, K, V]]] = {
val streamId = streamIdCounter.getAndIncrement()
val chunkQueue: F[Queue[F, Option[Chunk[CommittableMessage[F, K, V]]]]] =
Queue.bounded(settings.maxPrefetchBatches - 1)

Expand All @@ -382,7 +387,7 @@ private[kafka] object KafkaConsumer {
partitionRevoked.tryGet.flatMap {
case None =>
Deferred[F, PartitionRequest].flatMap { deferred =>
val request = Request.Fetch(partition, deferred)
val request = Request.Fetch(partition, streamId, deferred)
val fetch = requests.enqueue1(request) >> deferred.get
F.race(shutdown, fetch).flatMap {
case Left(()) => F.unit
Expand Down
42 changes: 32 additions & 10 deletions src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.time.Duration
import java.util
import java.util.regex.Pattern

import cats.data.{Chain, NonEmptyChain, NonEmptyList, NonEmptySet}
import cats.data.{Chain, NonEmptyList, NonEmptySet}
import cats.effect.concurrent.{Deferred, Ref}
import cats.effect.{ConcurrentEffect, ContextShift, IO, Timer}
import cats.instances.list._
Expand Down Expand Up @@ -256,6 +256,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](

private[this] def fetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): F[Unit] = {
val assigned =
Expand All @@ -265,7 +266,16 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](

def storeFetch =
ref
.updateAndGet(_.withFetch(partition, deferred))
.modify { state =>
val (newState, oldFetch) = state.withFetch(partition, streamId, deferred)
(newState, (newState, oldFetch))
}
.flatMap {
case (newState, None) => F.pure(newState)
case (newState, Some(oldFetch)) =>
// TODO: Log warning
oldFetch.completeRevoked(Chunk.empty).as(newState)
}
.log(StoredFetch(partition, deferred, _))

def completeRevoked =
Expand Down Expand Up @@ -319,7 +329,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(withRecords).traverse {
case (partition, partitionFetches) =>
val records = Chunk.buffer(state.records(partition))
partitionFetches.traverse(_.completeRevoked(records))
partitionFetches.values.toList.traverse(_.completeRevoked(records))
} >> ref
.updateAndGet(_.withoutFetchesAndRecords(withRecords))
.log(RevokedFetchesWithRecords(state.records.filterKeysStrict(withRecords), _))
Expand All @@ -329,7 +339,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
if (withoutRecords.nonEmpty) {
state.fetches
.filterKeysStrictValuesList(withoutRecords)
.traverse(_.traverse(_.completeRevoked(Chunk.empty))) >>
.traverse(_.values.toList.traverse(_.completeRevoked(Chunk.empty))) >>
ref
.updateAndGet(_.withoutFetches(withoutRecords))
.log(RevokedFetchesWithoutRecords(withoutRecords, _))
Expand Down Expand Up @@ -472,7 +482,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(canBeCompleted).traverse {
case (partition, fetches) =>
val records = Chunk.buffer(allRecords(partition))
fetches.traverse(_.completeRecords(records))
fetches.values.toList.traverse(_.completeRecords(records))
} >> ref
.updateAndGet(_.withoutFetchesAndRecords(canBeCompleted))
.log(CompletedFetchesWithRecords(allRecords.filterKeysStrict(canBeCompleted), _))
Expand Down Expand Up @@ -509,7 +519,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
case Request.Poll() => poll
case Request.SubscribeTopics(topics, deferred) => subscribe(topics, deferred)
case Request.SubscribePattern(pattern, deferred) => subscribe(pattern, deferred)
case Request.Fetch(partition, deferred) => fetch(partition, deferred)
case Request.Fetch(partition, streamId, deferred) => fetch(partition, streamId, deferred)
case Request.Commit(offsets, deferred) => commit(offsets, deferred)
case Request.Seek(partition, offset, deferred) => seek(partition, offset, deferred)
case Request.SeekToBeginning(partitions, deferred) => seekToBeginning(partitions, deferred)
Expand All @@ -532,21 +542,32 @@ private[kafka] object KafkaConsumerActor {
"FetchRequest$" + System.identityHashCode(this)
}

type StreamId = Int

final case class State[F[_], K, V](
fetches: Map[TopicPartition, NonEmptyChain[FetchRequest[F, K, V]]],
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]],
records: Map[TopicPartition, ArrayBuffer[CommittableMessage[F, K, V]]],
onRebalances: Chain[OnRebalance[F, K, V]],
subscribed: Boolean
) {
def withOnRebalance(onRebalance: OnRebalance[F, K, V]): State[F, K, V] =
copy(onRebalances = onRebalances append onRebalance)

/**
* @return (new-state, old-fetch-to-revoke)
*/
def withFetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): State[F, K, V] = {
val fetch = NonEmptyChain.one(FetchRequest(deferred))
copy(fetches = fetches combine Map(partition -> fetch))
): (State[F, K, V], Option[FetchRequest[F, K, V]]) = {

val maybeFetchesForPartition = fetches.get(partition)
val existing = maybeFetchesForPartition.flatMap(_.get(streamId))

val fetchesForPartition =
maybeFetchesForPartition.getOrElse(Map.empty).updated(streamId, FetchRequest(deferred))
(copy(fetches = fetches.updated(partition, fetchesForPartition)), existing)
}

def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
Expand Down Expand Up @@ -640,6 +661,7 @@ private[kafka] object KafkaConsumerActor {

final case class Fetch[F[_], K, V](
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
) extends Request[F, K, V]

Expand Down

0 comments on commit f6da986

Please sign in to comment.