-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23097][SQL][SS] Migrate text socket source to V2 #20382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
32b49d5
39978c4
40df1c8
5ce6648
50c53e3
a224a1b
153cd43
70b2b48
f69f490
323e853
d0b1d8b
1073be4
6d38bed
762f1da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat | |
| import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat | ||
| import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat | ||
| import org.apache.spark.sql.execution.streaming._ | ||
| import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.sources._ | ||
| import org.apache.spark.sql.streaming.OutputMode | ||
|
|
@@ -563,6 +564,7 @@ object DataSource extends Logging { | |
| val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" | ||
| val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" | ||
| val nativeOrc = classOf[OrcFileFormat].getCanonicalName | ||
| val socket = classOf[TextSocketSourceProvider].getCanonicalName | ||
|
|
||
| Map( | ||
| "org.apache.spark.sql.jdbc" -> jdbc, | ||
|
|
@@ -583,7 +585,8 @@ object DataSource extends Logging { | |
| "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, | ||
| "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, | ||
| "org.apache.spark.ml.source.libsvm" -> libsvm, | ||
| "com.databricks.spark.csv" -> csv | ||
| "com.databricks.spark.csv" -> csv, | ||
| "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket | ||
|
||
| ) | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,40 +15,47 @@ | |
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.streaming | ||
| package org.apache.spark.sql.execution.streaming.sources | ||
|
|
||
| import java.io.{BufferedReader, InputStreamReader, IOException} | ||
| import java.net.Socket | ||
| import java.sql.Timestamp | ||
| import java.text.SimpleDateFormat | ||
| import java.util.{Calendar, Locale} | ||
| import java.util.{Calendar, List => JList, Locale, Optional} | ||
| import javax.annotation.concurrent.GuardedBy | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable.ListBuffer | ||
| import scala.util.{Failure, Success, Try} | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} | ||
| import org.apache.spark.sql.execution.streaming.LongOffset | ||
| import org.apache.spark.sql.sources.DataSourceRegister | ||
| import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} | ||
| import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} | ||
| import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} | ||
| import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
|
|
||
| object TextSocketSource { | ||
| object TextSocketMicroBatchReader { | ||
| val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) | ||
| val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: | ||
| StructField("timestamp", TimestampType) :: Nil) | ||
| val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) | ||
| } | ||
|
|
||
| /** | ||
| * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. | ||
| * This source will *not* work in production applications due to multiple reasons, including no | ||
| * support for fault recovery and keeping all of the text read in memory forever. | ||
| * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and | ||
|
||
| * debugging. This MicroBatchReader will *not* work in production applications due to multiple | ||
| * reasons, including no support for fault recovery. | ||
| */ | ||
| class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) | ||
| extends Source with Logging { | ||
| class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { | ||
|
|
||
| private var startOffset: Offset = _ | ||
| private var endOffset: Offset = _ | ||
|
|
||
| private val host: String = options.get("host").get() | ||
| private val port: Int = options.get("port").get().toInt | ||
|
|
||
| @GuardedBy("this") | ||
| private var socket: Socket = null | ||
|
|
@@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo | |
| * Stored in a ListBuffer to facilitate removing committed batches. | ||
| */ | ||
| @GuardedBy("this") | ||
| protected val batches = new ListBuffer[(String, Timestamp)] | ||
| private val batches = new ListBuffer[(String, Timestamp)] | ||
|
|
||
| @GuardedBy("this") | ||
| protected var currentOffset: LongOffset = new LongOffset(-1) | ||
| private var currentOffset: LongOffset = LongOffset(-1L) | ||
|
|
||
| @GuardedBy("this") | ||
| protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) | ||
| private var lastOffsetCommitted: LongOffset = LongOffset(-1L) | ||
|
|
||
| initialize() | ||
|
|
||
| /** This method is only used for unit test */ | ||
| private[sources] def getCurrentOffset(): LongOffset = synchronized { | ||
| currentOffset.copy() | ||
| } | ||
|
|
||
| private def initialize(): Unit = synchronized { | ||
| socket = new Socket(host, port) | ||
| val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) | ||
|
|
@@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo | |
| logWarning(s"Stream closed by $host:$port") | ||
| return | ||
| } | ||
| TextSocketSource.this.synchronized { | ||
| TextSocketMicroBatchReader.this.synchronized { | ||
| val newData = (line, | ||
| Timestamp.valueOf( | ||
| TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) | ||
| ) | ||
| currentOffset = currentOffset + 1 | ||
| TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) | ||
| ) | ||
| currentOffset += 1 | ||
| batches.append(newData) | ||
| } | ||
| } | ||
|
|
@@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo | |
| readThread.start() | ||
| } | ||
|
|
||
| /** Returns the schema of the data from this source */ | ||
| override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP | ||
| else TextSocketSource.SCHEMA_REGULAR | ||
| override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { | ||
| startOffset = start.orElse(LongOffset(-1L)) | ||
| endOffset = end.orElse(currentOffset) | ||
| } | ||
|
|
||
| override def getStartOffset(): Offset = { | ||
| Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) | ||
| } | ||
|
|
||
| override def getEndOffset(): Offset = { | ||
| Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) | ||
| } | ||
|
|
||
| override def deserializeOffset(json: String): Offset = { | ||
| LongOffset(json.toLong) | ||
| } | ||
|
|
||
| override def getOffset: Option[Offset] = synchronized { | ||
| if (currentOffset.offset == -1) { | ||
| None | ||
| override def readSchema(): StructType = { | ||
| if (options.getBoolean("includeTimestamp", false)) { | ||
| TextSocketMicroBatchReader.SCHEMA_TIMESTAMP | ||
| } else { | ||
| Some(currentOffset) | ||
| TextSocketMicroBatchReader.SCHEMA_REGULAR | ||
| } | ||
| } | ||
|
|
||
| /** Returns the data that is between the offsets (`start`, `end`]. */ | ||
| override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { | ||
| val startOrdinal = | ||
| start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 | ||
| val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 | ||
| override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { | ||
| assert(startOffset != null && endOffset != null, | ||
| "start offset and end offset should already be set before create read tasks.") | ||
|
|
||
| val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 | ||
| val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 | ||
|
|
||
| // Internal buffer only holds the batches after lastOffsetCommitted | ||
| val rawList = synchronized { | ||
|
|
@@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo | |
| batches.slice(sliceStart, sliceEnd) | ||
| } | ||
|
|
||
| val rdd = sqlContext.sparkContext | ||
| .parallelize(rawList) | ||
| .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } | ||
| sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) | ||
| assert(SparkSession.getActiveSession.isDefined) | ||
| val spark = SparkSession.getActiveSession.get | ||
| val numPartitions = spark.sparkContext.defaultParallelism | ||
|
|
||
| val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) | ||
| rawList.zipWithIndex.foreach { case (r, idx) => | ||
| slices(idx % numPartitions).append(r) | ||
| } | ||
|
|
||
| (0 until numPartitions).map { i => | ||
| val slice = slices(i) | ||
| new DataReaderFactory[Row] { | ||
| override def createDataReader(): DataReader[Row] = new DataReader[Row] { | ||
| private var currentIdx = -1 | ||
|
|
||
| override def next(): Boolean = { | ||
| currentIdx += 1 | ||
| currentIdx < slice.size | ||
| } | ||
|
|
||
| override def get(): Row = { | ||
| Row(slice(currentIdx)._1, slice(currentIdx)._2) | ||
| } | ||
|
|
||
| override def close(): Unit = {} | ||
| } | ||
| } | ||
| }.toList.asJava | ||
| } | ||
|
|
||
| override def commit(end: Offset): Unit = synchronized { | ||
|
|
@@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo | |
| } | ||
| } | ||
|
|
||
| override def toString: String = s"TextSocketSource[host: $host, port: $port]" | ||
| override def toString: String = s"TextSocket[host: $host, port: $port]" | ||
| } | ||
|
|
||
| class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { | ||
| private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { | ||
| Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { | ||
| case Success(bool) => bool | ||
| case Failure(_) => | ||
| throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") | ||
| } | ||
| } | ||
| class TextSocketSourceProvider extends DataSourceV2 | ||
| with MicroBatchReadSupport with DataSourceRegister with Logging { | ||
|
|
||
| /** Returns the name and schema of the source that can be used to continually read data. */ | ||
| override def sourceSchema( | ||
| sqlContext: SQLContext, | ||
| schema: Option[StructType], | ||
| providerName: String, | ||
| parameters: Map[String, String]): (String, StructType) = { | ||
| private def checkParameters(params: DataSourceOptions): Unit = { | ||
| logWarning("The socket source should not be used for production applications! " + | ||
| "It does not support recovery.") | ||
| if (!parameters.contains("host")) { | ||
| if (!params.get("host").isPresent) { | ||
| throw new AnalysisException("Set a host to read from with option(\"host\", ...).") | ||
| } | ||
| if (!parameters.contains("port")) { | ||
| if (!params.get("port").isPresent) { | ||
| throw new AnalysisException("Set a port to read from with option(\"port\", ...).") | ||
| } | ||
| if (schema.nonEmpty) { | ||
| throw new AnalysisException("The socket source does not support a user-specified schema.") | ||
| Try { | ||
| params.get("includeTimestamp").orElse("false").toBoolean | ||
| } match { | ||
| case Success(_) => | ||
| case Failure(_) => | ||
| throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") | ||
| } | ||
|
|
||
| val sourceSchema = | ||
| if (parseIncludeTimestamp(parameters)) { | ||
| TextSocketSource.SCHEMA_TIMESTAMP | ||
| } else { | ||
| TextSocketSource.SCHEMA_REGULAR | ||
| } | ||
| ("textSocket", sourceSchema) | ||
| } | ||
|
|
||
| override def createSource( | ||
| sqlContext: SQLContext, | ||
| metadataPath: String, | ||
| schema: Option[StructType], | ||
| providerName: String, | ||
| parameters: Map[String, String]): Source = { | ||
| val host = parameters("host") | ||
| val port = parameters("port").toInt | ||
| new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) | ||
| override def createMicroBatchReader( | ||
| schema: Optional[StructType], | ||
| checkpointLocation: String, | ||
| options: DataSourceOptions): MicroBatchReader = { | ||
| checkParameters(options) | ||
| if (schema.isPresent) { | ||
| throw new AnalysisException("The socket source does not support a user-specified schema.") | ||
| } | ||
|
|
||
| new TextSocketMicroBatchReader(options) | ||
| } | ||
|
|
||
| /** String that represents the format that this data source provider uses. */ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a redirection in the
DataSource.backwardCompatibilityMapfor this?