Skip to content

Commit

Permalink
Rebase on master; prevent requests after shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Viktor Lövgren committed Apr 1, 2019
1 parent f6da986 commit 1d73f2c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 53 deletions.
77 changes: 44 additions & 33 deletions src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package fs2.kafka

import java.util.concurrent.atomic.AtomicInteger

import cats.{Foldable, Reducible}
import cats.data.{NonEmptyList, NonEmptySet}
import cats.effect._
Expand Down Expand Up @@ -342,6 +340,7 @@ private[kafka] object KafkaConsumer {
settings: ConsumerSettings[K, V],
actor: Fiber[F, Unit],
polls: Fiber[F, Unit],
streamIdRef: Ref[F, Int],
id: Int
)(implicit F: Concurrent[F]): KafkaConsumer[F, K, V] =
new KafkaConsumer[F, K, V] {
Expand All @@ -361,36 +360,36 @@ private[kafka] object KafkaConsumer {
actorFiber combine pollsFiber
}

private val streamIdCounter = new AtomicInteger(0)

override def partitionedStream: Stream[F, Stream[F, CommittableMessage[F, K, V]]] = {
val streamId = streamIdCounter.getAndIncrement()
val chunkQueue: F[Queue[F, Option[Chunk[CommittableMessage[F, K, V]]]]] =
Queue.bounded(settings.maxPrefetchBatches - 1)

type PartitionRequest =
(Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)

def enqueueStream(
streamId: Int,
partition: TopicPartition,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[Unit] =
chunkQueue.flatMap { chunks =>
Deferred[F, Unit].flatMap { dequeueDone =>
Deferred.tryable[F, Unit].flatMap { partitionRevoked =>
Deferred.tryable[F, Unit].flatMap { stopRequests =>
val shutdown = F.race(fiber.join.attempt, dequeueDone.get).void
partitions.enqueue1 {
Stream.eval {
F.guarantee {
Stream
.repeatEval {
partitionRevoked.tryGet.flatMap {
stopRequests.tryGet.flatMap {
case None =>
Deferred[F, PartitionRequest].flatMap { deferred =>
val request = Request.Fetch(partition, streamId, deferred)
val fetch = requests.enqueue1(request) >> deferred.get
F.race(shutdown, fetch).flatMap {
case Left(()) => F.unit
case Left(()) =>
stopRequests.complete(())

case Right((chunk, reason)) =>
val enqueueChunk =
if (chunk.nonEmpty)
Expand All @@ -399,7 +398,7 @@ private[kafka] object KafkaConsumer {

val completeRevoked =
if (reason.topicPartitionRevoked)
partitionRevoked.complete(())
stopRequests.complete(())
else F.unit

enqueueChunk >> completeRevoked
Expand All @@ -408,11 +407,12 @@ private[kafka] object KafkaConsumer {

case Some(()) =>
// Prevent issuing additional requests after partition is
// revoked, in case stream interruption isn't fast enough
// revoked or shutdown happens, in case the stream isn't
// interrupted fast enough
F.unit
}
}
.interruptWhen(F.race(shutdown, partitionRevoked.get).void.attempt)
.interruptWhen(F.race(shutdown, stopRequests.get).void.attempt)
.compile
.drain
}(F.race(dequeueDone.get, chunks.enqueue1(None)).void)
Expand All @@ -430,22 +430,26 @@ private[kafka] object KafkaConsumer {
}

def enqueueStreams(
streamId: Int,
assigned: NonEmptySet[TopicPartition],
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[Unit] = assigned.foldLeft(F.unit)(_ >> enqueueStream(_, partitions))
): F[Unit] = assigned.foldLeft(F.unit)(_ >> enqueueStream(streamId, _, partitions))

def onRebalance(
streamId: Int,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): OnRebalance[F, K, V] = OnRebalance(
onAssigned = assigned => enqueueStreams(assigned, partitions),
onAssigned = assigned => enqueueStreams(streamId, assigned, partitions),
onRevoked = _ => F.unit
)

def requestAssignment(
streamId: Int,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[SortedSet[TopicPartition]] = {
Deferred[F, Either[Throwable, SortedSet[TopicPartition]]].flatMap { deferred =>
val request = Request.Assignment[F, K, V](deferred, Some(onRebalance(partitions)))
val request =
Request.Assignment[F, K, V](deferred, Some(onRebalance(streamId, partitions)))
val assignment = requests.enqueue1(request) >> deferred.get.rethrow
F.race(fiber.join.attempt, assignment).map {
case Left(_) => SortedSet.empty[TopicPartition]
Expand All @@ -454,20 +458,25 @@ private[kafka] object KafkaConsumer {
}
}

def initialEnqueue(partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]): F[Unit] =
requestAssignment(partitions).flatMap { assigned =>
def initialEnqueue(
streamId: Int,
partitions: Queue[F, Stream[F, CommittableMessage[F, K, V]]]
): F[Unit] =
requestAssignment(streamId, partitions).flatMap { assigned =>
if (assigned.nonEmpty) {
val nonEmpty = NonEmptySet.fromSetUnsafe(assigned)
enqueueStreams(nonEmpty, partitions)
enqueueStreams(streamId, nonEmpty, partitions)
} else F.unit
}

val partitionQueue: F[Queue[F, Stream[F, CommittableMessage[F, K, V]]]] =
Queue.unbounded[F, Stream[F, CommittableMessage[F, K, V]]]

Stream.eval(partitionQueue).flatMap { partitions =>
Stream.eval(initialEnqueue(partitions)) >>
partitions.dequeue.interruptWhen(fiber.join.attempt)
Stream.eval(streamIdRef.modify(n => (n + 1, n))).flatMap { streamId =>
Stream.eval(initialEnqueue(streamId, partitions)) >>
partitions.dequeue.interruptWhen(fiber.join.attempt)
}
}
}

Expand Down Expand Up @@ -628,20 +637,22 @@ private[kafka] object KafkaConsumer {
Resource.liftF(Jitter.default[F]).flatMap { implicit jitter =>
Resource.liftF(F.delay(new Object().hashCode)).flatMap { id =>
Resource.liftF(Logging.default[F](id)).flatMap { implicit logging =>
executionContextResource(settings).flatMap { executionContext =>
createConsumer(settings, executionContext).flatMap { synchronized =>
val actor =
new KafkaConsumerActor(
settings = settings,
executionContext = executionContext,
ref = ref,
requests = requests,
synchronized = synchronized
)

startConsumerActor(requests, polls, actor).flatMap { actor =>
startPollScheduler(polls, settings.pollInterval).map { polls =>
createKafkaConsumer(requests, settings, actor, polls, id)
Resource.liftF(Ref.of[F, Int](0)).flatMap { streamId =>
executionContextResource(settings).flatMap { executionContext =>
createConsumer(settings, executionContext).flatMap { synchronized =>
val actor =
new KafkaConsumerActor(
settings = settings,
executionContext = executionContext,
ref = ref,
requests = requests,
synchronized = synchronized
)

startConsumerActor(requests, polls, actor).flatMap { actor =>
startPollScheduler(polls, settings.pollInterval).map { polls =>
createKafkaConsumer(requests, settings, actor, polls, streamId, id)
}
}
}
}
Expand Down
40 changes: 22 additions & 18 deletions src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@ import java.util.regex.Pattern
import cats.data.{Chain, NonEmptyList, NonEmptySet}
import cats.effect.concurrent.{Deferred, Ref}
import cats.effect.{ConcurrentEffect, ContextShift, IO, Timer}
import cats.instances.list._
import cats.instances.map._
import cats.syntax.applicativeError._
import cats.syntax.flatMap._
import cats.syntax.monadError._
import cats.syntax.semigroup._
import cats.syntax.traverse._
import cats.implicits._
import fs2.Chunk
import fs2.concurrent.Queue
import fs2.kafka._
Expand Down Expand Up @@ -267,16 +261,18 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
def storeFetch =
ref
.modify { state =>
val (newState, oldFetch) = state.withFetch(partition, streamId, deferred)
val (newState, oldFetch) =
state.withFetch(partition, streamId, deferred)
(newState, (newState, oldFetch))
}
.flatMap {
case (newState, None) => F.pure(newState)
case (newState, Some(oldFetch)) =>
// TODO: Log warning
oldFetch.completeRevoked(Chunk.empty).as(newState)
case (newState, oldFetch) =>
log(StoredFetch(partition, deferred, newState)) >>
oldFetch.fold(F.unit) { fetch =>
fetch.completeRevoked(Chunk.empty) >>
log(RevokedPreviousFetch(partition, streamId))
}
}
.log(StoredFetch(partition, deferred, _))

def completeRevoked =
deferred.complete((Chunk.empty, FetchCompletedReason.TopicPartitionRevoked))
Expand Down Expand Up @@ -561,13 +557,21 @@ private[kafka] object KafkaConsumerActor {
streamId: Int,
deferred: Deferred[F, (Chunk[CommittableMessage[F, K, V]], FetchCompletedReason)]
): (State[F, K, V], Option[FetchRequest[F, K, V]]) = {
val oldPartitionFetches =
fetches.get(partition)

val oldPartitionFetch =
oldPartitionFetches.flatMap(_.get(streamId))

val newPartitionFetches =
oldPartitionFetches
.getOrElse(Map.empty)
.updated(streamId, FetchRequest(deferred))

val maybeFetchesForPartition = fetches.get(partition)
val existing = maybeFetchesForPartition.flatMap(_.get(streamId))
val newFetches =
fetches.updated(partition, newPartitionFetches)

val fetchesForPartition =
maybeFetchesForPartition.getOrElse(Map.empty).updated(streamId, FetchRequest(deferred))
(copy(fetches = fetches.updated(partition, fetchesForPartition)), existing)
(copy(fetches = newFetches), oldPartitionFetch)
}

def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
Expand Down
9 changes: 9 additions & 0 deletions src/main/scala/fs2/kafka/internal/LogEntry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ private[kafka] object LogEntry {
s"Stored records for partitions [${recordsString(records)}]. Current state [$state]."
}

final case class RevokedPreviousFetch(
partition: TopicPartition,
streamId: Int
) extends LogEntry {
override def level: LogLevel = Warn
override def message: String =
s"Revoked previous fetch for partition [$partition] in stream with id [$streamId]."
}

def recordsString[F[_], K, V](
records: Records[F, K, V]
): String =
Expand Down
3 changes: 1 addition & 2 deletions src/test/scala/fs2/kafka/KafkaConsumerSpec.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package fs2.kafka


import cats.effect.IO
import cats.effect.concurrent.Ref
import cats.implicits._
Expand Down Expand Up @@ -30,7 +29,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
.using(consumerSettings(config))
.evalTap(_.subscribeTo(topic))
.evalTap(consumer => IO(consumer.toString should startWith("KafkaConsumer$")).void)
.evalMap(c => IO.sleep(3.seconds) *> IO.pure(c)) // sleep a bit to trigger potential race condition with _.stream
.evalMap(IO.sleep(3.seconds).as) // sleep a bit to trigger potential race condition with _.stream
.flatMap(_.stream)
.map(message => message.record.key -> message.record.value)
.interruptAfter(10.seconds) // wait some time to catch potentially duplicated records
Expand Down

0 comments on commit 1d73f2c

Please sign in to comment.