From cad6a1fa6f8e0dc2cd89adb561711b22acf0ebdb Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 17 Aug 2018 18:56:31 +0900 Subject: [PATCH 01/13] [SPARK-25151][SS] Apply Apache Commons Pool to KafkaDataConsumer * Fix scala style check violation * Add commons-pool2 version property to root pom.xml * Address eviction to pool which enables cleaning up idle consumers * Fix missed spots * New approach: pool both kafka consumers and fetched data individually ** This approach enables applying different policies on pool ** This approach also enables evicting consumers and fetched data for invalid topics/partitions ** This approach can handle the case which multiple tasks access same topic partition and group id * Address review comments from @gaborgsomogyi * Address another review comments on @gaborgsomogyi * Fix silly mistake * Add sanity/edge-case tests on KafkaDataConsumer ** also add basic metrics to verify behavior on fetched pool * Apply new fetched data pool metrics to unit tests in FetchedDataPoolSuite * Fix test failures (forgot to reset TestContext) * Address another review comments from @gaborgsomogyi * Try best-effort to isolate environments for UTs * Fix scalastyle * Reflect low-hanging fruits review comments * Use ConfigBuilder, rename config, etc. * Address build failure because of class rename * Address review comment --- external/kafka-0-10-sql/pom.xml | 5 + .../spark/sql/kafka010/FetchedDataPool.scala | 190 +++++ .../kafka010/InternalKafkaConsumerPool.scala | 226 ++++++ .../spark/sql/kafka010/KafkaBatch.scala | 2 +- .../kafka010/KafkaBatchPartitionReader.scala | 11 +- .../sql/kafka010/KafkaContinuousStream.scala | 2 +- .../sql/kafka010/KafkaDataConsumer.scala | 677 +++++++++--------- .../sql/kafka010/KafkaMicroBatchStream.scala | 7 +- .../spark/sql/kafka010/KafkaRelation.scala | 2 +- .../spark/sql/kafka010/KafkaSource.scala | 3 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 5 +- .../apache/spark/sql/kafka010/package.scala | 21 + .../sql/kafka010/FetchedDataPoolSuite.scala | 337 +++++++++ .../InternalKafkaConsumerPoolSuite.scala | 316 ++++++++ .../sql/kafka010/KafkaDataConsumerSuite.scala | 203 +++++- .../kafka010/KafkaMicroBatchSourceSuite.scala | 1 - pom.xml | 2 + 17 files changed, 1617 insertions(+), 393 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 827ceb89a0c3..feba787e9901 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -72,6 +72,11 @@ kafka-clients ${kafka.version} + + org.apache.commons + commons-pool2 + ${commons-pool2.version} + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala new file mode 100644 index 000000000000..a408c27d21f8 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.atomic.LongAdder + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaDataConsumer.{CacheKey, UNKNOWN_OFFSET} +import org.apache.spark.util.ThreadUtils + +/** + * Provides object pool for [[FetchedData]] which is grouped by [[CacheKey]]. + * + * Along with CacheKey, it receives desired start offset to find cached FetchedData which + * may be stored from previous batch. If it can't find one to match, it will create + * a new FetchedData. + */ +private[kafka010] class FetchedDataPool extends Logging { + import FetchedDataPool._ + + private val cache: mutable.Map[CacheKey, CachedFetchedDataList] = mutable.HashMap.empty + + private val (minEvictableIdleTimeMillis, evictorThreadRunIntervalMillis): (Long, Long) = { + val conf = SparkEnv.get.conf + + val minEvictIdleTime = conf.getLong(CONFIG_NAME_MIN_EVICTABLE_IDLE_TIME_MILLIS, + DEFAULT_VALUE_MIN_EVICTABLE_IDLE_TIME_MILLIS) + + val evictorThreadInterval = conf.getLong( + CONFIG_NAME_EVICTOR_THREAD_RUN_INTERVAL_MILLIS, + DEFAULT_VALUE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS) + + (minEvictIdleTime, evictorThreadInterval) + } + + private val executorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kafka-fetched-data-cache-evictor") + + private def startEvictorThread(): ScheduledFuture[_] = { + executorService.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { + try { + removeIdleFetchedData() + } catch { + case NonFatal(e) => + logWarning("Exception occurred while removing idle fetched data.", e) + } + } + }, 0, evictorThreadRunIntervalMillis, TimeUnit.MILLISECONDS) + } + + private var scheduled = startEvictorThread() + + private val numCreatedFetchedData = new LongAdder() + private val numTotalElements = new LongAdder() + + def getNumCreated: Long = numCreatedFetchedData.sum() + def getNumTotal: Long = numTotalElements.sum() + + def acquire(key: CacheKey, desiredStartOffset: Long): FetchedData = synchronized { + val fetchedDataList = cache.getOrElseUpdate(key, new CachedFetchedDataList()) + + val cachedFetchedDataOption = fetchedDataList.find { p => + !p.inUse && p.getObject.nextOffsetInFetchedData == desiredStartOffset + } + + var cachedFetchedData: CachedFetchedData = null + if (cachedFetchedDataOption.isDefined) { + cachedFetchedData = cachedFetchedDataOption.get + } else { + cachedFetchedData = CachedFetchedData.empty() + fetchedDataList += cachedFetchedData + + numCreatedFetchedData.increment() + numTotalElements.increment() + } + + cachedFetchedData.lastAcquiredTimestamp = System.currentTimeMillis() + cachedFetchedData.inUse = true + + cachedFetchedData.getObject + } + + def invalidate(key: CacheKey): Unit = synchronized { + cache.remove(key) match { + case Some(lst) => numTotalElements.add(-1 * lst.size) + case None => + } + } + + def release(key: CacheKey, fetchedData: FetchedData): Unit = synchronized { + cache.get(key) match { + case Some(fetchedDataList) => + val cachedFetchedDataOption = fetchedDataList.find { p => + p.inUse && p.getObject == fetchedData + } + + if (cachedFetchedDataOption.isDefined) { + val cachedFetchedData = cachedFetchedDataOption.get + cachedFetchedData.inUse = false + cachedFetchedData.lastReleasedTimestamp = System.currentTimeMillis() + } + + case None => logWarning(s"No matching data in pool for $fetchedData in key $key. " + + "It might be released before, or it was not a part of pool.") + } + } + + def shutdown(): Unit = { + executorService.shutdownNow() + } + + def reset(): Unit = synchronized { + scheduled.cancel(true) + + cache.clear() + numTotalElements.reset() + numCreatedFetchedData.reset() + + scheduled = startEvictorThread() + } + + private def removeIdleFetchedData(): Unit = synchronized { + val timestamp = System.currentTimeMillis() + val maxAllowedIdleTimestamp = timestamp - minEvictableIdleTimeMillis + cache.values.foreach { p: CachedFetchedDataList => + val idles = p.filter(q => !q.inUse && q.lastReleasedTimestamp < maxAllowedIdleTimestamp) + val lstSize = p.size + idles.foreach(idle => p -= idle) + numTotalElements.add(-1 * (lstSize - p.size)) + } + } +} + +private[kafka010] object FetchedDataPool { + private[kafka010] case class CachedFetchedData(fetchedData: FetchedData) { + var lastReleasedTimestamp: Long = Long.MaxValue + var lastAcquiredTimestamp: Long = Long.MinValue + var inUse: Boolean = false + + def getObject: FetchedData = fetchedData + } + + private object CachedFetchedData { + def empty(): CachedFetchedData = { + val emptyData = FetchedData( + ju.Collections.emptyListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + UNKNOWN_OFFSET, + UNKNOWN_OFFSET) + + CachedFetchedData(emptyData) + } + } + + private[kafka010] type CachedFetchedDataList = mutable.ListBuffer[CachedFetchedData] + + val CONFIG_NAME_PREFIX = "spark.sql.kafkaFetchedDataCache." + val CONFIG_NAME_MIN_EVICTABLE_IDLE_TIME_MILLIS = CONFIG_NAME_PREFIX + + "minEvictableIdleTimeMillis" + val CONFIG_NAME_EVICTOR_THREAD_RUN_INTERVAL_MILLIS = CONFIG_NAME_PREFIX + + "evictorThreadRunIntervalMillis" + + val DEFAULT_VALUE_MIN_EVICTABLE_IDLE_TIME_MILLIS = 10 * 60 * 1000 // 10 minutes + val DEFAULT_VALUE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS = 5 * 60 * 1000 // 3 minutes + + def build: FetchedDataPool = new FetchedDataPool() +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala new file mode 100644 index 000000000000..f268508a7c61 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.ConcurrentHashMap + +import org.apache.commons.pool2.{BaseKeyedPooledObjectFactory, PooledObject, SwallowedExceptionListener} +import org.apache.commons.pool2.impl.{DefaultEvictionPolicy, DefaultPooledObject, GenericKeyedObjectPool, GenericKeyedObjectPoolConfig} + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.InternalKafkaConsumerPool._ +import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey + +/** + * Provides object pool for [[InternalKafkaConsumer]] which is grouped by [[CacheKey]]. + * + * This class leverages [[GenericKeyedObjectPool]] internally, hence providing methods based on + * the class, and same contract applies: after using the borrowed object, you must either call + * returnObject() if the object is healthy to return to pool, or invalidateObject() if the object + * should be destroyed. + * + * The soft capacity of pool is determined by "spark.sql.kafkaConsumerCache.capacity" config value, + * and the pool will have reasonable default value if the value is not provided. + * (The instance will do its best effort to respect soft capacity but it can exceed when there's + * a borrowing request and there's neither free space nor idle object to clear.) + * + * This class guarantees that no caller will get pooled object once the object is borrowed and + * not yet returned, hence provide thread-safety usage of non-thread-safe [[InternalKafkaConsumer]] + * unless caller shares the object to multiple threads. + */ +private[kafka010] class InternalKafkaConsumerPool( + objectFactory: ObjectFactory, + poolConfig: PoolConfig) { + + // the class is intended to have only soft capacity + assert(poolConfig.getMaxTotal < 0) + + private lazy val pool = { + val internalPool = new GenericKeyedObjectPool[CacheKey, InternalKafkaConsumer]( + objectFactory, poolConfig) + internalPool.setSwallowedExceptionListener(CustomSwallowedExceptionListener) + internalPool + } + + /** + * Borrows [[InternalKafkaConsumer]] object from the pool. If there's no idle object for the key, + * the pool will create the [[InternalKafkaConsumer]] object. + * + * If the pool doesn't have idle object for the key and also exceeds the soft capacity, + * pool will try to clear some of idle objects. + * + * Borrowed object must be returned by either calling returnObject or invalidateObject, otherwise + * the object will be kept in pool as active object. + */ + def borrowObject(key: CacheKey, kafkaParams: ju.Map[String, Object]): InternalKafkaConsumer = { + updateKafkaParamForKey(key, kafkaParams) + + if (getTotal == poolConfig.getSoftMaxTotal()) { + pool.clearOldest() + } + + pool.borrowObject(key) + } + + /** Returns borrowed object to the pool. */ + def returnObject(consumer: InternalKafkaConsumer): Unit = { + pool.returnObject(extractCacheKey(consumer), consumer) + } + + /** Invalidates (destroy) borrowed object to the pool. */ + def invalidateObject(consumer: InternalKafkaConsumer): Unit = { + pool.invalidateObject(extractCacheKey(consumer), consumer) + } + + /** Invalidates all idle consumers for the key */ + def invalidateKey(key: CacheKey): Unit = { + pool.clear(key) + } + + /** + * Closes the keyed object pool. Once the pool is closed, + * borrowObject will fail with [[IllegalStateException]], but returnObject and invalidateObject + * will continue to work, with returned objects destroyed on return. + * + * Also destroys idle instances in the pool. + */ + def close(): Unit = { + pool.close() + } + + def reset(): Unit = { + // this is the best-effort of clearing up. otherwise we should close the pool and create again + // but we don't want to make it "var" only because of tests. + pool.clear() + } + + def getNumIdle: Int = pool.getNumIdle + + def getNumIdle(key: CacheKey): Int = pool.getNumIdle(key) + + def getNumActive: Int = pool.getNumActive + + def getNumActive(key: CacheKey): Int = pool.getNumActive(key) + + def getTotal: Int = getNumIdle + getNumActive + + def getTotal(key: CacheKey): Int = getNumIdle(key) + getNumActive(key) + + private def updateKafkaParamForKey(key: CacheKey, kafkaParams: ju.Map[String, Object]): Unit = { + // We can assume that kafkaParam should not be different for same cache key, + // otherwise we can't reuse the cached object and cache key should contain kafkaParam. + // So it should be safe to put the key/value pair only when the key doesn't exist. + val oldKafkaParams = objectFactory.keyToKafkaParams.putIfAbsent(key, kafkaParams) + require(oldKafkaParams == null || kafkaParams == oldKafkaParams, "Kafka parameters for same " + + s"cache key should be equal. old parameters: $oldKafkaParams new parameters: $kafkaParams") + } + + private def extractCacheKey(consumer: InternalKafkaConsumer): CacheKey = { + new CacheKey(consumer.topicPartition, consumer.kafkaParams) + } +} + +private[kafka010] object InternalKafkaConsumerPool { + + /** + * Builds the pool for [[InternalKafkaConsumer]]. The pool instance is created per each call. + */ + def build: InternalKafkaConsumerPool = { + val objFactory = new ObjectFactory + val poolConfig = new PoolConfig + new InternalKafkaConsumerPool(objFactory, poolConfig) + } + + object CustomSwallowedExceptionListener extends SwallowedExceptionListener with Logging { + override def onSwallowException(e: Exception): Unit = { + logError(s"Error closing Kafka consumer", e) + } + } + + class PoolConfig extends GenericKeyedObjectPoolConfig[InternalKafkaConsumer] { + private var softMaxTotal = Int.MaxValue + + def getSoftMaxTotal(): Int = softMaxTotal + + init() + + def init(): Unit = { + val conf = SparkEnv.get.conf + + softMaxTotal = conf.get(CONSUMER_CACHE_CAPACITY) + + val jmxEnabled = conf.get(CONSUMER_CACHE_JMX_ENABLED) + val minEvictableIdleTimeMillis = conf.get(CONSUMER_CACHE_MIN_EVICTABLE_IDLE_TIME_MILLIS) + val evictorThreadRunIntervalMillis = conf.get( + CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS) + + // NOTE: Below lines define the behavior, so do not modify unless you know what you are + // doing, and update the class doc accordingly if necessary when you modify. + + // 1. Set min idle objects per key to 0 to avoid creating unnecessary object. + // 2. Set max idle objects per key to 3 but set total objects per key to infinite + // which ensures borrowing per key is not restricted. + // 3. Set max total objects to infinite which ensures all objects are managed in this pool. + setMinIdlePerKey(0) + setMaxIdlePerKey(3) + setMaxTotalPerKey(-1) + setMaxTotal(-1) + + // Set minimum evictable idle time which will be referred from evictor thread + setMinEvictableIdleTimeMillis(minEvictableIdleTimeMillis) + setSoftMinEvictableIdleTimeMillis(-1) + + // evictor thread will run test with ten idle objects + setTimeBetweenEvictionRunsMillis(evictorThreadRunIntervalMillis) + setNumTestsPerEvictionRun(10) + setEvictionPolicy(new DefaultEvictionPolicy[InternalKafkaConsumer]()) + + // Immediately fail on exhausted pool while borrowing + setBlockWhenExhausted(false) + + setJmxEnabled(jmxEnabled) + setJmxNamePrefix("kafka010-cached-simple-kafka-consumer-pool") + } + } + + class ObjectFactory extends BaseKeyedPooledObjectFactory[CacheKey, InternalKafkaConsumer] + with Logging { + + val keyToKafkaParams: ConcurrentHashMap[CacheKey, ju.Map[String, Object]] = + new ConcurrentHashMap[CacheKey, ju.Map[String, Object]]() + + override def create(key: CacheKey): InternalKafkaConsumer = { + Option(keyToKafkaParams.get(key)) match { + case Some(kafkaParams) => new InternalKafkaConsumer(key.topicPartition, kafkaParams) + case None => throw new IllegalStateException("Kafka params should be set before " + + "borrowing object.") + } + } + + override def wrap(value: InternalKafkaConsumer): PooledObject[InternalKafkaConsumer] = { + new DefaultPooledObject[InternalKafkaConsumer](value) + } + + override def destroyObject(key: CacheKey, p: PooledObject[InternalKafkaConsumer]): Unit = { + p.getObject.close() + } + } +} + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala index 839a64ed3132..700414167f3e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala @@ -91,7 +91,7 @@ private[kafka010] class KafkaBatch( KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) offsetRanges.map { range => new KafkaBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, false) + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) }.toArray } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index cbc2fbfce319..53b0b3c46854 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -30,14 +30,13 @@ private[kafka010] case class KafkaBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition + failOnDataLoss: Boolean) extends InputPartition private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val p = partition.asInstanceOf[KafkaBatchInputPartition] KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, - p.failOnDataLoss, p.reuseKafkaConsumer) + p.failOnDataLoss) } } @@ -46,11 +45,9 @@ private case class KafkaBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging { + failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging { - private val consumer = KafkaDataConsumer.acquire( - offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala index 18d740eaa968..a9c1181a01c5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala @@ -185,7 +185,7 @@ class KafkaContinuousPartitionReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { - private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) + private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams) private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index af240dc04eea..9e98ae562937 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.io.Closeable import java.util.concurrent.TimeoutException import scala.collection.JavaConverters._ @@ -25,169 +26,200 @@ import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.common.TopicPartition -import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.kafka010.KafkaConfigUpdater -import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange +import org.apache.spark.sql.kafka010.KafkaDataConsumer.{AvailableOffsetRange, UNKNOWN_OFFSET} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.{ShutdownHookManager, UninterruptibleThread} + +/** + * This class simplifies the usages of Kafka consumer in Spark SQL Kafka connector. + * + * NOTE: Like KafkaConsumer, this class is not thread-safe. + * NOTE for contributors: It is possible for the instance to be used from multiple callers, + * so all the methods should not rely on current cursor and use seek manually. + */ +private[kafka010] class InternalKafkaConsumer( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Closeable with Logging { + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + private val consumer = createConsumer -private[kafka010] sealed trait KafkaDataConsumer { /** - * Get the record for the given offset if available. - * - * If the record is invisible (either a - * transaction message, or an aborted message when the consumer's `isolation.level` is - * `read_committed`), it will be skipped and this method will try to fetch next available record - * within [offset, untilOffset). + * Poll messages from Kafka starting from `offset` and returns a pair of "list of consumer record" + * and "offset after poll". The list of consumer record may be empty if the Kafka consumer fetches + * some messages but all of them are not visible messages (either transaction messages, + * or aborted messages when `isolation.level` is `read_committed`). * - * This method also will try its best to detect data loss. If `failOnDataLoss` is `true`, it will - * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this - * method will try to fetch next available record within [offset, untilOffset). - * - * When this method tries to skip offsets due to either invisible messages or data loss and - * reaches `untilOffset`, it will return `null`. - * - * @param offset the offset to fetch. - * @param untilOffset the max offset to fetch. Exclusive. - * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. - * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at - * offset if available, or throw exception.when `failOnDataLoss` is `false`, - * this method will either return record at offset if available, or return - * the next earliest available record less than untilOffset, or null. It - * will not throw any exception. + * @throws OffsetOutOfRangeException if `offset` is out of range. + * @throws TimeoutException if the consumer position is not changed after polling. It means the + * consumer polls nothing before timeout. */ - def get( - offset: Long, - untilOffset: Long, - pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { - internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss) + def fetch(offset: Long, pollTimeoutMs: Long): + (ju.List[ConsumerRecord[Array[Byte], Array[Byte]]], Long) = { + + // Seek to the offset because we may call seekToBeginning or seekToEnd before this. + seek(offset) + val p = consumer.poll(pollTimeoutMs) + val r = p.records(topicPartition) + logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") + val offsetAfterPoll = consumer.position(topicPartition) + logDebug(s"Offset changed from $offset to $offsetAfterPoll after polling") + val fetchedData = (r, offsetAfterPoll) + if (r.isEmpty) { + // We cannot fetch anything after `poll`. Two possible cases: + // - `offset` is out of range so that Kafka returns nothing. `OffsetOutOfRangeException` will + // be thrown. + // - Cannot fetch any data before timeout. `TimeoutException` will be thrown. + // - Fetched something but all of them are not invisible. This is a valid case and let the + // caller handles this. + val range = getAvailableOffsetRange() + if (offset < range.earliest || offset >= range.latest) { + throw new OffsetOutOfRangeException( + Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + } else if (offset == offsetAfterPoll) { + throw new TimeoutException( + s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + } + } + fetchedData } /** * Return the available offset range of the current partition. It's a pair of the earliest offset * and the latest offset. */ - def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange() + def getAvailableOffsetRange(): AvailableOffsetRange = { + consumer.seekToBeginning(Set(topicPartition).asJava) + val earliestOffset = consumer.position(topicPartition) + consumer.seekToEnd(Set(topicPartition).asJava) + val latestOffset = consumer.position(topicPartition) + AvailableOffsetRange(earliestOffset, latestOffset) + } - /** - * Release this consumer from being further used. Depending on its implementation, - * this consumer will be either finalized, or reset for reuse later. - */ - def release(): Unit + override def close(): Unit = { + consumer.close() + } - /** Reference to the internal implementation that this wrapper delegates to */ - def internalConsumer: InternalKafkaConsumer -} + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { + val updatedKafkaParams = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) + .setAuthenticationConfigIfNeeded() + .build() + val c = new KafkaConsumer[Array[Byte], Array[Byte]](updatedKafkaParams) + val tps = new ju.ArrayList[TopicPartition]() + tps.add(topicPartition) + c.assign(tps) + c + } + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $groupId $topicPartition $offset") + consumer.seek(topicPartition, offset) + } +} /** - * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected. - * This is not for direct use outside this file. + * The internal object to store the fetched data from Kafka consumer and the next offset to poll. + * + * @param _records the pre-fetched Kafka records. + * @param _nextOffsetInFetchedData the next offset in `records`. We use this to verify if we + * should check if the pre-fetched data is still valid. + * @param _offsetAfterPoll the Kafka offset after calling `poll`. We will use this offset to + * poll when `records` is drained. */ -private[kafka010] case class InternalKafkaConsumer( - topicPartition: TopicPartition, - kafkaParams: ju.Map[String, Object]) extends Logging { - import InternalKafkaConsumer._ - - /** - * The internal object to store the fetched data from Kafka consumer and the next offset to poll. - * - * @param _records the pre-fetched Kafka records. - * @param _nextOffsetInFetchedData the next offset in `records`. We use this to verify if we - * should check if the pre-fetched data is still valid. - * @param _offsetAfterPoll the Kafka offset after calling `poll`. We will use this offset to - * poll when `records` is drained. - */ - private case class FetchedData( - private var _records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], - private var _nextOffsetInFetchedData: Long, - private var _offsetAfterPoll: Long) { - - def withNewPoll( - records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], - offsetAfterPoll: Long): FetchedData = { - this._records = records - this._nextOffsetInFetchedData = UNKNOWN_OFFSET - this._offsetAfterPoll = offsetAfterPoll - this - } - - /** Whether there are more elements */ - def hasNext: Boolean = _records.hasNext - - /** Move `records` forward and return the next record. */ - def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { - val record = _records.next() - _nextOffsetInFetchedData = record.offset + 1 - record - } +private[kafka010] case class FetchedData( + private var _records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + private var _nextOffsetInFetchedData: Long, + private var _offsetAfterPoll: Long) { + + def withNewPoll( + records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + offsetAfterPoll: Long): FetchedData = { + this._records = records + this._nextOffsetInFetchedData = UNKNOWN_OFFSET + this._offsetAfterPoll = offsetAfterPoll + this + } - /** Move `records` backward and return the previous record. */ - def previous(): ConsumerRecord[Array[Byte], Array[Byte]] = { - assert(_records.hasPrevious, "fetchedData cannot move back") - val record = _records.previous() - _nextOffsetInFetchedData = record.offset - record - } + /** Whether there are more elements */ + def hasNext: Boolean = _records.hasNext - /** Reset the internal pre-fetched data. */ - def reset(): Unit = { - _records = ju.Collections.emptyListIterator() - _nextOffsetInFetchedData = UNKNOWN_OFFSET - _offsetAfterPoll = UNKNOWN_OFFSET - } + /** Move `records` forward and return the next record. */ + def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + val record = _records.next() + _nextOffsetInFetchedData = record.offset + 1 + record + } - /** - * Returns the next offset in `records`. We use this to verify if we should check if the - * pre-fetched data is still valid. - */ - def nextOffsetInFetchedData: Long = _nextOffsetInFetchedData + /** Move `records` backward and return the previous record. */ + def previous(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(_records.hasPrevious, "fetchedData cannot move back") + val record = _records.previous() + _nextOffsetInFetchedData = record.offset + record + } - /** - * Returns the next offset to poll after draining the pre-fetched records. - */ - def offsetAfterPoll: Long = _offsetAfterPoll + /** Reset the internal pre-fetched data. */ + def reset(): Unit = { + _records = ju.Collections.emptyListIterator() + _nextOffsetInFetchedData = UNKNOWN_OFFSET + _offsetAfterPoll = UNKNOWN_OFFSET } /** - * The internal object returned by the `fetchRecord` method. If `record` is empty, it means it is - * invisible (either a transaction message, or an aborted message when the consumer's - * `isolation.level` is `read_committed`), and the caller should use `nextOffsetToFetch` to fetch - * instead. + * Returns the next offset in `records`. We use this to verify if we should check if the + * pre-fetched data is still valid. */ - private case class FetchedRecord( - var record: ConsumerRecord[Array[Byte], Array[Byte]], - var nextOffsetToFetch: Long) { - - def withRecord( - record: ConsumerRecord[Array[Byte], Array[Byte]], - nextOffsetToFetch: Long): FetchedRecord = { - this.record = record - this.nextOffsetToFetch = nextOffsetToFetch - this - } - } + def nextOffsetInFetchedData: Long = _nextOffsetInFetchedData - private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + /** + * Returns the next offset to poll after draining the pre-fetched records. + */ + def offsetAfterPoll: Long = _offsetAfterPoll +} - @volatile private var consumer = createConsumer +/** + * The internal object returned by the `fetchRecord` method. If `record` is empty, it means it is + * invisible (either a transaction message, or an aborted message when the consumer's + * `isolation.level` is `read_committed`), and the caller should use `nextOffsetToFetch` to fetch + * instead. + */ +private[kafka010] case class FetchedRecord( + var record: ConsumerRecord[Array[Byte], Array[Byte]], + var nextOffsetToFetch: Long) { + + def withRecord( + record: ConsumerRecord[Array[Byte], Array[Byte]], + nextOffsetToFetch: Long): FetchedRecord = { + this.record = record + this.nextOffsetToFetch = nextOffsetToFetch + this + } +} - /** indicates whether this consumer is in use or not */ - @volatile var inUse = true +/** + * This class helps caller to read from Kafka leveraging consumer pool as well as fetched data pool. + * This class throws error when data loss is detected while reading from Kafka. + * + * NOTE for contributors: we need to ensure all the public methods to initialize necessary resources + * via calling `getOrRetrieveConsumer` and `getOrRetrieveFetchedData`. + */ +private[kafka010] class KafkaDataConsumer( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + consumerPool: InternalKafkaConsumerPool, + fetchedDataPool: FetchedDataPool) extends Logging { + import KafkaDataConsumer._ - /** indicate whether this consumer is going to be stopped in the next release */ - @volatile var markedForClose = false + @volatile private[kafka010] var _consumer: Option[InternalKafkaConsumer] = None + @volatile private var _fetchedData: Option[FetchedData] = None - /** - * The fetched data returned from Kafka consumer. This is a reusable private object to avoid - * memory allocation. - */ - private val fetchedData = FetchedData( - ju.Collections.emptyListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], - UNKNOWN_OFFSET, - UNKNOWN_OFFSET) + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + private val cacheKey = CacheKey(groupId, topicPartition) /** * The fetched record returned from the `fetchRecord` method. This is a reusable private object to @@ -195,41 +227,30 @@ private[kafka010] case class InternalKafkaConsumer( */ private val fetchedRecord: FetchedRecord = FetchedRecord(null, UNKNOWN_OFFSET) - - /** Create a KafkaConsumer to fetch records for `topicPartition` */ - private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { - val updatedKafkaParams = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) - .setAuthenticationConfigIfNeeded() - .build() - val c = new KafkaConsumer[Array[Byte], Array[Byte]](updatedKafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { - case ut: UninterruptibleThread => - ut.runUninterruptibly(body) - case _ => - logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " + - "It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894") - body - } - /** - * Return the available offset range of the current partition. It's a pair of the earliest offset - * and the latest offset. + * Get the record for the given offset if available. + * + * If the record is invisible (either a + * transaction message, or an aborted message when the consumer's `isolation.level` is + * `read_committed`), it will be skipped and this method will try to fetch next available record + * within [offset, untilOffset). + * + * This method also will try its best to detect data loss. If `failOnDataLoss` is `true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will try to fetch next available record within [offset, untilOffset). + * + * When this method tries to skip offsets due to either invisible messages or data loss and + * reaches `untilOffset`, it will return `null`. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. */ - def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible { - consumer.seekToBeginning(Set(topicPartition).asJava) - val earliestOffset = consumer.position(topicPartition) - consumer.seekToEnd(Set(topicPartition).asJava) - val latestOffset = consumer.position(topicPartition) - AvailableOffsetRange(earliestOffset, latestOffset) - } - - /** @see [[KafkaDataConsumer.get]] */ def get( offset: Long, untilOffset: Long, @@ -238,8 +259,13 @@ private[kafka010] case class InternalKafkaConsumer( ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") + + val consumer = getOrRetrieveConsumer() + val fetchedData = getOrRetrieveFetchedData(offset) + logDebug(s"Get $groupId $topicPartition nextOffset ${fetchedData.nextOffsetInFetchedData} " + - s"requested $offset") + "requested $offset") + // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then // we will move to the next available offset within `[offset, untilOffset)` and retry. @@ -252,7 +278,8 @@ private[kafka010] case class InternalKafkaConsumer( while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) { try { - fetchedRecord = fetchRecord(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) + fetchedRecord = fetchRecord(consumer, fetchedData, toFetchOffset, untilOffset, + pollTimeoutMs, failOnDataLoss) if (fetchedRecord.record != null) { isFetchComplete = true } else { @@ -266,12 +293,9 @@ private[kafka010] case class InternalKafkaConsumer( } } catch { case e: OffsetOutOfRangeException => - // When there is some error thrown, it's better to use a new consumer to drop all cached - // states in the old consumer. We don't need to worry about the performance because this - // is not a common path. - resetConsumer() - reportDataLoss(failOnDataLoss, s"Cannot fetch offset $toFetchOffset", e) - toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, untilOffset) + reportDataLoss(topicPartition, groupId, failOnDataLoss, + s"Cannot fetch offset $toFetchOffset", e) + toFetchOffset = getEarliestAvailableOffsetBetween(consumer, toFetchOffset, untilOffset) } } @@ -283,14 +307,45 @@ private[kafka010] case class InternalKafkaConsumer( } } + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible { + val consumer = getOrRetrieveConsumer() + consumer.getAvailableOffsetRange() + } + + /** + * Release borrowed objects in data reader to the pool. Once the instance is created, caller + * must call method after using the instance to make sure resources are not leaked. + */ + def release(): Unit = { + if (_consumer.isDefined) { + consumerPool.returnObject(_consumer.get) + _consumer = None + } + + if (_fetchedData.isDefined) { + fetchedDataPool.release(cacheKey, _fetchedData.get) + _fetchedData = None + } + } + /** * Return the next earliest available offset in [offset, untilOffset). If all offsets in * [offset, untilOffset) are invalid (e.g., the topic is deleted and recreated), it will return * `UNKNOWN_OFFSET`. */ - private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: Long): Long = { - val range = getAvailableOffsetRange() + private def getEarliestAvailableOffsetBetween( + consumer: InternalKafkaConsumer, + offset: Long, + untilOffset: Long): Long = { + val range = consumer.getAvailableOffsetRange() logWarning(s"Some data may be lost. Recovering from the earliest offset: ${range.earliest}") + + val topicPartition = consumer.topicPartition + val groupId = consumer.groupId if (offset >= range.latest || range.earliest >= untilOffset) { // [offset, untilOffset) and [earliestOffset, latestOffset) have no overlap, // either @@ -305,10 +360,10 @@ private[kafka010] case class InternalKafkaConsumer( // | | | | // offset untilOffset earliestOffset latestOffset val warningMessage = - s""" - |The current available offset range is $range. - | Offset ${offset} is out of range, and records in [$offset, $untilOffset) will be - | skipped ${additionalMessage(failOnDataLoss = false)} + s""" + |The current available offset range is $range. + | Offset $offset is out of range, and records in [$offset, $untilOffset) will be + | skipped ${additionalMessage(topicPartition, groupId, failOnDataLoss = false)} """.stripMargin logWarning(warningMessage) UNKNOWN_OFFSET @@ -321,8 +376,8 @@ private[kafka010] case class InternalKafkaConsumer( // This will happen when a topic is deleted and recreated, and new data are pushed very fast, // then we will see `offset` disappears first then appears again. Although the parameters // are same, the state in Kafka cluster is changed, so the outer loop won't be endless. - logWarning(s"Found a disappeared offset $offset. " + - s"Some data may be lost ${additionalMessage(failOnDataLoss = false)}") + logWarning(s"Found a disappeared offset $offset. Some data may be lost " + + s"${additionalMessage(topicPartition, groupId, failOnDataLoss = false)}") offset } else { // ------------------------------------------------------------------------------ @@ -330,10 +385,10 @@ private[kafka010] case class InternalKafkaConsumer( // | | | | // offset earliestOffset min(untilOffset,latestOffset) max(untilOffset, latestOffset) val warningMessage = - s""" - |The current available offset range is $range. - | Offset ${offset} is out of range, and records in [$offset, ${range.earliest}) will be - | skipped ${additionalMessage(failOnDataLoss = false)} + s""" + |The current available offset range is $range. + | Offset ${offset} is out of range, and records in [$offset, ${range.earliest}) will be + | skipped ${additionalMessage(topicPartition, groupId, failOnDataLoss = false)} """.stripMargin logWarning(warningMessage) range.earliest @@ -355,6 +410,8 @@ private[kafka010] case class InternalKafkaConsumer( * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. */ private def fetchRecord( + consumer: InternalKafkaConsumer, + fetchedData: FetchedData, offset: Long, untilOffset: Long, pollTimeoutMs: Long, @@ -362,7 +419,7 @@ private[kafka010] case class InternalKafkaConsumer( if (offset != fetchedData.nextOffsetInFetchedData) { // This is the first fetch, or the fetched data has been reset. // Fetch records from Kafka and update `fetchedData`. - fetchData(offset, pollTimeoutMs) + fetchData(consumer, fetchedData, offset, pollTimeoutMs) } else if (!fetchedData.hasNext) { // The last pre-fetched data has been drained. if (offset < fetchedData.offsetAfterPoll) { // Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask @@ -372,7 +429,7 @@ private[kafka010] case class InternalKafkaConsumer( return fetchedRecord.withRecord(null, nextOffsetToFetch) } else { // Fetch records from Kafka and update `fetchedData`. - fetchData(offset, pollTimeoutMs) + fetchData(consumer, fetchedData, offset, pollTimeoutMs) } } @@ -388,7 +445,7 @@ private[kafka010] case class InternalKafkaConsumer( // In general, Kafka uses the specified offset as the start point, and tries to fetch the next // available offset. Hence we need to handle offset mismatch. if (record.offset > offset) { - val range = getAvailableOffsetRange() + val range = consumer.getAvailableOffsetRange() if (range.earliest <= offset) { // `offset` is still valid but the corresponding message is invisible. We should skip it // and jump to `record.offset`. Here we move `fetchedData` back so that the next call of @@ -398,16 +455,19 @@ private[kafka010] case class InternalKafkaConsumer( } // This may happen when some records aged out but their offsets already got verified if (failOnDataLoss) { - reportDataLoss(true, s"Cannot fetch records in [$offset, ${record.offset})") + reportDataLoss(consumer.topicPartition, consumer.groupId, failOnDataLoss = true, + s"Cannot fetch records in [$offset, ${record.offset})") // Never happen as "reportDataLoss" will throw an exception throw new IllegalStateException( "reportDataLoss didn't throw an exception when 'failOnDataLoss' is true") } else if (record.offset >= untilOffset) { - reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") + reportDataLoss(consumer.topicPartition, consumer.groupId, failOnDataLoss = false, + s"Skip missing records in [$offset, $untilOffset)") // Set `nextOffsetToFetch` to `untilOffset` to finish the current batch. fetchedRecord.withRecord(null, untilOffset) } else { - reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") + reportDataLoss(consumer.topicPartition, consumer.groupId, failOnDataLoss = false, + s"Skip missing records in [$offset, ${record.offset})") fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } else if (record.offset < offset) { @@ -421,17 +481,49 @@ private[kafka010] case class InternalKafkaConsumer( } } - /** Create a new consumer and reset cached states */ - private def resetConsumer(): Unit = { - consumer.close() - consumer = createConsumer - fetchedData.reset() + /** + * Poll messages from Kafka starting from `offset` and update `fetchedData`. `fetchedData` may be + * empty if the Kafka consumer fetches some messages but all of them are not visible messages + * (either transaction messages, or aborted messages when `isolation.level` is `read_committed`). + * + * @throws OffsetOutOfRangeException if `offset` is out of range. + * @throws TimeoutException if the consumer position is not changed after polling. It means the + * consumer polls nothing before timeout. + */ + private def fetchData( + consumer: InternalKafkaConsumer, + fetchedData: FetchedData, + offset: Long, + pollTimeoutMs: Long): Unit = { + val (records, offsetAfterPoll) = consumer.fetch(offset, pollTimeoutMs) + fetchedData.withNewPoll(records.listIterator, offsetAfterPoll) + } + + private def getOrRetrieveConsumer(): InternalKafkaConsumer = _consumer match { + case None => + _consumer = Option(consumerPool.borrowObject(cacheKey, kafkaParams)) + require(_consumer.isDefined, "borrowing consumer from pool must always succeed.") + _consumer.get + + case Some(consumer) => consumer + } + + private def getOrRetrieveFetchedData(offset: Long): FetchedData = _fetchedData match { + case None => + _fetchedData = Option(fetchedDataPool.acquire(cacheKey, offset)) + require(_fetchedData.isDefined, "acquiring fetched data from cache must always succeed.") + _fetchedData.get + + case Some(fetchedData) => fetchedData } /** * Return an addition message including useful message and instruction. */ - private def additionalMessage(failOnDataLoss: Boolean): String = { + private def additionalMessage( + topicPartition: TopicPartition, + groupId: String, + failOnDataLoss: Boolean): String = { if (failOnDataLoss) { s"(GroupId: $groupId, TopicPartition: $topicPartition). " + s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE" @@ -445,197 +537,69 @@ private[kafka010] case class InternalKafkaConsumer( * Throw an exception or log a warning as per `failOnDataLoss`. */ private def reportDataLoss( + topicPartition: TopicPartition, + groupId: String, failOnDataLoss: Boolean, message: String, cause: Throwable = null): Unit = { - val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}" + val finalMessage = s"$message ${additionalMessage(topicPartition, groupId, failOnDataLoss)}" reportDataLoss0(failOnDataLoss, finalMessage, cause) } - def close(): Unit = consumer.close() - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $groupId $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - /** - * Poll messages from Kafka starting from `offset` and update `fetchedData`. `fetchedData` may be - * empty if the Kafka consumer fetches some messages but all of them are not visible messages - * (either transaction messages, or aborted messages when `isolation.level` is `read_committed`). - * - * @throws OffsetOutOfRangeException if `offset` is out of range. - * @throws TimeoutException if the consumer position is not changed after polling. It means the - * consumer polls nothing before timeout. - */ - private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = { - // Seek to the offset because we may call seekToBeginning or seekToEnd before this. - seek(offset) - val p = consumer.poll(pollTimeoutMs) - val r = p.records(topicPartition) - logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") - val offsetAfterPoll = consumer.position(topicPartition) - logDebug(s"Offset changed from $offset to $offsetAfterPoll after polling") - fetchedData.withNewPoll(r.listIterator, offsetAfterPoll) - if (!fetchedData.hasNext) { - // We cannot fetch anything after `poll`. Two possible cases: - // - `offset` is out of range so that Kafka returns nothing. `OffsetOutOfRangeException` will - // be thrown. - // - Cannot fetch any data before timeout. `TimeoutException` will be thrown. - // - Fetched something but all of them are not invisible. This is a valid case and let the - // caller handles this. - val range = getAvailableOffsetRange() - if (offset < range.earliest || offset >= range.latest) { - throw new OffsetOutOfRangeException( - Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) - } else if (offset == offsetAfterPoll) { - throw new TimeoutException( - s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") - } - } + private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { + case ut: UninterruptibleThread => + ut.runUninterruptibly(body) + case _ => + logWarning("KafkaDataConsumer is not running in UninterruptibleThread. " + + "It may hang when KafkaDataConsumer's methods are interrupted because of KAFKA-1894") + body } } - private[kafka010] object KafkaDataConsumer extends Logging { + val UNKNOWN_OFFSET = -2L case class AvailableOffsetRange(earliest: Long, latest: Long) - private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) - extends KafkaDataConsumer { - assert(internalConsumer.inUse) // make sure this has been set to true - override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) } - } - - private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) - extends KafkaDataConsumer { - override def release(): Unit = { internalConsumer.close() } - } - - private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) { + case class CacheKey(groupId: String, topicPartition: TopicPartition) { def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) } - // This cache has the following important properties. - // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. - // The capacity is not guaranteed to be maintained, especially when there are more active - // tasks simultaneously using consumers than the capacity. - private[kafka010] lazy val cache = { - val conf = SparkEnv.get.conf - val capacity = conf.get(CONSUMER_CACHE_CAPACITY) - new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = { - - // Try to remove the least-used entry if its currently not in use. - // - // If you cannot remove it, then the cache will keep growing. In the worst case, - // the cache will grow to the max number of concurrent tasks that can run in the executor, - // (that is, number of tasks slots) after which it will never reduce. This is unlikely to - // be a serious problem because an executor with more than 64 (default) tasks slots is - // likely running on a beefy machine that can handle a large number of simultaneously - // active consumers. - - if (!entry.getValue.inUse && this.size > capacity) { - logWarning( - s"KafkaConsumer cache hitting max capacity of $capacity, " + - s"removing consumer for ${entry.getKey}") - try { - entry.getValue.close() - } catch { - case e: SparkException => - logError(s"Error closing earliest Kafka consumer for ${entry.getKey}", e) - } - true - } else { - false - } - } + private val consumerPool = InternalKafkaConsumerPool.build + private val fetchedDataPool = FetchedDataPool.build + + ShutdownHookManager.addShutdownHook { () => + try { + fetchedDataPool.shutdown() + consumerPool.close() + } catch { + case e: Throwable => + logWarning("Ignoring Exception while shutting down pools from shutdown hook", e) } } /** - * Get a cached consumer for groupId, assigned to topic and partition. + * Get a data reader for groupId, assigned to topic and partition. * If matching consumer doesn't already exist, will be created using kafkaParams. - * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. - * - * Note: This method guarantees that the consumer returned is not currently in use by any one - * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by - * caching them and tracking when they are in use. + * The returned data reader must be released explicitly. */ def acquire( topicPartition: TopicPartition, - kafkaParams: ju.Map[String, Object], - useCache: Boolean): KafkaDataConsumer = synchronized { - val key = new CacheKey(topicPartition, kafkaParams) - val existingInternalConsumer = cache.get(key) - - lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams) - + kafkaParams: ju.Map[String, Object]): KafkaDataConsumer = { if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { - // If this is reattempt at running the task, then invalidate cached consumer if any and - // start with a new one. - if (existingInternalConsumer != null) { - // Consumer exists in cache. If its in use, mark it for closing later, or close it now. - if (existingInternalConsumer.inUse) { - existingInternalConsumer.markedForClose = true - } else { - existingInternalConsumer.close() - } - } - cache.remove(key) // Invalidate the cache in any case - NonCachedKafkaDataConsumer(newInternalConsumer) - - } else if (!useCache) { - // If planner asks to not reuse consumers, then do not use it, return a new consumer - NonCachedKafkaDataConsumer(newInternalConsumer) + val cacheKey = new CacheKey(topicPartition, kafkaParams) - } else if (existingInternalConsumer == null) { - // If consumer is not already cached, then put a new in the cache and return it - cache.put(key, newInternalConsumer) - newInternalConsumer.inUse = true - CachedKafkaDataConsumer(newInternalConsumer) + // If this is reattempt at running the task, then invalidate cached consumer if any. + consumerPool.invalidateKey(cacheKey) - } else if (existingInternalConsumer.inUse) { - // If consumer is already cached but is currently in use, then return a new consumer - NonCachedKafkaDataConsumer(newInternalConsumer) - - } else { - // If consumer is already cached and is currently not in use, then return that consumer - existingInternalConsumer.inUse = true - CachedKafkaDataConsumer(existingInternalConsumer) + // invalidate all fetched data for the key as well + // sadly we can't pinpoint specific data and invalidate cause we don't have unique id + fetchedDataPool.invalidate(cacheKey) } - } - private def release(intConsumer: InternalKafkaConsumer): Unit = { - synchronized { - - // Clear the consumer from the cache if this is indeed the consumer present in the cache - val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams) - val cachedIntConsumer = cache.get(key) - if (intConsumer.eq(cachedIntConsumer)) { - // The released consumer is the same object as the cached one. - if (intConsumer.markedForClose) { - intConsumer.close() - cache.remove(key) - } else { - intConsumer.inUse = false - } - } else { - // The released consumer is either not the same one as in the cache, or not in the cache - // at all. This may happen if the cache was invalidate while this consumer was being used. - // Just close this consumer. - intConsumer.close() - logInfo(s"Released a supposedly cached consumer that was not found in the cache") - } - } + new KafkaDataConsumer(topicPartition, kafkaParams, consumerPool, fetchedDataPool) } -} - -private[kafka010] object InternalKafkaConsumer extends Logging { - - private val UNKNOWN_OFFSET = -2L private def reportDataLoss0( failOnDataLoss: Boolean, @@ -655,4 +619,5 @@ private[kafka010] object InternalKafkaConsumer extends Logging { } } } + } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 08a52ddbd19b..9cd16c8e1624 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -144,14 +144,9 @@ private[kafka010] class KafkaMicroBatchStream( untilOffsets = untilOffsets, executorLocations = getSortedExecutorList()) - // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, - // that is, concurrent tasks will not read the same TopicPartitions. - val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size - // Generate factories based on the offset ranges offsetRanges.map { range => - KafkaBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) }.toArray } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index dd584a5987a0..dc7087821b10 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -102,7 +102,7 @@ private[kafka010] class KafkaRelation( KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, - pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr => + pollTimeoutMs, failOnDataLoss).map { cr => InternalRow( cr.key, cr.value, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index f477c35dcf39..d1a35ec53bc9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -268,8 +268,7 @@ private[kafka010] class KafkaSource( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( - sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss, - reuseKafkaConsumer = true).map { cr => + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => InternalRow( cr.key, cr.value, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index f8b90056d293..dae9515205f5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -63,8 +63,7 @@ private[kafka010] class KafkaSourceRDD( executorKafkaParams: ju.Map[String, Object], offsetRanges: Seq[KafkaSourceRDDOffsetRange], pollTimeoutMs: Long, - failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) + failOnDataLoss: Boolean) extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) { override def persist(newLevel: StorageLevel): this.type = { @@ -87,7 +86,7 @@ private[kafka010] class KafkaSourceRDD( context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val consumer = KafkaDataConsumer.acquire( - sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + sourcePartition.offsetRange.topicPartition, executorKafkaParams) val range = resolveRange(consumer, sourcePartition.offsetRange) assert( diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala index ff19862c20cc..b24c0f1aa143 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -38,4 +38,25 @@ package object kafka010 { // scalastyle:ignore " (check Structured Streaming Kafka integration guide for further details).") .intConf .createWithDefault(64) + + private[kafka010] val CONSUMER_CACHE_JMX_ENABLED = + ConfigBuilder("spark.kafka.consumer.cache.jmx.enable") + .doc("Enable or disable JMX for pools created with this configuration instance.") + .booleanConf + .createWithDefault(false) + + private[kafka010] val CONSUMER_CACHE_MIN_EVICTABLE_IDLE_TIME_MILLIS = + ConfigBuilder("spark.kafka.consumer.cache.minEvictableIdleTimeMillis") + .doc("The minimum amount of time an object may sit idle in the pool before " + + "it is eligible for eviction by the idle object evictor. " + + "When non-positive, no objects will be evicted from the pool due to idle time alone.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("5m") + + private[kafka010] val CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS = + ConfigBuilder("spark.kafka.consumer.cache.evictorThreadRunIntervalMillis") + .doc("The number of milliseconds to sleep between runs of the idle object evictor thread. " + + "When non-positive, no idle object evictor thread will be run.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("3m") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala new file mode 100644 index 000000000000..ad3975d673f3 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey +import org.apache.spark.sql.test.SharedSparkSession + +class FetchedDataPoolSuite extends SharedSparkSession with PrivateMethodTester { + import FetchedDataPool._ + type Record = ConsumerRecord[Array[Byte], Array[Byte]] + + private val dummyBytes = "dummy".getBytes + + // Helper private method accessors for FetchedDataPool + private type PoolCacheType = mutable.Map[CacheKey, CachedFetchedDataList] + private val _cache = PrivateMethod[PoolCacheType]('cache) + + def getCache(pool: FetchedDataPool): PoolCacheType = { + pool.invokePrivate(_cache()) + } + + test("acquire fresh one") { + val dataPool = FetchedDataPool.build + + val cacheKey = CacheKey("testgroup", new TopicPartition("topic", 0)) + + assert(getCache(dataPool).get(cacheKey).isEmpty) + + val data = dataPool.acquire(cacheKey, 0) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 1, expectedNumTotal = 1) + assert(getCache(dataPool)(cacheKey).size === 1) + assert(getCache(dataPool)(cacheKey).head.inUse) + + data.withNewPoll(testRecords(0, 5).listIterator, 5) + + dataPool.release(cacheKey, data) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 1, expectedNumTotal = 1) + assert(getCache(dataPool)(cacheKey).size === 1) + assert(!getCache(dataPool)(cacheKey).head.inUse) + + dataPool.shutdown() + } + + test("acquire fetched data from multiple keys") { + val dataPool = FetchedDataPool.build + + val cacheKeys = (0 until 10).map { partId => + CacheKey("testgroup", new TopicPartition("topic", partId)) + } + + assert(getCache(dataPool).size === 0) + cacheKeys.foreach { key => assert(getCache(dataPool).get(key).isEmpty) } + + val dataList = cacheKeys.map(key => (key, dataPool.acquire(key, 0))) + + assert(getCache(dataPool).size === cacheKeys.size) + cacheKeys.map { key => + assert(getCache(dataPool)(key).size === 1) + assert(getCache(dataPool)(key).head.inUse) + } + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 10, expectedNumTotal = 10) + + dataList.map { case (_, data) => + data.withNewPoll(testRecords(0, 5).listIterator, 5) + } + + dataList.foreach { case (key, data) => + dataPool.release(key, data) + } + + assert(getCache(dataPool).size === cacheKeys.size) + cacheKeys.map { key => + assert(getCache(dataPool)(key).size === 1) + assert(!getCache(dataPool)(key).head.inUse) + } + + dataPool.shutdown() + } + + test("continuous use of fetched data from single key") { + val dataPool = FetchedDataPool.build + + val cacheKey = CacheKey("testgroup", new TopicPartition("topic", 0)) + + assert(getCache(dataPool).get(cacheKey).isEmpty) + + val data = dataPool.acquire(cacheKey, 0) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 1, expectedNumTotal = 1) + assert(getCache(dataPool)(cacheKey).size === 1) + assert(getCache(dataPool)(cacheKey).head.inUse) + + data.withNewPoll(testRecords(0, 5).listIterator, 5) + + (0 to 3).foreach { _ => data.next() } + + dataPool.release(cacheKey, data) + + // suppose next batch + + val data2 = dataPool.acquire(cacheKey, data.nextOffsetInFetchedData) + + assert(data.eq(data2)) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 1, expectedNumTotal = 1) + assert(getCache(dataPool)(cacheKey).size === 1) + assert(getCache(dataPool)(cacheKey).head.inUse) + + dataPool.release(cacheKey, data2) + + assert(getCache(dataPool)(cacheKey).size === 1) + assert(!getCache(dataPool)(cacheKey).head.inUse) + + dataPool.shutdown() + } + + test("multiple tasks referring same key continuously using fetched data") { + val dataPool = FetchedDataPool.build + + val cacheKey = CacheKey("testgroup", new TopicPartition("topic", 0)) + + assert(getCache(dataPool).get(cacheKey).isEmpty) + + val dataFromTask1 = dataPool.acquire(cacheKey, 0) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 1, expectedNumTotal = 1) + assert(getCache(dataPool)(cacheKey).size === 1) + assert(getCache(dataPool)(cacheKey).head.inUse) + + val dataFromTask2 = dataPool.acquire(cacheKey, 0) + + // it shouldn't give same object as dataFromTask1 though it asks same offset + // it definitely works when offsets are not overlapped: skip adding test for that + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 2, expectedNumTotal = 2) + assert(getCache(dataPool)(cacheKey).size === 2) + assert(getCache(dataPool)(cacheKey)(1).inUse) + + // reading from task 1 + dataFromTask1.withNewPoll(testRecords(0, 5).listIterator, 5) + + (0 to 3).foreach { _ => dataFromTask1.next() } + + dataPool.release(cacheKey, dataFromTask1) + + // reading from task 2 + dataFromTask2.withNewPoll(testRecords(0, 30).listIterator, 30) + + (0 to 5).foreach { _ => dataFromTask2.next() } + + dataPool.release(cacheKey, dataFromTask2) + + // suppose next batch for task 1 + val data2FromTask1 = dataPool.acquire(cacheKey, dataFromTask1.nextOffsetInFetchedData) + assert(data2FromTask1.eq(dataFromTask1)) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 2, expectedNumTotal = 2) + assert(getCache(dataPool)(cacheKey).head.inUse) + + // suppose next batch for task 2 + val data2FromTask2 = dataPool.acquire(cacheKey, dataFromTask2.nextOffsetInFetchedData) + assert(data2FromTask2.eq(dataFromTask2)) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 2, expectedNumTotal = 2) + assert(getCache(dataPool)(cacheKey)(1).inUse) + + // release from task 2 + dataPool.release(cacheKey, data2FromTask2) + assert(!getCache(dataPool)(cacheKey)(1).inUse) + + // release from task 1 + dataPool.release(cacheKey, data2FromTask1) + assert(!getCache(dataPool)(cacheKey).head.inUse) + + dataPool.shutdown() + } + + test("evict idle fetched data") { + import FetchedDataPool._ + import org.scalatest.time.SpanSugar._ + + val minEvictableIdleTimeMillis = 1000 + val evictorThreadRunIntervalMillis = 500 + + val newConf = Seq( + CONFIG_NAME_MIN_EVICTABLE_IDLE_TIME_MILLIS -> minEvictableIdleTimeMillis.toString, + CONFIG_NAME_EVICTOR_THREAD_RUN_INTERVAL_MILLIS -> evictorThreadRunIntervalMillis.toString) + + withSparkConf(newConf: _*) { + val dataPool = FetchedDataPool.build + + val cacheKeys = (0 until 10).map { partId => + CacheKey("testgroup", new TopicPartition("topic", partId)) + } + + val dataList = cacheKeys.map(key => (key, dataPool.acquire(key, 0))) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 10, expectedNumTotal = 10) + + dataList.map { case (_, data) => + data.withNewPoll(testRecords(0, 5).listIterator, 5) + } + + val dataToEvict = dataList.take(3) + dataToEvict.foreach { case (key, data) => + dataPool.release(key, data) + } + + // wait up to twice than minEvictableIdleTimeMillis to ensure evictor thread to clear up + // idle objects + eventually(timeout((minEvictableIdleTimeMillis.toLong * 2).milliseconds), + interval(evictorThreadRunIntervalMillis.milliseconds)) { + // idle objects should be evicted + dataToEvict.map { case (key, _) => + assert(getCache(dataPool)(key).isEmpty) + } + } + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 10, expectedNumTotal = 7) + assert(getCache(dataPool).values.map(_.size).sum === dataList.size - dataToEvict.size) + + dataList.takeRight(3).foreach { case (key, data) => + dataPool.release(key, data) + } + + // ensure releasing more objects don't trigger eviction immediately + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 10, expectedNumTotal = 7) + assert(getCache(dataPool).values.map(_.size).sum === dataList.size - dataToEvict.size) + + dataPool.shutdown() + } + } + + test("invalidate key") { + val dataPool = FetchedDataPool.build + + val cacheKey = CacheKey("testgroup", new TopicPartition("topic", 0)) + + val dataFromTask1 = dataPool.acquire(cacheKey, 0) + val dataFromTask2 = dataPool.acquire(cacheKey, 0) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 2, expectedNumTotal = 2) + + // 1 idle, 1 active + dataPool.release(cacheKey, dataFromTask1) + + val cacheKey2 = CacheKey("testgroup", new TopicPartition("topic", 1)) + + dataPool.acquire(cacheKey2, 0) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 3, expectedNumTotal = 3) + assert(getCache(dataPool).size === 2) + assert(getCache(dataPool)(cacheKey).size === 2) + assert(getCache(dataPool)(cacheKey2).size === 1) + + dataPool.invalidate(cacheKey) + + assertFetchedDataPoolStatistic(dataPool, expectedNumCreated = 3, expectedNumTotal = 1) + assert(getCache(dataPool).size === 1) + assert(getCache(dataPool).get(cacheKey).isEmpty) + + // it doesn't affect other keys + assert(getCache(dataPool)(cacheKey2).size === 1) + + dataPool.release(cacheKey, dataFromTask2) + + // it doesn't throw error on invalidated objects, but it doesn't cache them again + assert(getCache(dataPool).size === 1) + assert(getCache(dataPool).get(cacheKey).isEmpty) + + dataPool.shutdown() + } + + + private def testRecords(startOffset: Long, count: Int): ju.List[Record] = { + (0 until count).map { offset => + new Record("topic", 0, startOffset + offset, dummyBytes, dummyBytes) + }.toList.asJava + } + + private def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.get(key)) + } else { + None + } + } + + (keys, values).zipped.foreach { conf.set } + + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.set(key, value) + case (key, None) => conf.remove(key) + } + } + } + + private def assertFetchedDataPoolStatistic( + fetchedDataPool: FetchedDataPool, + expectedNumCreated: Long, + expectedNumTotal: Long): Unit = { + assert(fetchedDataPool.getNumCreated === expectedNumCreated) + assert(fetchedDataPool.getNumTotal === expectedNumTotal) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala new file mode 100644 index 000000000000..7aa13b7042e3 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey +import org.apache.spark.sql.test.SharedSparkSession + +class InternalKafkaConsumerPoolSuite extends SharedSparkSession { + + test("basic multiple borrows and returns for single key") { + val pool = InternalKafkaConsumerPool.build + + val topic = "topic" + val partitionId = 0 + val topicPartition = new TopicPartition(topic, partitionId) + + val kafkaParams: ju.Map[String, Object] = getTestKafkaParams + + val key = new CacheKey(topicPartition, kafkaParams) + + val pooledObjects = (0 to 2).map { _ => + val pooledObject = pool.borrowObject(key, kafkaParams) + assertPooledObject(pooledObject, topicPartition, kafkaParams) + pooledObject + } + + assertPoolStateForKey(pool, key, numIdle = 0, numActive = 3, numTotal = 3) + assertPoolState(pool, numIdle = 0, numActive = 3, numTotal = 3) + + val pooledObject2 = pool.borrowObject(key, kafkaParams) + + assertPooledObject(pooledObject2, topicPartition, kafkaParams) + assertPoolStateForKey(pool, key, numIdle = 0, numActive = 4, numTotal = 4) + assertPoolState(pool, numIdle = 0, numActive = 4, numTotal = 4) + + pooledObjects.foreach(pool.returnObject) + + assertPoolStateForKey(pool, key, numIdle = 3, numActive = 1, numTotal = 4) + assertPoolState(pool, numIdle = 3, numActive = 1, numTotal = 4) + + pool.returnObject(pooledObject2) + + // we only allow three idle objects per key + assertPoolStateForKey(pool, key, numIdle = 3, numActive = 0, numTotal = 3) + assertPoolState(pool, numIdle = 3, numActive = 0, numTotal = 3) + + pool.close() + } + + test("basic borrow and return for multiple keys") { + val pool = InternalKafkaConsumerPool.build + + val kafkaParams = getTestKafkaParams + val topicPartitions = createTopicPartitions(Seq("topic", "topic2"), 6) + val keys = createCacheKeys(topicPartitions, kafkaParams) + + // while in loop pool doesn't still exceed total pool size + val keyToPooledObjectPairs = borrowObjectsPerKey(pool, kafkaParams, keys) + + assertPoolState(pool, numIdle = 0, numActive = keyToPooledObjectPairs.length, + numTotal = keyToPooledObjectPairs.length) + + returnObjects(pool, keyToPooledObjectPairs) + + assertPoolState(pool, numIdle = keyToPooledObjectPairs.length, numActive = 0, + numTotal = keyToPooledObjectPairs.length) + + pool.close() + } + + test("borrow more than soft max capacity from pool which is neither free space nor idle object") { + testWithPoolBorrowedSoftMaxCapacity { (pool, kafkaParams, keyToPooledObjectPairs) => + val moreTopicPartition = new TopicPartition("topic2", 0) + val newCacheKey = new CacheKey(moreTopicPartition, kafkaParams) + + // exceeds soft max pool size, and also no idle object for cleaning up + // but pool will borrow a new object + pool.borrowObject(newCacheKey, kafkaParams) + + assertPoolState(pool, numIdle = 0, numActive = keyToPooledObjectPairs.length + 1, + numTotal = keyToPooledObjectPairs.length + 1) + } + } + + test("borrow more than soft max capacity from pool frees up idle objects automatically") { + testWithPoolBorrowedSoftMaxCapacity { (pool, kafkaParams, keyToPooledObjectPairs) => + // return 20% of objects to ensure there're some idle objects to free up later + val numToReturn = (keyToPooledObjectPairs.length * 0.2).toInt + returnObjects(pool, keyToPooledObjectPairs.take(numToReturn)) + + assertPoolState(pool, numIdle = numToReturn, + numActive = keyToPooledObjectPairs.length - numToReturn, + numTotal = keyToPooledObjectPairs.length) + + // borrow a new object: there should be some idle objects to clean up + val moreTopicPartition = new TopicPartition("topic2", 0) + val newCacheKey = new CacheKey(moreTopicPartition, kafkaParams) + + val newObject = pool.borrowObject(newCacheKey, kafkaParams) + assertPooledObject(newObject, moreTopicPartition, kafkaParams) + assertPoolStateForKey(pool, newCacheKey, numIdle = 0, numActive = 1, numTotal = 1) + + // at least one of idle object should be freed up + assert(pool.getNumIdle < numToReturn) + // we can determine number of active objects correctly + assert(pool.getNumActive === keyToPooledObjectPairs.length - numToReturn + 1) + // total objects should be more than number of active + 1 but can't expect exact number + assert(pool.getTotal > keyToPooledObjectPairs.length - numToReturn + 1) + } + } + + + private def testWithPoolBorrowedSoftMaxCapacity( + testFn: (InternalKafkaConsumerPool, + ju.Map[String, Object], + Seq[(CacheKey, InternalKafkaConsumer)]) => Unit): Unit = { + val capacity = 16 + val newConf = newConfForKafkaPool(Some(capacity), Some(-1), Some(-1)) + + withSparkConf(newConf: _*) { + val pool = InternalKafkaConsumerPool.build + + try { + val kafkaParams = getTestKafkaParams + val topicPartitions = createTopicPartitions(Seq("topic"), capacity) + val keys = createCacheKeys(topicPartitions, kafkaParams) + + // borrow objects which makes pool reaching soft capacity + val keyToPooledObjectPairs = borrowObjectsPerKey(pool, kafkaParams, keys) + + testFn(pool, kafkaParams, keyToPooledObjectPairs) + } finally { + pool.close() + } + } + } + + test("evicting idle objects on background") { + import org.scalatest.time.SpanSugar._ + + val minEvictableIdleTimeMillis = 3 * 1000 // 3 seconds + val evictorThreadRunIntervalMillis = 500 // triggering multiple evictions by intention + + val newConf = newConfForKafkaPool(None, Some(minEvictableIdleTimeMillis), + Some(evictorThreadRunIntervalMillis)) + withSparkConf(newConf: _*) { + val pool = InternalKafkaConsumerPool.build + + val kafkaParams = getTestKafkaParams + val topicPartitions = createTopicPartitions(Seq("topic"), 10) + val keys = createCacheKeys(topicPartitions, kafkaParams) + + // borrow and return some consumers to ensure some partitions are being idle + // this test covers the use cases: rebalance / topic removal happens while running query + val keyToPooledObjectPairs = borrowObjectsPerKey(pool, kafkaParams, keys) + val objectsToReturn = keyToPooledObjectPairs.filter(_._1.topicPartition.partition() % 2 == 0) + returnObjects(pool, objectsToReturn) + + // wait up to twice than minEvictableIdleTimeMillis to ensure evictor thread to clear up + // idle objects + eventually(timeout((minEvictableIdleTimeMillis.toLong * 2).seconds), + interval(evictorThreadRunIntervalMillis.milliseconds)) { + assertPoolState(pool, numIdle = 0, numActive = 5, numTotal = 5) + } + + pool.close() + } + } + + private def newConfForKafkaPool( + capacity: Option[Int], + minEvictableIdleTimeMillis: Option[Long], + evictorThreadRunIntervalMillis: Option[Long]): Seq[(String, String)] = { + Seq( + CONSUMER_CACHE_CAPACITY.key -> capacity, + CONSUMER_CACHE_MIN_EVICTABLE_IDLE_TIME_MILLIS.key -> minEvictableIdleTimeMillis, + CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS.key -> evictorThreadRunIntervalMillis + ).filter(_._2.isDefined).map(e => (e._1 -> e._2.get.toString)) + } + + private def createTopicPartitions( + topicNames: Seq[String], + countPartition: Int): List[TopicPartition] = { + for ( + topic <- topicNames.toList; + partitionId <- 0 until countPartition + ) yield new TopicPartition(topic, partitionId) + } + + private def createCacheKeys( + topicPartitions: List[TopicPartition], + kafkaParams: ju.Map[String, Object]): List[CacheKey] = { + topicPartitions.map(new CacheKey(_, kafkaParams)) + } + + private def assertPooledObject( + pooledObject: InternalKafkaConsumer, + expectedTopicPartition: TopicPartition, + expectedKafkaParams: ju.Map[String, Object]): Unit = { + assert(pooledObject != null) + assert(pooledObject.kafkaParams === expectedKafkaParams) + assert(pooledObject.topicPartition === expectedTopicPartition) + } + + private def assertPoolState( + pool: InternalKafkaConsumerPool, + numIdle: Int, + numActive: Int, + numTotal: Int): Unit = { + assert(pool.getNumIdle === numIdle) + assert(pool.getNumActive === numActive) + assert(pool.getTotal === numTotal) + } + + private def assertPoolStateForKey( + pool: InternalKafkaConsumerPool, + key: CacheKey, + numIdle: Int, + numActive: Int, + numTotal: Int): Unit = { + assert(pool.getNumIdle(key) === numIdle) + assert(pool.getNumActive(key) === numActive) + assert(pool.getTotal(key) === numTotal) + } + + private def getTestKafkaParams: ju.Map[String, Object] = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> "PLAINTEXT://localhost:9092", + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + private def borrowObjectsPerKey( + pool: InternalKafkaConsumerPool, + kafkaParams: ju.Map[String, Object], + keys: List[CacheKey]): Seq[(CacheKey, InternalKafkaConsumer)] = { + keys.map { key => + val numActiveBeforeBorrowing = pool.getNumActive + val numIdleBeforeBorrowing = pool.getNumIdle + val numTotalBeforeBorrowing = pool.getTotal + + val pooledObj = pool.borrowObject(key, kafkaParams) + + assertPoolStateForKey(pool, key, numIdle = 0, numActive = 1, numTotal = 1) + assertPoolState(pool, numIdle = numIdleBeforeBorrowing, + numActive = numActiveBeforeBorrowing + 1, numTotal = numTotalBeforeBorrowing + 1) + + (key, pooledObj) + } + } + + private def returnObjects( + pool: InternalKafkaConsumerPool, + objects: Seq[(CacheKey, InternalKafkaConsumer)]): Unit = { + objects.foreach { case (key, pooledObj) => + val numActiveBeforeReturning = pool.getNumActive + val numIdleBeforeReturning = pool.getNumIdle + val numTotalBeforeReturning = pool.getTotal + + pool.returnObject(pooledObj) + + // we only allow one idle object per key + assertPoolStateForKey(pool, key, numIdle = 1, numActive = 0, numTotal = 1) + assertPoolState(pool, numIdle = numIdleBeforeReturning + 1, + numActive = numActiveBeforeReturning - 1, numTotal = numTotalBeforeReturning) + } + } + + private def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.get(key)) + } else { + None + } + } + + (keys, values).zipped.foreach { conf.set } + + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.set(key, value) + case (key, None) => conf.remove(key) + } + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala index 8aa7e06e772a..3c89f5f7efd6 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kafka010 import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ +import scala.collection.immutable import scala.util.Random import org.apache.kafka.clients.consumer.ConsumerConfig._ @@ -60,49 +61,83 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester AUTO_OFFSET_RESET_CONFIG -> "earliest", ENABLE_AUTO_COMMIT_CONFIG -> "false" ).asJava + private var fetchedDataPool: FetchedDataPool = _ + private var consumerPool: InternalKafkaConsumerPool = _ + + override def beforeEach(): Unit = { + fetchedDataPool = { + val fetchedDataPoolMethod = PrivateMethod[FetchedDataPool]('fetchedDataPool) + KafkaDataConsumer.invokePrivate(fetchedDataPoolMethod()) + } + + consumerPool = { + val internalKafkaConsumerPoolMethod = PrivateMethod[InternalKafkaConsumerPool]('consumerPool) + KafkaDataConsumer.invokePrivate(internalKafkaConsumerPoolMethod()) + } + + fetchedDataPool.reset() + consumerPool.reset() + } test("SPARK-19886: Report error cause correctly in reportDataLoss") { val cause = new Exception("D'oh!") val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) val e = intercept[IllegalStateException] { - InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + KafkaDataConsumer.invokePrivate(reportDataLoss(true, "message", cause)) } assert(e.getCause === cause) } test("new KafkaDataConsumer instance in case of Task retry") { try { - KafkaDataConsumer.cache.clear() - val kafkaParams = getKafkaParams() val key = new CacheKey(groupId, topicPartition) val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) TaskContext.setTaskContext(context1) - val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true) + val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + + // any method call which requires consumer is necessary + consumer1.getAvailableOffsetRange() + + val consumer1Underlying = consumer1._consumer + assert(consumer1Underlying.isDefined) + consumer1.release() - assert(KafkaDataConsumer.cache.size() == 1) - assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer)) + assert(consumerPool.getTotal(key) === 1) + val pooledObj = consumerPool.borrowObject(key, kafkaParams) + assert(consumer1Underlying.get.eq(pooledObj)) + consumerPool.returnObject(pooledObj) val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null) TaskContext.setTaskContext(context2) - val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true) + val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + + // any method call which requires consumer is necessary + consumer2.getAvailableOffsetRange() + + val consumer2Underlying = consumer2._consumer + assert(consumer2Underlying.isDefined) + // here we expect different consumer as pool will invalidate for task reattempt + assert(consumer2Underlying.get.ne(consumer1Underlying.get)) + consumer2.release() - // The first consumer should be removed from cache and new non-cached should be returned - assert(KafkaDataConsumer.cache.size() == 0) - assert(consumer1.internalConsumer.ne(consumer2.internalConsumer)) + // The first consumer should be removed from cache, but second consumer should be cached. + assert(consumerPool.getTotal(key) === 1) + val pooledObj2 = consumerPool.borrowObject(key, kafkaParams) + assert(consumer2Underlying.get.eq(pooledObj2)) + consumerPool.returnObject(pooledObj2) } finally { TaskContext.unset() } } test("SPARK-23623: concurrent use of KafkaDataConsumer") { - val data = (1 to 1000).map(_.toString) - testUtils.createTopic(topic, 1) - testUtils.sendMessages(topic, data.toArray) + val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic) + val topicPartition = new TopicPartition(topic, 0) val kafkaParams = getKafkaParams() val numThreads = 100 val numConsumerUsages = 500 @@ -110,14 +145,13 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester @volatile var error: Throwable = null def consume(i: Int): Unit = { - val useCache = Random.nextBoolean val taskContext = if (Random.nextBoolean) { new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) } else { null } TaskContext.setTaskContext(taskContext) - val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache) + val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams) try { val range = consumer.getAvailableOffsetRange() val rcvd = range.earliest until range.latest map { offset => @@ -147,4 +181,143 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester threadpool.shutdown() } } + + test("SPARK-25151 Handles multiple tasks in executor fetching same (topic, partition) pair") { + prepareTestTopicHavingTestMessages(topic) + val topicPartition = new TopicPartition(topic, 0) + + val kafkaParams = getKafkaParams() + + withTaskContext(TaskContext.empty()) { + // task A trying to fetch offset 0 to 100, and read 5 records + val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + val lastOffsetForConsumer1 = readAndGetLastOffset(consumer1, 0, 100, 5) + consumer1.release() + + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 1, expectedNumTotal = 1) + + // task B trying to fetch offset 300 to 500, and read 5 records + val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + val lastOffsetForConsumer2 = readAndGetLastOffset(consumer2, 300, 500, 5) + consumer2.release() + + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 2, expectedNumTotal = 2) + + // task A continue reading from the last offset + 1, with upper bound 100 again + val consumer1a = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + + consumer1a.get(lastOffsetForConsumer1 + 1, 100, 10000, failOnDataLoss = false) + consumer1a.release() + + // pool should succeed to provide cached data instead of creating one + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 2, expectedNumTotal = 2) + + // task B also continue reading from the last offset + 1, with upper bound 500 again + val consumer2a = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + + consumer2a.get(lastOffsetForConsumer2 + 1, 500, 10000, failOnDataLoss = false) + consumer2a.release() + + // same expectation: pool should succeed to provide cached data instead of creating one + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 2, expectedNumTotal = 2) + } + } + + test("SPARK-25151 Handles multiple tasks in executor fetching same (topic, partition) pair " + + "and same offset (edge-case) - data in use") { + prepareTestTopicHavingTestMessages(topic) + val topicPartition = new TopicPartition(topic, 0) + + val kafkaParams = getKafkaParams() + + withTaskContext(TaskContext.empty()) { + // task A trying to fetch offset 0 to 100, and read 5 records (still reading) + val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + val lastOffsetForConsumer1 = readAndGetLastOffset(consumer1, 0, 100, 5) + + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 1, expectedNumTotal = 1) + + // task B trying to fetch offset the last offset task A is reading so far + 1 to 500 + // this is a condition for edge case + val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + consumer2.get(lastOffsetForConsumer1 + 1, 100, 10000, failOnDataLoss = false) + + // Pool must create a new fetched data instead of returning existing on now in use even + // there's fetched data matching start offset. + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 2, expectedNumTotal = 2) + + consumer1.release() + consumer2.release() + } + } + + test("SPARK-25151 Handles multiple tasks in executor fetching same (topic, partition) pair " + + "and same offset (edge-case) - data not in use") { + prepareTestTopicHavingTestMessages(topic) + val topicPartition = new TopicPartition(topic, 0) + + val kafkaParams = getKafkaParams() + + withTaskContext(TaskContext.empty()) { + // task A trying to fetch offset 0 to 100, and read 5 records (still reading) + val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + val lastOffsetForConsumer1 = readAndGetLastOffset(consumer1, 0, 100, 5) + consumer1.release() + + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 1, expectedNumTotal = 1) + + // task B trying to fetch offset the last offset task A is reading so far + 1 to 500 + // this is a condition for edge case + val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + consumer2.get(lastOffsetForConsumer1 + 1, 100, 10000, failOnDataLoss = false) + + // Pool cannot determine the origin task, so it has to just provide matching one. + // task A may come back and try to fetch, and cannot find previous data + // (or the data is in use). + // If then task A may have to fetch from Kafka, but we already avoided fetching from Kafka in + // task B, so it is not a big deal in overall. + assertFetchedDataPoolStatistic(fetchedDataPool, expectedNumCreated = 1, expectedNumTotal = 1) + + consumer2.release() + } + } + + private def assertFetchedDataPoolStatistic( + fetchedDataPool: FetchedDataPool, + expectedNumCreated: Long, + expectedNumTotal: Long): Unit = { + assert(fetchedDataPool.getNumCreated === expectedNumCreated) + assert(fetchedDataPool.getNumTotal === expectedNumTotal) + } + + private def readAndGetLastOffset( + consumer: KafkaDataConsumer, + startOffset: Long, + untilOffset: Long, + numToRead: Int): Long = { + var lastOffset: Long = startOffset - 1 + (0 until numToRead).foreach { _ => + val record = consumer.get(lastOffset + 1, untilOffset, 10000, failOnDataLoss = false) + // validation for fetched record is covered by other tests, so skip on validating + lastOffset = record.offset() + } + lastOffset + } + + private def prepareTestTopicHavingTestMessages(topic: String) = { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic, 1) + testUtils.sendMessages(topic, data.toArray) + data + } + + private def withTaskContext(context: TaskContext)(task: => Unit): Unit = { + try { + TaskContext.setTaskContext(context) + task + } finally { + TaskContext.unset() + } + } + } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 8663a5d8d26c..ae8a6886b2b4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1146,7 +1146,6 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { KafkaSourceOffset(Map(tp -> 100L))).map(_.asInstanceOf[KafkaBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { assert(inputPartitions.size == numPartitionsGenerated) - inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } } diff --git a/pom.xml b/pom.xml index 6a8424cc1328..17947451ea38 100644 --- a/pom.xml +++ b/pom.xml @@ -180,6 +180,8 @@ 2.6 3.8.1 + + 2.6.2 3.2.10 3.0.15 2.29 From 8cb52e39092e74a1db35ee909c5f93ccda9d1e55 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 27 Aug 2019 06:23:49 +0900 Subject: [PATCH 02/13] Address review comments --- .../spark/sql/kafka010/FetchedDataPool.scala | 64 ++++++++----------- .../kafka010/InternalKafkaConsumerPool.scala | 25 ++++---- .../sql/kafka010/KafkaDataConsumer.scala | 6 +- .../sql/kafka010/FetchedDataPoolSuite.scala | 10 +-- .../InternalKafkaConsumerPoolSuite.scala | 30 ++++----- .../sql/kafka010/KafkaDataConsumerSuite.scala | 12 ++-- 6 files changed, 70 insertions(+), 77 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala index a408c27d21f8..d58dfc222525 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/FetchedDataPool.scala @@ -22,14 +22,13 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.LongAdder import scala.collection.mutable -import scala.util.control.NonFatal import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.kafka010.KafkaDataConsumer.{CacheKey, UNKNOWN_OFFSET} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Provides object pool for [[FetchedData]] which is grouped by [[CacheKey]]. @@ -46,12 +45,8 @@ private[kafka010] class FetchedDataPool extends Logging { private val (minEvictableIdleTimeMillis, evictorThreadRunIntervalMillis): (Long, Long) = { val conf = SparkEnv.get.conf - val minEvictIdleTime = conf.getLong(CONFIG_NAME_MIN_EVICTABLE_IDLE_TIME_MILLIS, - DEFAULT_VALUE_MIN_EVICTABLE_IDLE_TIME_MILLIS) - - val evictorThreadInterval = conf.getLong( - CONFIG_NAME_EVICTOR_THREAD_RUN_INTERVAL_MILLIS, - DEFAULT_VALUE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS) + val minEvictIdleTime = conf.get(CONSUMER_CACHE_MIN_EVICTABLE_IDLE_TIME_MILLIS) + val evictorThreadInterval = conf.get(CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS) (minEvictIdleTime, evictorThreadInterval) } @@ -62,12 +57,7 @@ private[kafka010] class FetchedDataPool extends Logging { private def startEvictorThread(): ScheduledFuture[_] = { executorService.scheduleAtFixedRate(new Runnable { override def run(): Unit = { - try { - removeIdleFetchedData() - } catch { - case NonFatal(e) => - logWarning("Exception occurred while removing idle fetched data.", e) - } + Utils.tryLogNonFatalError(removeIdleFetchedData()) } }, 0, evictorThreadRunIntervalMillis, TimeUnit.MILLISECONDS) } @@ -77,8 +67,8 @@ private[kafka010] class FetchedDataPool extends Logging { private val numCreatedFetchedData = new LongAdder() private val numTotalElements = new LongAdder() - def getNumCreated: Long = numCreatedFetchedData.sum() - def getNumTotal: Long = numTotalElements.sum() + def numCreated: Long = numCreatedFetchedData.sum() + def numTotal: Long = numTotalElements.sum() def acquire(key: CacheKey, desiredStartOffset: Long): FetchedData = synchronized { val fetchedDataList = cache.getOrElseUpdate(key, new CachedFetchedDataList()) @@ -112,25 +102,32 @@ private[kafka010] class FetchedDataPool extends Logging { } def release(key: CacheKey, fetchedData: FetchedData): Unit = synchronized { + def warnReleasedDataNotInPool(key: CacheKey, fetchedData: FetchedData): Unit = { + logWarning(s"No matching data in pool for $fetchedData in key $key. " + + "It might be released before, or it was not a part of pool.") + } + cache.get(key) match { case Some(fetchedDataList) => val cachedFetchedDataOption = fetchedDataList.find { p => p.inUse && p.getObject == fetchedData } - if (cachedFetchedDataOption.isDefined) { + if (cachedFetchedDataOption.isEmpty) { + warnReleasedDataNotInPool(key, fetchedData) + } else { val cachedFetchedData = cachedFetchedDataOption.get cachedFetchedData.inUse = false - cachedFetchedData.lastReleasedTimestamp = System.currentTimeMillis() + cachedFetchedData.lastReleasedTimestamp = System.nanoTime() } - case None => logWarning(s"No matching data in pool for $fetchedData in key $key. " + - "It might be released before, or it was not a part of pool.") + case None => + warnReleasedDataNotInPool(key, fetchedData) } } def shutdown(): Unit = { - executorService.shutdownNow() + ThreadUtils.shutdown(executorService) } def reset(): Unit = synchronized { @@ -144,13 +141,17 @@ private[kafka010] class FetchedDataPool extends Logging { } private def removeIdleFetchedData(): Unit = synchronized { - val timestamp = System.currentTimeMillis() - val maxAllowedIdleTimestamp = timestamp - minEvictableIdleTimeMillis + val timestamp = System.nanoTime() + val minEvictableIdleTimeNanos = TimeUnit.MILLISECONDS.toNanos(minEvictableIdleTimeMillis) + val maxAllowedIdleTimestamp = timestamp - minEvictableIdleTimeNanos cache.values.foreach { p: CachedFetchedDataList => - val idles = p.filter(q => !q.inUse && q.lastReleasedTimestamp < maxAllowedIdleTimestamp) - val lstSize = p.size - idles.foreach(idle => p -= idle) - numTotalElements.add(-1 * (lstSize - p.size)) + val expired = p.filter { + q => !q.inUse && q.lastReleasedTimestamp < maxAllowedIdleTimestamp + } + expired.foreach { + idle => p -= idle + } + numTotalElements.add(-1 * expired.size) } } } @@ -177,14 +178,5 @@ private[kafka010] object FetchedDataPool { private[kafka010] type CachedFetchedDataList = mutable.ListBuffer[CachedFetchedData] - val CONFIG_NAME_PREFIX = "spark.sql.kafkaFetchedDataCache." - val CONFIG_NAME_MIN_EVICTABLE_IDLE_TIME_MILLIS = CONFIG_NAME_PREFIX + - "minEvictableIdleTimeMillis" - val CONFIG_NAME_EVICTOR_THREAD_RUN_INTERVAL_MILLIS = CONFIG_NAME_PREFIX + - "evictorThreadRunIntervalMillis" - - val DEFAULT_VALUE_MIN_EVICTABLE_IDLE_TIME_MILLIS = 10 * 60 * 1000 // 10 minutes - val DEFAULT_VALUE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS = 5 * 60 * 1000 // 3 minutes - def build: FetchedDataPool = new FetchedDataPool() } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala index f268508a7c61..fea4831333b7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala @@ -52,7 +52,7 @@ private[kafka010] class InternalKafkaConsumerPool( // the class is intended to have only soft capacity assert(poolConfig.getMaxTotal < 0) - private lazy val pool = { + private val pool = { val internalPool = new GenericKeyedObjectPool[CacheKey, InternalKafkaConsumer]( objectFactory, poolConfig) internalPool.setSwallowedExceptionListener(CustomSwallowedExceptionListener) @@ -72,7 +72,7 @@ private[kafka010] class InternalKafkaConsumerPool( def borrowObject(key: CacheKey, kafkaParams: ju.Map[String, Object]): InternalKafkaConsumer = { updateKafkaParamForKey(key, kafkaParams) - if (getTotal == poolConfig.getSoftMaxTotal()) { + if (size == poolConfig.softMaxSize()) { pool.clearOldest() } @@ -111,17 +111,17 @@ private[kafka010] class InternalKafkaConsumerPool( pool.clear() } - def getNumIdle: Int = pool.getNumIdle + def numIdle: Int = pool.getNumIdle - def getNumIdle(key: CacheKey): Int = pool.getNumIdle(key) + def numIdle(key: CacheKey): Int = pool.getNumIdle(key) - def getNumActive: Int = pool.getNumActive + def numActive: Int = pool.getNumActive - def getNumActive(key: CacheKey): Int = pool.getNumActive(key) + def numActive(key: CacheKey): Int = pool.getNumActive(key) - def getTotal: Int = getNumIdle + getNumActive + def size: Int = numIdle + numActive - def getTotal(key: CacheKey): Int = getNumIdle(key) + getNumActive(key) + def size(key: CacheKey): Int = numIdle(key) + numActive(key) private def updateKafkaParamForKey(key: CacheKey, kafkaParams: ju.Map[String, Object]): Unit = { // We can assume that kafkaParam should not be different for same cache key, @@ -155,16 +155,16 @@ private[kafka010] object InternalKafkaConsumerPool { } class PoolConfig extends GenericKeyedObjectPoolConfig[InternalKafkaConsumer] { - private var softMaxTotal = Int.MaxValue + private var _softMaxSize = Int.MaxValue - def getSoftMaxTotal(): Int = softMaxTotal + def softMaxSize(): Int = _softMaxSize init() def init(): Unit = { val conf = SparkEnv.get.conf - softMaxTotal = conf.get(CONSUMER_CACHE_CAPACITY) + _softMaxSize = conf.get(CONSUMER_CACHE_CAPACITY) val jmxEnabled = conf.get(CONSUMER_CACHE_JMX_ENABLED) val minEvictableIdleTimeMillis = conf.get(CONSUMER_CACHE_MIN_EVICTABLE_IDLE_TIME_MILLIS) @@ -203,8 +203,7 @@ private[kafka010] object InternalKafkaConsumerPool { class ObjectFactory extends BaseKeyedPooledObjectFactory[CacheKey, InternalKafkaConsumer] with Logging { - val keyToKafkaParams: ConcurrentHashMap[CacheKey, ju.Map[String, Object]] = - new ConcurrentHashMap[CacheKey, ju.Map[String, Object]]() + val keyToKafkaParams = new ConcurrentHashMap[CacheKey, ju.Map[String, Object]]() override def create(key: CacheKey): InternalKafkaConsumer = { Option(keyToKafkaParams.get(key)) match { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 9e98ae562937..0cba9f7e59eb 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -46,7 +46,7 @@ private[kafka010] class InternalKafkaConsumer( val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private val consumer = createConsumer + private val consumer = createConsumer() /** * Poll messages from Kafka starting from `offset` and returns a pair of "list of consumer record" @@ -59,7 +59,7 @@ private[kafka010] class InternalKafkaConsumer( * consumer polls nothing before timeout. */ def fetch(offset: Long, pollTimeoutMs: Long): - (ju.List[ConsumerRecord[Array[Byte], Array[Byte]]], Long) = { + (ju.List[ConsumerRecord[Array[Byte], Array[Byte]]], Long) = { // Seek to the offset because we may call seekToBeginning or seekToEnd before this. seek(offset) @@ -105,7 +105,7 @@ private[kafka010] class InternalKafkaConsumer( } /** Create a KafkaConsumer to fetch records for `topicPartition` */ - private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { + private def createConsumer(): KafkaConsumer[Array[Byte], Array[Byte]] = { val updatedKafkaParams = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) .setAuthenticationConfigIfNeeded() .build() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala index ad3975d673f3..c80b0adb1378 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/FetchedDataPoolSuite.scala @@ -203,15 +203,15 @@ class FetchedDataPoolSuite extends SharedSparkSession with PrivateMethodTester { } test("evict idle fetched data") { - import FetchedDataPool._ import org.scalatest.time.SpanSugar._ val minEvictableIdleTimeMillis = 1000 val evictorThreadRunIntervalMillis = 500 val newConf = Seq( - CONFIG_NAME_MIN_EVICTABLE_IDLE_TIME_MILLIS -> minEvictableIdleTimeMillis.toString, - CONFIG_NAME_EVICTOR_THREAD_RUN_INTERVAL_MILLIS -> evictorThreadRunIntervalMillis.toString) + CONSUMER_CACHE_MIN_EVICTABLE_IDLE_TIME_MILLIS.key -> minEvictableIdleTimeMillis.toString, + CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL_MILLIS.key -> + evictorThreadRunIntervalMillis.toString) withSparkConf(newConf: _*) { val dataPool = FetchedDataPool.build @@ -331,7 +331,7 @@ class FetchedDataPoolSuite extends SharedSparkSession with PrivateMethodTester { fetchedDataPool: FetchedDataPool, expectedNumCreated: Long, expectedNumTotal: Long): Unit = { - assert(fetchedDataPool.getNumCreated === expectedNumCreated) - assert(fetchedDataPool.getNumTotal === expectedNumTotal) + assert(fetchedDataPool.numCreated === expectedNumCreated) + assert(fetchedDataPool.numTotal === expectedNumTotal) } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala index 7aa13b7042e3..497c784c238e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala @@ -125,11 +125,11 @@ class InternalKafkaConsumerPoolSuite extends SharedSparkSession { assertPoolStateForKey(pool, newCacheKey, numIdle = 0, numActive = 1, numTotal = 1) // at least one of idle object should be freed up - assert(pool.getNumIdle < numToReturn) + assert(pool.numIdle < numToReturn) // we can determine number of active objects correctly - assert(pool.getNumActive === keyToPooledObjectPairs.length - numToReturn + 1) + assert(pool.numActive === keyToPooledObjectPairs.length - numToReturn + 1) // total objects should be more than number of active + 1 but can't expect exact number - assert(pool.getTotal > keyToPooledObjectPairs.length - numToReturn + 1) + assert(pool.size > keyToPooledObjectPairs.length - numToReturn + 1) } } @@ -231,9 +231,9 @@ class InternalKafkaConsumerPoolSuite extends SharedSparkSession { numIdle: Int, numActive: Int, numTotal: Int): Unit = { - assert(pool.getNumIdle === numIdle) - assert(pool.getNumActive === numActive) - assert(pool.getTotal === numTotal) + assert(pool.numIdle === numIdle) + assert(pool.numActive === numActive) + assert(pool.size === numTotal) } private def assertPoolStateForKey( @@ -242,9 +242,9 @@ class InternalKafkaConsumerPoolSuite extends SharedSparkSession { numIdle: Int, numActive: Int, numTotal: Int): Unit = { - assert(pool.getNumIdle(key) === numIdle) - assert(pool.getNumActive(key) === numActive) - assert(pool.getTotal(key) === numTotal) + assert(pool.numIdle(key) === numIdle) + assert(pool.numActive(key) === numActive) + assert(pool.size(key) === numTotal) } private def getTestKafkaParams: ju.Map[String, Object] = Map[String, Object]( @@ -261,9 +261,9 @@ class InternalKafkaConsumerPoolSuite extends SharedSparkSession { kafkaParams: ju.Map[String, Object], keys: List[CacheKey]): Seq[(CacheKey, InternalKafkaConsumer)] = { keys.map { key => - val numActiveBeforeBorrowing = pool.getNumActive - val numIdleBeforeBorrowing = pool.getNumIdle - val numTotalBeforeBorrowing = pool.getTotal + val numActiveBeforeBorrowing = pool.numActive + val numIdleBeforeBorrowing = pool.numIdle + val numTotalBeforeBorrowing = pool.size val pooledObj = pool.borrowObject(key, kafkaParams) @@ -279,9 +279,9 @@ class InternalKafkaConsumerPoolSuite extends SharedSparkSession { pool: InternalKafkaConsumerPool, objects: Seq[(CacheKey, InternalKafkaConsumer)]): Unit = { objects.foreach { case (key, pooledObj) => - val numActiveBeforeReturning = pool.getNumActive - val numIdleBeforeReturning = pool.getNumIdle - val numTotalBeforeReturning = pool.getTotal + val numActiveBeforeReturning = pool.numActive + val numIdleBeforeReturning = pool.numIdle + val numTotalBeforeReturning = pool.size pool.returnObject(pooledObj) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala index 3c89f5f7efd6..80f9a1b410d2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -105,7 +105,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester consumer1.release() - assert(consumerPool.getTotal(key) === 1) + assert(consumerPool.size(key) === 1) + // check whether acquired object is available in pool val pooledObj = consumerPool.borrowObject(key, kafkaParams) assert(consumer1Underlying.get.eq(pooledObj)) consumerPool.returnObject(pooledObj) @@ -124,8 +125,9 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester consumer2.release() - // The first consumer should be removed from cache, but second consumer should be cached. - assert(consumerPool.getTotal(key) === 1) + // The first consumer should be removed from cache, but the consumer after invalidate + // should be cached. + assert(consumerPool.size(key) === 1) val pooledObj2 = consumerPool.borrowObject(key, kafkaParams) assert(consumer2Underlying.get.eq(pooledObj2)) consumerPool.returnObject(pooledObj2) @@ -286,8 +288,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester fetchedDataPool: FetchedDataPool, expectedNumCreated: Long, expectedNumTotal: Long): Unit = { - assert(fetchedDataPool.getNumCreated === expectedNumCreated) - assert(fetchedDataPool.getNumTotal === expectedNumTotal) + assert(fetchedDataPool.numCreated === expectedNumCreated) + assert(fetchedDataPool.numTotal === expectedNumTotal) } private def readAndGetLastOffset( From 9543745bc41b2adf258a8a91e3c5aab59c5d0cbd Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 27 Aug 2019 14:19:18 +0900 Subject: [PATCH 03/13] Modify eviction UT to leverage manual scheduler and clock --- external/kafka-0-10-sql/pom.xml | 5 ++ .../spark/sql/kafka010/FetchedDataPool.scala | 36 ++++++----- .../sql/kafka010/FetchedDataPoolSuite.scala | 60 ++++++++++++++----- 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index feba787e9901..5b8738263a60 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -130,6 +130,11 @@ org.apache.spark spark-tags_${scala.binary.version} + + org.jmock + jmock-junit4 + test +