@@ -28,12 +28,13 @@ import scala.collection.JavaConverters._
2828import scala .io .Source
2929import scala .util .Random
3030
31+ import org .apache .commons .io .FileUtils
3132import org .apache .kafka .clients .producer .{ProducerRecord , RecordMetadata }
3233import org .apache .kafka .common .TopicPartition
3334import org .scalatest .concurrent .PatienceConfiguration .Timeout
3435import org .scalatest .time .SpanSugar ._
3536
36- import org .apache .spark .sql .{Dataset , ForeachWriter , SparkSession }
37+ import org .apache .spark .sql .{Dataset , ForeachWriter , Row , SparkSession }
3738import org .apache .spark .sql .catalyst .util .CaseInsensitiveMap
3839import org .apache .spark .sql .connector .read .streaming .SparkDataStream
3940import org .apache .spark .sql .execution .datasources .v2 .StreamingDataSourceV2Relation
@@ -47,6 +48,7 @@ import org.apache.spark.sql.streaming.{StreamTest, Trigger}
4748import org .apache .spark .sql .streaming .util .StreamManualClock
4849import org .apache .spark .sql .test .SharedSparkSession
4950import org .apache .spark .sql .util .CaseInsensitiveStringMap
51+ import org .apache .spark .util .Utils
5052
5153abstract class KafkaSourceTest extends StreamTest with SharedSparkSession with KafkaTest {
5254
@@ -1162,6 +1164,62 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
11621164 intercept[IllegalArgumentException ] { test(minPartitions = " -1" , 1 , true ) }
11631165 }
11641166
1167+ test(" default config of includeHeader doesn't break existing query from Spark 2.4" ) {
1168+ import testImplicits ._
1169+
1170+ // This topic name is migrated from Spark 2.4.3 test run
1171+ val topic = " spark-test-topic-2b8619f5-d3c4-4c2d-b5d1-8d9d9458aa62"
1172+ // create same topic and messages as test run
1173+ testUtils.createTopic(topic, partitions = 5 , overwrite = true )
1174+ testUtils.sendMessages(topic, Array (- 20 , - 21 , - 22 ).map(_.toString), Some (0 ))
1175+ testUtils.sendMessages(topic, Array (- 10 , - 11 , - 12 ).map(_.toString), Some (1 ))
1176+ testUtils.sendMessages(topic, Array (0 , 1 , 2 ).map(_.toString), Some (2 ))
1177+ testUtils.sendMessages(topic, Array (10 , 11 , 12 ).map(_.toString), Some (3 ))
1178+ testUtils.sendMessages(topic, Array (20 , 21 , 22 ).map(_.toString), Some (4 ))
1179+ require(testUtils.getLatestOffsets(Set (topic)).size === 5 )
1180+
1181+ (31 to 35 ).map { num =>
1182+ (num - 31 , (num.toString, Seq ((" a" , " b" .getBytes(UTF_8 )), (" c" , " d" .getBytes(UTF_8 )))))
1183+ }.foreach { rec => testUtils.sendMessage(topic, rec._2, Some (rec._1)) }
1184+
1185+ val kafka = spark
1186+ .readStream
1187+ .format(" kafka" )
1188+ .option(" kafka.bootstrap.servers" , testUtils.brokerAddress)
1189+ .option(" kafka.metadata.max.age.ms" , " 1" )
1190+ .option(" subscribePattern" , topic)
1191+ .option(" startingOffsets" , " earliest" )
1192+ .load()
1193+
1194+ val query = kafka.dropDuplicates()
1195+ .selectExpr(" CAST(key AS STRING)" , " CAST(value AS STRING)" )
1196+ .as[(String , String )]
1197+ .map(kv => kv._2.toInt + 1 )
1198+
1199+ val resourceUri = this .getClass.getResource(
1200+ " /structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/" ).toURI
1201+
1202+ val checkpointDir = Utils .createTempDir().getCanonicalFile
1203+ // Copy the checkpoint to a temp dir to prevent changes to the original.
1204+ // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
1205+ FileUtils .copyDirectory(new File (resourceUri), checkpointDir)
1206+
1207+ testStream(query)(
1208+ StartStream (checkpointLocation = checkpointDir.getAbsolutePath),
1209+ /*
1210+ Note: The checkpoint was generated using the following input in Spark version 2.4.3
1211+ testUtils.createTopic(topic, partitions = 5, overwrite = true)
1212+
1213+ testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0))
1214+ testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1))
1215+ testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2))
1216+ testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3))
1217+ testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4))
1218+ */
1219+ makeSureGetOffsetCalled,
1220+ CheckNewAnswer (32 , 33 , 34 , 35 , 36 )
1221+ )
1222+ }
11651223}
11661224
11671225abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
@@ -1414,7 +1472,9 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
14141472 val now = System .currentTimeMillis()
14151473 val topic = newTopic()
14161474 testUtils.createTopic(newTopic(), partitions = 1 )
1417- testUtils.sendMessages(topic, Array (1 ).map(_.toString))
1475+ testUtils.sendMessage(
1476+ topic, (" 1" , Seq ((" a" , " b" .getBytes(UTF_8 )), (" c" , " d" .getBytes(UTF_8 )))), None
1477+ )
14181478
14191479 val kafka = spark
14201480 .readStream
@@ -1423,6 +1483,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
14231483 .option(" kafka.metadata.max.age.ms" , " 1" )
14241484 .option(" startingOffsets" , s " earliest " )
14251485 .option(" subscribe" , topic)
1486+ .option(" includeHeaders" , " true" )
14261487 .load()
14271488
14281489 val query = kafka
@@ -1445,6 +1506,21 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
14451506 // producer. So here we just use a low bound to make sure the internal conversion works.
14461507 assert(row.getAs[java.sql.Timestamp ](" timestamp" ).getTime >= now, s " Unexpected results: $row" )
14471508 assert(row.getAs[Int ](" timestampType" ) === 0 , s " Unexpected results: $row" )
1509+
1510+ def checkHeader (row : Row , expected : Seq [(String , Array [Byte ])]): Unit = {
1511+ // array<struct<key:string,value:binary>>
1512+ val headers = row.getList[Row ](row.fieldIndex(" headers" )).asScala
1513+ assert(headers.length === expected.length)
1514+
1515+ (0 until expected.length).foreach { idx =>
1516+ val key = headers(idx).getAs[String ](" key" )
1517+ val value = headers(idx).getAs[Array [Byte ]](" value" )
1518+ assert(key === expected(idx)._1)
1519+ assert(value === expected(idx)._2)
1520+ }
1521+ }
1522+
1523+ checkHeader(row, Seq ((" a" , " b" .getBytes(UTF_8 )), (" c" , " d" .getBytes(UTF_8 ))))
14481524 query.stop()
14491525 }
14501526
0 commit comments