Skip to content

Commit

Permalink
Merge pull request #107 from backuitist/multiple_fetch_request
Browse files Browse the repository at this point in the history
Multiple fetch request
  • Loading branch information
vlovgr authored Apr 1, 2019
2 parents 72db472 + 1d73f2c commit 2c08d2c
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 62 deletions.
74 changes: 45 additions & 29 deletions src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ private[kafka] object KafkaConsumer {
settings: ConsumerSettings[K, V],
actor: Fiber[F, Unit],
polls: Fiber[F, Unit],
streamIdRef: Ref[F, Int],
id: Int
)(implicit F: Concurrent[F]): KafkaConsumer[F, K, V] =
new KafkaConsumer[F, K, V] {
Expand Down Expand Up @@ -367,25 +368,28 @@ private[kafka] object KafkaConsumer {
(Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)

def enqueueStream(
streamId: Int,
partition: TopicPartition,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[Unit] =
chunkQueue.flatMap { chunks =>
Deferred[F, Unit].flatMap { dequeueDone =>
Deferred.tryable[F, Unit].flatMap { partitionRevoked =>
Deferred.tryable[F, Unit].flatMap { stopRequests =>
val shutdown = F.race(fiber.join.attempt, dequeueDone.get).void
partitions.enqueue1 {
Stream.eval {
F.guarantee {
Stream
.repeatEval {
partitionRevoked.tryGet.flatMap {
stopRequests.tryGet.flatMap {
case None =>
Deferred[F, PartitionRequest].flatMap { deferred =>
val request = Request.Fetch(partition, deferred)
val request = Request.Fetch(partition, streamId, deferred)
val fetch = requests.enqueue1(request) >> deferred.get
F.race(shutdown, fetch).flatMap {
case Left(()) => F.unit
case Left(()) =>
stopRequests.complete(())

case Right((chunk, reason)) =>
val enqueueChunk =
if (chunk.nonEmpty)
Expand All @@ -394,7 +398,7 @@ private[kafka] object KafkaConsumer {

val completeRevoked =
if (reason.topicPartitionRevoked)
partitionRevoked.complete(())
stopRequests.complete(())
else F.unit

enqueueChunk >> completeRevoked
Expand All @@ -403,11 +407,12 @@ private[kafka] object KafkaConsumer {

case Some(()) =>
// Prevent issuing additional requests after partition is
// revoked, in case stream interruption isn't fast enough
// revoked or shutdown happens, in case the stream isn't
// interrupted fast enough
F.unit
}
}
.interruptWhen(F.race(shutdown, partitionRevoked.get).void.attempt)
.interruptWhen(F.race(shutdown, stopRequests.get).void.attempt)
.compile
.drain
}(F.race(dequeueDone.get, chunks.enqueue1(None)).void)
Expand All @@ -425,22 +430,26 @@ private[kafka] object KafkaConsumer {
}

def enqueueStreams(
streamId: Int,
assigned: NonEmptySet[TopicPartition],
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[Unit] = assigned.foldLeft(F.unit)(_ >> enqueueStream(_, partitions))
): F[Unit] = assigned.foldLeft(F.unit)(_ >> enqueueStream(streamId, _, partitions))

def onRebalance(
streamId: Int,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): OnRebalance[F, K, V] = OnRebalance(
onAssigned = assigned => enqueueStreams(assigned, partitions),
onAssigned = assigned => enqueueStreams(streamId, assigned, partitions),
onRevoked = _ => F.unit
)

def requestAssignment(
streamId: Int,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[SortedSet[TopicPartition]] = {
Deferred[F, Either[Throwable, SortedSet[TopicPartition]]].flatMap { deferred =>
val request = Request.Assignment[F, K, V](deferred, Some(onRebalance(partitions)))
val request =
Request.Assignment[F, K, V](deferred, Some(onRebalance(streamId, partitions)))
val assignment = requests.enqueue1(request) >> deferred.get.rethrow
F.race(fiber.join.attempt, assignment).map {
case Left(_) => SortedSet.empty[TopicPartition]
Expand All @@ -449,20 +458,25 @@ private[kafka] object KafkaConsumer {
}
}

def initialEnqueue(partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]): F[Unit] =
requestAssignment(partitions).flatMap { assigned =>
def initialEnqueue(
streamId: Int,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[Unit] =
requestAssignment(streamId, partitions).flatMap { assigned =>
if (assigned.nonEmpty) {
val nonEmpty = NonEmptySet.fromSetUnsafe(assigned)
enqueueStreams(nonEmpty, partitions)
enqueueStreams(streamId, nonEmpty, partitions)
} else F.unit
}

val partitionQueue: F[Queue[F, Stream[F, CommittableMessage[F, K, V]]]] =
Queue.unbounded[F, Stream[F, CommittableMessage[F, K, V]]]

Stream.eval(partitionQueue).flatMap { partitions =>
Stream.eval(initialEnqueue(partitions)) >>
partitions.dequeue.interruptWhen(fiber.join.attempt)
Stream.eval(streamIdRef.modify(n => (n + 1, n))).flatMap { streamId =>
Stream.eval(initialEnqueue(streamId, partitions)) >>
partitions.dequeue.interruptWhen(fiber.join.attempt)
}
}
}

Expand Down Expand Up @@ -623,20 +637,22 @@ private[kafka] object KafkaConsumer {
Resource.liftF(Jitter.default[F]).flatMap { implicit jitter =>
Resource.liftF(F.delay(new Object().hashCode)).flatMap { id =>
Resource.liftF(Logging.default[F](id)).flatMap { implicit logging =>
executionContextResource(settings).flatMap { executionContext =>
createConsumer(settings, executionContext).flatMap { synchronized =>
val actor =
new KafkaConsumerActor(
settings = settings,
executionContext = executionContext,
ref = ref,
requests = requests,
synchronized = synchronized
)

startConsumerActor(requests, polls, actor).flatMap { actor =>
startPollScheduler(polls, settings.pollInterval).map { polls =>
createKafkaConsumer(requests, settings, actor, polls, id)
Resource.liftF(Ref.of[F, Int](0)).flatMap { streamId =>
executionContextResource(settings).flatMap { executionContext =>
createConsumer(settings, executionContext).flatMap { synchronized =>
val actor =
new KafkaConsumerActor(
settings = settings,
executionContext = executionContext,
ref = ref,
requests = requests,
synchronized = synchronized
)

startConsumerActor(requests, polls, actor).flatMap { actor =>
startPollScheduler(polls, settings.pollInterval).map { polls =>
createKafkaConsumer(requests, settings, actor, polls, streamId, id)
}
}
}
}
Expand Down
62 changes: 44 additions & 18 deletions src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,10 @@ import java.time.Duration
import java.util
import java.util.regex.Pattern

import cats.data.{Chain, NonEmptyChain, NonEmptyList, NonEmptySet}
import cats.data.{Chain, NonEmptyList, NonEmptySet}
import cats.effect.concurrent.{Deferred, Ref}
import cats.effect.{ConcurrentEffect, ContextShift, IO, Timer}
import cats.instances.list._
import cats.instances.map._
import cats.syntax.applicativeError._
import cats.syntax.flatMap._
import cats.syntax.monadError._
import cats.syntax.semigroup._
import cats.syntax.traverse._
import cats.implicits._
import fs2.Chunk
import fs2.concurrent.Queue
import fs2.kafka._
Expand Down Expand Up @@ -256,6 +250,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](

private[this] def fetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): F[Unit] = {
val assigned =
Expand All @@ -265,8 +260,19 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](

def storeFetch =
ref
.updateAndGet(_.withFetch(partition, deferred))
.log(StoredFetch(partition, deferred, _))
.modify { state =>
val (newState, oldFetch) =
state.withFetch(partition, streamId, deferred)
(newState, (newState, oldFetch))
}
.flatMap {
case (newState, oldFetch) =>
log(StoredFetch(partition, deferred, newState)) >>
oldFetch.fold(F.unit) { fetch =>
fetch.completeRevoked(Chunk.empty) >>
log(RevokedPreviousFetch(partition, streamId))
}
}

def completeRevoked =
deferred.complete((Chunk.empty, FetchCompletedReason.TopicPartitionRevoked))
Expand Down Expand Up @@ -319,7 +325,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(withRecords).traverse {
case (partition, partitionFetches) =>
val records = Chunk.buffer(state.records(partition))
partitionFetches.traverse(_.completeRevoked(records))
partitionFetches.values.toList.traverse(_.completeRevoked(records))
} >> ref
.updateAndGet(_.withoutFetchesAndRecords(withRecords))
.log(RevokedFetchesWithRecords(state.records.filterKeysStrict(withRecords), _))
Expand All @@ -329,7 +335,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
if (withoutRecords.nonEmpty) {
state.fetches
.filterKeysStrictValuesList(withoutRecords)
.traverse(_.traverse(_.completeRevoked(Chunk.empty))) >>
.traverse(_.values.toList.traverse(_.completeRevoked(Chunk.empty))) >>
ref
.updateAndGet(_.withoutFetches(withoutRecords))
.log(RevokedFetchesWithoutRecords(withoutRecords, _))
Expand Down Expand Up @@ -475,7 +481,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
state.fetches.filterKeysStrictList(canBeCompleted).traverse {
case (partition, fetches) =>
val records = Chunk.buffer(allRecords(partition))
fetches.traverse(_.completeRecords(records))
fetches.values.toList.traverse(_.completeRecords(records))
} >> ref
.updateAndGet(_.withoutFetchesAndRecords(canBeCompleted))
.log(CompletedFetchesWithRecords(allRecords.filterKeysStrict(canBeCompleted), _))
Expand Down Expand Up @@ -512,7 +518,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
case Request.Poll() => poll
case Request.SubscribeTopics(topics, deferred) => subscribe(topics, deferred)
case Request.SubscribePattern(pattern, deferred) => subscribe(pattern, deferred)
case Request.Fetch(partition, deferred) => fetch(partition, deferred)
case Request.Fetch(partition, streamId, deferred) => fetch(partition, streamId, deferred)
case Request.Commit(offsets, deferred) => commit(offsets, deferred)
case Request.Seek(partition, offset, deferred) => seek(partition, offset, deferred)
case Request.SeekToBeginning(partitions, deferred) => seekToBeginning(partitions, deferred)
Expand All @@ -535,8 +541,10 @@ private[kafka] object KafkaConsumerActor {
"FetchRequest$" + System.identityHashCode(this)
}

type StreamId = Int

final case class State[F[_], K, V](
fetches: Map[TopicPartition, NonEmptyChain[FetchRequest[F, K, V]]],
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]],
records: Map[TopicPartition, ArrayBuffer[CommittableMessage[F, K, V]]],
onRebalances: Chain[OnRebalance[F, K, V]],
subscribed: Boolean,
Expand All @@ -545,12 +553,29 @@ private[kafka] object KafkaConsumerActor {
def withOnRebalance(onRebalance: OnRebalance[F, K, V]): State[F, K, V] =
copy(onRebalances = onRebalances append onRebalance)

/**
* @return (new-state, old-fetch-to-revoke)
*/
def withFetch(
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): State[F, K, V] = {
val fetch = NonEmptyChain.one(FetchRequest(deferred))
copy(fetches = fetches combine Map(partition -> fetch))
): (State[F, K, V], Option[FetchRequest[F, K, V]]) = {
val oldPartitionFetches =
fetches.get(partition)

val oldPartitionFetch =
oldPartitionFetches.flatMap(_.get(streamId))

val newPartitionFetches =
oldPartitionFetches
.getOrElse(Map.empty)
.updated(streamId, FetchRequest(deferred))

val newFetches =
fetches.updated(partition, newPartitionFetches)

(copy(fetches = newFetches), oldPartitionFetch)
}

def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
Expand Down Expand Up @@ -648,6 +673,7 @@ private[kafka] object KafkaConsumerActor {

final case class Fetch[F[_], K, V](
partition: TopicPartition,
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
) extends Request[F, K, V]

Expand Down
9 changes: 9 additions & 0 deletions src/main/scala/fs2/kafka/internal/LogEntry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ private[kafka] object LogEntry {
s"Stored records for partitions [${recordsString(records)}]. Current state [$state]."
}

final case class RevokedPreviousFetch(
partition: TopicPartition,
streamId: Int
) extends LogEntry {
override def level: LogLevel = Warn
override def message: String =
s"Revoked previous fetch for partition [$partition] in stream with id [$streamId]."
}

def recordsString[F[_], K, V](
records: Records[F, K, V]
): String =
Expand Down
Loading

0 comments on commit 2c08d2c

Please sign in to comment.