Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecor
import org.apache.kafka.common.header.Header
import org.apache.kafka.common.header.internals.RecordHeader

import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection}
import org.apache.spark.sql.types.BinaryType

/**
* Writes out data in a single Spark task, without any concerns about how
Expand Down Expand Up @@ -116,66 +116,13 @@ private[kafka010] abstract class KafkaRowWriter(
}

private def createProjection = {
val topicExpression = topic.map(Literal(_)).orElse {
inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
}.getOrElse {
throw new IllegalStateException(s"topic option required when no " +
s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present")
}
topicExpression.dataType match {
case StringType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"must be a ${StringType.catalogString}")
}
val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME)
.getOrElse(Literal(null, BinaryType))
keyExpression.dataType match {
case StringType | BinaryType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
val valueExpression = inputSchema
.find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
throw new IllegalStateException("Required attribute " +
s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
)
valueExpression.dataType match {
case StringType | BinaryType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
val headersExpression = inputSchema
.find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse(
Literal(CatalystTypeConverters.convertToCatalyst(null),
KafkaRecordToRowConverter.headersType)
)
headersExpression.dataType match {
case KafkaRecordToRowConverter.headersType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
val partitionExpression =
inputSchema.find(_.name == KafkaWriter.PARTITION_ATTRIBUTE_NAME)
.getOrElse(Literal(null, IntegerType))
partitionExpression.dataType match {
case IntegerType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " +
s"attribute unsupported type $t. ${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " +
s"must be a ${IntegerType.catalogString}")
}
UnsafeProjection.create(
Seq(
topicExpression,
Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType),
headersExpression,
partitionExpression
KafkaWriter.topicExpression(inputSchema, topic),
Cast(KafkaWriter.keyExpression(inputSchema), BinaryType),
Cast(KafkaWriter.valueExpression(inputSchema), BinaryType),
KafkaWriter.headersExpression(inputSchema),
KafkaWriter.partitionExpression(inputSchema)
),
inputSchema
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{BinaryType, IntegerType, MapType, StringType}
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType}
import org.apache.spark.util.Utils

/**
Expand All @@ -49,51 +49,14 @@ private[kafka010] object KafkaWriter extends Logging {
schema: Seq[Attribute],
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
if (topic.isEmpty) {
throw new AnalysisException(s"topic option required when no " +
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
} else {
Literal.create(topic.get, StringType)
}
).dataType match {
case StringType => // good
case _ =>
throw new AnalysisException(s"Topic type must be a ${StringType.catalogString}")
}
schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse(
Literal(null, StringType)
).dataType match {
case StringType | BinaryType => // good
case _ =>
throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " +
s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}")
}
schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse(
throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found")
).dataType match {
case StringType | BinaryType => // good
case _ =>
throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}")
}
schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse(
Literal(CatalystTypeConverters.convertToCatalyst(null),
KafkaRecordToRowConverter.headersType)
).dataType match {
case KafkaRecordToRowConverter.headersType => // good
case _ =>
throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " +
s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}")
}
schema.find(_.name == PARTITION_ATTRIBUTE_NAME).getOrElse(
Literal(null, IntegerType)
).dataType match {
case IntegerType => // good
case _ =>
throw new AnalysisException(s"$PARTITION_ATTRIBUTE_NAME attribute type " +
s"must be an ${IntegerType.catalogString}")
try {
topicExpression(schema, topic)
keyExpression(schema)
valueExpression(schema)
headersExpression(schema)
partitionExpression(schema)
} catch {
case e: IllegalStateException => throw new AnalysisException(e.getMessage)
}
}

Expand All @@ -110,4 +73,53 @@ private[kafka010] object KafkaWriter extends Logging {
finallyBlock = writeTask.close())
}
}

def topicExpression(schema: Seq[Attribute], topic: Option[String] = None): Expression = {
topic.map(Literal(_)).getOrElse(
expression(schema, TOPIC_ATTRIBUTE_NAME, Seq(StringType)) {
throw new IllegalStateException(s"topic option required when no " +
s"'${TOPIC_ATTRIBUTE_NAME}' attribute is present. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
}
)
}

def keyExpression(schema: Seq[Attribute]): Expression = {
expression(schema, KEY_ATTRIBUTE_NAME, Seq(StringType, BinaryType)) {
Literal(null, BinaryType)
}
}

def valueExpression(schema: Seq[Attribute]): Expression = {
expression(schema, VALUE_ATTRIBUTE_NAME, Seq(StringType, BinaryType)) {
throw new IllegalStateException(s"Required attribute '${VALUE_ATTRIBUTE_NAME}' not found")
}
}

def headersExpression(schema: Seq[Attribute]): Expression = {
expression(schema, HEADERS_ATTRIBUTE_NAME, Seq(KafkaRecordToRowConverter.headersType)) {
Literal(CatalystTypeConverters.convertToCatalyst(null),
KafkaRecordToRowConverter.headersType)
}
}

def partitionExpression(schema: Seq[Attribute]): Expression = {
expression(schema, PARTITION_ATTRIBUTE_NAME, Seq(IntegerType)) {
Literal(null, IntegerType)
}
}

private def expression(
schema: Seq[Attribute],
attrName: String,
desired: Seq[DataType])(
default: => Expression): Expression = {
val expr = schema.find(_.name == attrName).getOrElse(default)
if (!desired.exists(_.sameType(expr.dataType))) {
throw new IllegalStateException(s"$attrName attribute unsupported type " +
s"${expr.dataType.catalogString}. $attrName must be a(n) " +
s"${desired.map(_.catalogString).mkString(" or ")}")
}
expr
}
}
Loading