Skip to content

Commit

Permalink
Convert commitAsync callback handling to ZIO sooner (#1404)
Browse files Browse the repository at this point in the history
KafkaConsumer's `commitAsync` takes a callback, which we program against
with complicated followup code. This PR attempts to convert everything
to ZIO's earlier on, making chaining followup effects easier to reason
about.

As this changes some functionality around locking and same / single
threads, here's a summary of what do we need to ensure:
* We have an exclusive lock on the consumer when calling `commitAsync`.
In `Runloop.run` this is done using `ConsumerAccess`. In the rebalance
coordinator (while rebalancing) we already have the lock as we're
calling `poll()` so no need for extra locking.
* The consumer is not used from more than one thread at the same time.
For use in `Runloop.run` we get this for free by guaranteeing exclusive
access. In the rebalance coordinator a `poll()` call is in the middle of
being executed and we need to call `commitAsync` on the same thread as
the rebalance listener is invoked.

Anything that is not calling commitAsync is free to run on any thread as
executed by the default ZIO runtime.

---------

Co-authored-by: Erik van Oosten <e.vanoosten@grons.nl>
  • Loading branch information
svroonland and erikvanoosten authored Dec 7, 2024
1 parent e689977 commit 4a0176f
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 169 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,6 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom {
.provideSome[Scope & Kafka](producer)
.provideSomeShared[Scope](
Kafka.embedded
) @@ withLiveClock @@ timeout(2.minutes)
) @@ withLiveClock @@ timeout(2.minutes) @@ TestAspect.timed

}
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
package zio.kafka.consumer.internal

import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.clients.consumer.{ MockConsumer, OffsetAndMetadata, OffsetCommitCallback, OffsetResetStrategy }
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.errors.RebalanceInProgressException
import zio.kafka.consumer.diagnostics.Diagnostics
import zio.test._
import zio.{ durationInt, Promise, Ref, ZIO }
import zio.{ durationInt, Promise, Queue, Ref, Task, Unsafe, ZIO }

import java.util.{ Map => JavaMap }
import scala.jdk.CollectionConverters.MapHasAsJava
import scala.jdk.CollectionConverters.{ MapHasAsJava, MapHasAsScala }

object CommitterSpec extends ZIOSpecDefault {
override def spec = suite("Committer")(
test("signals that a new commit is available") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter
.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
_ <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
Expand All @@ -31,98 +29,97 @@ object CommitterSpec extends ZIOSpecDefault {
},
test("handles a successful commit by completing the commit effect") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
_ <- committer.processQueuedCommits((offsets, callback) => ZIO.attempt(callback.onComplete(offsets, null)))
_ <- commitFiber.join
consumer <- createMockConsumer(offsets => ZIO.succeed(offsets))
_ <- committer.processQueuedCommits(consumer)
_ <- commitFiber.join
} yield assertCompletes
},
test("handles a failed commit by completing the commit effect with a failure") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
_ <- committer.processQueuedCommits((offsets, callback) =>
ZIO.attempt(callback.onComplete(offsets, new RuntimeException("Commit failed")))
)
result <- commitFiber.await
consumer <- createMockConsumer(_ => ZIO.fail(new RuntimeException("Commit failed")))
_ <- committer.processQueuedCommits(consumer)
result <- commitFiber.await
} yield assertTrue(result.isFailure)
},
test("retries when rebalancing") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
commitAvailable <- Queue.bounded[Unit](1)
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.offer(()).unit
)
tp = new TopicPartition("topic", 0)
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
_ <- committer.processQueuedCommits((offsets, callback) =>
ZIO.attempt(callback.onComplete(offsets, new RebalanceInProgressException("Rebalance in progress")))
)
_ <- committer.processQueuedCommits((offsets, callback) => ZIO.attempt(callback.onComplete(offsets, null)))
result <- commitFiber.await
} yield assertTrue(result.isSuccess)
_ <- commitAvailable.take
callCount <- Ref.make(0)
consumer <-
createMockConsumer { offsets =>
callCount.updateAndGet(_ + 1).flatMap { count =>
if (count == 1) {
ZIO.fail(new RebalanceInProgressException("Rebalance in progress"))
} else {
ZIO.succeed(offsets)
}
}
}
_ <- committer.processQueuedCommits(consumer)
_ <- commitAvailable.take
_ <- committer.processQueuedCommits(consumer)
_ <- commitFiber.join
} yield assertCompletes
},
test("adds 1 to the committed last offset") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
_ <- committer.commit(Map(tp -> new OffsetAndMetadata(1))).forkScoped
_ <- commitAvailable.await
committedOffsets <- Promise.make[Nothing, JavaMap[TopicPartition, OffsetAndMetadata]]
_ <- committer.processQueuedCommits((offsets, callback) =>
committedOffsets.succeed(offsets) *> ZIO.attempt(callback.onComplete(offsets, null))
)
consumer <- createMockConsumer(offsets => committedOffsets.succeed(offsets.asJava).as(offsets))
_ <- committer.processQueuedCommits(consumer)
offsetsCommitted <- committedOffsets.await
} yield assertTrue(
offsetsCommitted == Map(tp -> new OffsetAndMetadata(2)).asJava
)
},
test("batches commits from multiple partitions and offsets") {
for {
runtime <- ZIO.runtime[Any]
commitsAvailable <- Promise.make[Nothing, Unit]
nrCommitsDone <- Ref.make(0)
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable =
ZIO.whenZIO(nrCommitsDone.updateAndGet(_ + 1).map(_ == 3))(commitsAvailable.succeed(())).unit,
sameThreadRuntime = runtime
ZIO.whenZIO(nrCommitsDone.updateAndGet(_ + 1).map(_ == 3))(commitsAvailable.succeed(())).unit
)
tp = new TopicPartition("topic", 0)
tp2 = new TopicPartition("topic", 1)
Expand All @@ -131,9 +128,8 @@ object CommitterSpec extends ZIOSpecDefault {
commitFiber3 <- committer.commit(Map(tp2 -> new OffsetAndMetadata(3))).forkScoped
_ <- commitsAvailable.await
committedOffsets <- Promise.make[Nothing, JavaMap[TopicPartition, OffsetAndMetadata]]
_ <- committer.processQueuedCommits((offsets, callback) =>
committedOffsets.succeed(offsets) *> ZIO.attempt(callback.onComplete(offsets, null))
)
consumer <- createMockConsumer(offsets => committedOffsets.succeed(offsets.asJava).as(offsets))
_ <- committer.processQueuedCommits(consumer)
_ <- commitFiber1.join zip commitFiber2.join zip commitFiber3.join
offsetsCommitted <- committedOffsets.await
} yield assertTrue(
Expand All @@ -142,63 +138,85 @@ object CommitterSpec extends ZIOSpecDefault {
},
test("keeps track of pending commits") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
_ <- committer.processQueuedCommits((offsets, callback) => ZIO.attempt(callback.onComplete(offsets, null)))
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
consumer <- createMockConsumer(offsets => ZIO.succeed(offsets))
_ <- committer.processQueuedCommits(consumer)
pendingCommitsDuringCommit <- committer.pendingCommitCount
_ <- commitFiber.join
_ <- committer.cleanupPendingCommits
pendingCommitsAfterCommit <- committer.pendingCommitCount
_ <- commitFiber.join
} yield assertTrue(pendingCommitsDuringCommit == 1 && pendingCommitsAfterCommit == 0)
},
test("keep track of committed offsets") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
_ <- committer.processQueuedCommits((offsets, callback) => ZIO.attempt(callback.onComplete(offsets, null)))
committedOffsets <- committer.getCommittedOffsets
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
consumer <- createMockConsumer(offsets => ZIO.succeed(offsets))
_ <- committer.processQueuedCommits(consumer)
_ <- commitFiber.join
committedOffsets <- committer.getCommittedOffsets
} yield assertTrue(committedOffsets.offsets == Map(tp -> 0L))
},
test("clean committed offsets of no-longer assigned partitions") {
for {
runtime <- ZIO.runtime[Any]
commitAvailable <- Promise.make[Nothing, Unit]
committer <- LiveCommitter.make(
10.seconds,
Diagnostics.NoOp,
new DummyMetrics,
onCommitAvailable = commitAvailable.succeed(()).unit,
sameThreadRuntime = runtime
onCommitAvailable = commitAvailable.succeed(()).unit
)
tp = new TopicPartition("topic", 0)
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
_ <- committer.processQueuedCommits((offsets, callback) => ZIO.attempt(callback.onComplete(offsets, null)))
_ <- committer.keepCommitsForPartitions(Set.empty)
committedOffsets <- committer.getCommittedOffsets
commitFiber <- committer.commit(Map(tp -> new OffsetAndMetadata(0))).forkScoped
_ <- commitAvailable.await
consumer <- createMockConsumer(offsets => ZIO.succeed(offsets))
_ <- committer.processQueuedCommits(consumer)
_ <- commitFiber.join
_ <- committer.keepCommitsForPartitions(Set.empty)
committedOffsets <- committer.getCommittedOffsets
} yield assertTrue(committedOffsets.offsets.isEmpty)
}
) @@ TestAspect.withLiveClock @@ TestAspect.nonFlaky(100)

private def createMockConsumer(
onCommitAsync: Map[TopicPartition, OffsetAndMetadata] => Task[Map[TopicPartition, OffsetAndMetadata]]
): ZIO[Any, Nothing, MockConsumer[Array[Byte], Array[Byte]]] =
ZIO.runtime[Any].map { runtime =>
new MockConsumer[Array[Byte], Array[Byte]](OffsetResetStrategy.LATEST) {
override def commitAsync(
offsets: JavaMap[TopicPartition, OffsetAndMetadata],
callback: OffsetCommitCallback
): Unit =
Unsafe.unsafe { implicit unsafe =>
runtime.unsafe
.run(
onCommitAsync(offsets.asScala.toMap)
.tapBoth(
e => ZIO.attempt(callback.onComplete(offsets, e.asInstanceOf[Exception])),
offsets => ZIO.attempt(callback.onComplete(offsets.asJava, null))
)
.ignore
)
.getOrThrowFiberFailure()
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.apache.kafka.common.TopicPartition
import zio.kafka.ZIOSpecDefaultSlf4j
import zio.kafka.consumer.diagnostics.Diagnostics
import zio.kafka.consumer.internal.Committer.CommitOffsets
import zio.kafka.consumer.internal.ConsumerAccess.ByteArrayKafkaConsumer
import zio.kafka.consumer.internal.LiveCommitter.Commit
import zio.kafka.consumer.internal.RebalanceCoordinator.RebalanceEvent
import zio.kafka.consumer.internal.Runloop.ByteArrayCommittableRecord
Expand Down Expand Up @@ -96,8 +97,7 @@ object RebalanceCoordinatorSpec extends ZIOSpecDefaultSlf4j {
records = createTestRecords(3)
recordsPulled <- Promise.make[Nothing, Unit]
_ <- streamControl.offerRecords(records)
runtime <- ZIO.runtime[Any]
committer <- LiveCommitter.make(10.seconds, Diagnostics.NoOp, mockMetrics, ZIO.unit, runtime)
committer <- LiveCommitter.make(10.seconds, Diagnostics.NoOp, mockMetrics, ZIO.unit)

streamDrain <-
streamControl.stream
Expand Down Expand Up @@ -173,16 +173,19 @@ object RebalanceCoordinatorSpec extends ZIOSpecDefaultSlf4j {
settings: ConsumerSettings = ConsumerSettings(List("")).withCommitTimeout(1.second),
rebalanceSafeCommits: Boolean = false
): ZIO[Scope, Throwable, RebalanceCoordinator] =
Semaphore.make(1).map(new ConsumerAccess(mockConsumer, _)).map { consumerAccess =>
new RebalanceCoordinator(
lastEvent,
settings.withRebalanceSafeCommits(rebalanceSafeCommits),
consumerAccess,
5.seconds,
ZIO.succeed(assignedStreams),
committer
)
}
Semaphore
.make(1)
.map(new ConsumerAccess(mockConsumer, _))
.map { consumerAccess =>
new RebalanceCoordinator(
lastEvent,
settings.withRebalanceSafeCommits(rebalanceSafeCommits),
consumerAccess,
5.seconds,
ZIO.succeed(assignedStreams),
committer
)
}

private def createTestRecords(count: Int): Chunk[ByteArrayCommittableRecord] =
Chunk.fromIterable(
Expand All @@ -205,10 +208,8 @@ object RebalanceCoordinatorSpec extends ZIOSpecDefaultSlf4j {
abstract class MockCommitter extends Committer {
override val commit: Map[TopicPartition, OffsetAndMetadata] => Task[Unit] = _ => ZIO.unit

override def processQueuedCommits(
commitAsync: (java.util.Map[TopicPartition, OffsetAndMetadata], OffsetCommitCallback) => zio.Task[Unit],
executeOnEmpty: Boolean
): zio.Task[Unit] = ZIO.unit
override def processQueuedCommits(consumer: ByteArrayKafkaConsumer, executeOnEmpty: Boolean): Task[Unit] = ZIO.unit

override def queueSize: UIO[Int] = ZIO.succeed(0)
override def pendingCommitCount: UIO[Int] = ZIO.succeed(0)
override def getPendingCommits: UIO[CommitOffsets] = ZIO.succeed(CommitOffsets.empty)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package zio.kafka.consumer.internal

import org.apache.kafka.clients.consumer.{ OffsetAndMetadata, OffsetCommitCallback }
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import zio.kafka.consumer.internal.Committer.CommitOffsets
import zio.kafka.consumer.internal.ConsumerAccess.ByteArrayKafkaConsumer
import zio.kafka.consumer.internal.LiveCommitter.Commit
import zio.{ Chunk, Task, UIO }

import java.lang.Math.max
import java.util.{ Map => JavaMap }
import scala.collection.mutable

private[internal] trait Committer {
Expand All @@ -21,14 +21,13 @@ private[internal] trait Committer {
* WARNING: this method is used during a rebalance from the same-thread-runtime. This restricts what ZIO operations
* may be used. Please see [[RebalanceCoordinator]] for more information.
*
* @param commitAsync
* Function 'commitAsync' on the KafkaConsumer. This is isolated from the whole KafkaConsumer for testing purposes.
* The caller should ensure exclusive access to the KafkaConsumer.
* @param consumer
* KafkaConsumer to use. The caller is responsible or guaranteeing exclusive access.
* @param executeOnEmpty
* Execute commitAsync() even if there are no commits
*/
def processQueuedCommits(
commitAsync: (JavaMap[TopicPartition, OffsetAndMetadata], OffsetCommitCallback) => Task[Unit],
consumer: ByteArrayKafkaConsumer,
executeOnEmpty: Boolean = false
): Task[Unit]

Expand Down
Loading

0 comments on commit 4a0176f

Please sign in to comment.