Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit}

import com.google.common.cache._
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.apache.kafka.clients.producer.KafkaProducer
import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging

private[kafka010] object CachedKafkaProducer extends Logging {

private type Producer = KafkaProducer[Array[Byte], Array[Byte]]

private lazy val cacheExpireTimeout: Long =
SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m")

private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] {
override def load(config: Seq[(String, Object)]): Producer = {
val configMap = config.map(x => x._1 -> x._2).toMap.asJava
createKafkaProducer(configMap)
}
}

private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() {
override def onRemoval(
notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = {
val paramsSeq: Seq[(String, Object)] = notification.getKey
val producer: Producer = notification.getValue
logDebug(
s"Evicting kafka producer $producer params: $paramsSeq, due to ${notification.getCause}")
close(paramsSeq, producer)
}
}

private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] =
CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS)
.removalListener(removalListener)
.build[Seq[(String, Object)], Producer](cacheLoader)

private def createKafkaProducer(producerConfiguration: ju.Map[String, Object]): Producer = {
val kafkaProducer: Producer = new Producer(producerConfiguration)
logDebug(s"Created a new instance of KafkaProducer for $producerConfiguration.")
kafkaProducer
}

/**
* Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't
* exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep
* one instance per specified kafkaParams.
*/
private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = {
val paramsSeq: Seq[(String, Object)] = paramsToSeq(kafkaParams)
try {
guavaCache.get(paramsSeq)
} catch {
case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError)
if e.getCause != null =>
throw e.getCause
}
}

private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = {
val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1)
paramsSeq
}

/** For explicitly closing kafka producer */
private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
val paramsSeq = paramsToSeq(kafkaParams)
guavaCache.invalidate(paramsSeq)
}

/** Auto close on cache evict */
private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = {
try {
logInfo(s"Closing the KafkaProducer with params: ${paramsSeq.mkString("\n")}.")
producer.close()
} catch {
case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
}
}

private def clear(): Unit = {
logInfo("Cleaning up guava cache.")
guavaCache.invalidateAll()
}

// Intended for testing purpose only.
private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap()
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String
* and not use wrong broker addresses.
*/
private[kafka010] class KafkaSource(
sqlContext: SQLContext,
kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
sqlContext: SQLContext,
kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends Source with Logging {

private val sc = sqlContext.sparkContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.kafka.clients.producer.{KafkaProducer, _}
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
Expand All @@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask(
* Writes key value data out to topics.
*/
def execute(iterator: Iterator[InternalRow]): Unit = {
producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration)
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
val projectedRow = projection(currentRow)
Expand All @@ -68,10 +67,10 @@ private[kafka010] class KafkaWriteTask(
}

def close(): Unit = {
checkForErrors()
if (producer != null) {
checkForErrors
producer.close()
checkForErrors
producer.flush()
checkForErrors()
producer = null
}
}
Expand All @@ -88,7 +87,7 @@ private[kafka010] class KafkaWriteTask(
case t =>
throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"must be a ${StringType}")
"must be a StringType")
}
val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME)
.getOrElse(Literal(null, BinaryType))
Expand All @@ -100,7 +99,7 @@ private[kafka010] class KafkaWriteTask(
}
val valueExpression = inputSchema
.find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
throw new IllegalStateException(s"Required attribute " +
throw new IllegalStateException("Required attribute " +
s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
)
valueExpression.dataType match {
Expand All @@ -114,7 +113,7 @@ private[kafka010] class KafkaWriteTask(
Cast(valueExpression, BinaryType)), inputSchema)
}

private def checkForErrors: Unit = {
private def checkForErrors(): Unit = {
if (failedWrite != null) {
throw failedWrite
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType}
Expand Down Expand Up @@ -49,7 +48,7 @@ private[kafka010] object KafkaWriter extends Logging {
topic: Option[String] = None): Unit = {
val schema = queryExecution.analyzed.output
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
if (topic == None) {
if (topic.isEmpty) {
throw new AnalysisException(s"topic option required when no " +
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.util.concurrent.ConcurrentMap

import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.PrivateMethodTester

import org.apache.spark.sql.test.SharedSQLContext

class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester {

type KP = KafkaProducer[Array[Byte], Array[Byte]]

protected override def beforeEach(): Unit = {
super.beforeEach()
val clear = PrivateMethod[Unit]('clear)
CachedKafkaProducer.invokePrivate(clear())
}

test("Should return the cached instance on calling getOrCreate with same params.") {
val kafkaParams = new ju.HashMap[String, Object]()
kafkaParams.put("acks", "0")
// Here only host should be resolvable, it does not need a running instance of kafka server.
kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
val producer = CachedKafkaProducer.getOrCreate(kafkaParams)
val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams)
assert(producer == producer2)

val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
val map = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map.size == 1)
}

test("Should close the correct kafka producer for the given kafkaPrams.") {
val kafkaParams = new ju.HashMap[String, Object]()
kafkaParams.put("acks", "0")
kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
kafkaParams.put("acks", "1")
val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
// With updated conf, a new producer instance should be created.
assert(producer != producer2)

val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
val map = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map.size == 2)

CachedKafkaProducer.close(kafkaParams)
val map2 = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map2.size == 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just know there is one KP by this assert. Seems we should also verify if we close the correct KP?

import scala.collection.JavaConverters._
val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0)
assert(_producer == producer)
}
}