Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
org.apache.spark.sql.execution.streaming.RateSourceProvider
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a redirection in the DataSource.backwardCompatibilityMap for this?

org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -563,6 +564,7 @@ object DataSource extends Logging {
val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat"
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
val socket = classOf[TextSocketSourceProvider].getCanonicalName

Map(
"org.apache.spark.sql.jdbc" -> jdbc,
Expand All @@ -583,7 +585,8 @@ object DataSource extends Logging {
"org.apache.spark.sql.execution.datasources.orc" -> nativeOrc,
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
"org.apache.spark.ml.source.libsvm" -> libsvm,
"com.databricks.spark.csv" -> csv
"com.databricks.spark.csv" -> csv,
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a test for this!

)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,47 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming
package org.apache.spark.sql.execution.streaming.sources

import java.io.{BufferedReader, InputStreamReader, IOException}
import java.net.Socket
import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.util.{Calendar, Locale}
import java.util.{Calendar, List => JList, Locale, Optional}
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.util.{Failure, Success, Try}

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.execution.streaming.LongOffset
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String


object TextSocketSource {
object TextSocketMicroBatchReader {
val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
StructField("timestamp", TimestampType) :: Nil)
val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
}

/**
* A source that reads text lines through a TCP socket, designed only for tutorials and debugging.
* This source will *not* work in production applications due to multiple reasons, including no
* support for fault recovery and keeping all of the text read in memory forever.
* A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: tutorials -> testing (i know it was like that, but lets fix it since we are changing it anyway)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tutorials is correct here; see e.g. StructuredSessionization.scala

* debugging. This MicroBatchReader will *not* work in production applications due to multiple
* reasons, including no support for fault recovery.
*/
class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext)
extends Source with Logging {
class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging {

private var startOffset: Offset = _
private var endOffset: Offset = _

private val host: String = options.get("host").get()
private val port: Int = options.get("port").get().toInt

@GuardedBy("this")
private var socket: Socket = null
Expand All @@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
protected val batches = new ListBuffer[(String, Timestamp)]
private val batches = new ListBuffer[(String, Timestamp)]

@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
private var currentOffset: LongOffset = LongOffset(-1L)

@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
private var lastOffsetCommitted: LongOffset = LongOffset(-1L)

initialize()

/** This method is only used for unit test */
private[sources] def getCurrentOffset(): LongOffset = synchronized {
currentOffset.copy()
}

private def initialize(): Unit = synchronized {
socket = new Socket(host, port)
val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
Expand All @@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
logWarning(s"Stream closed by $host:$port")
return
}
TextSocketSource.this.synchronized {
TextSocketMicroBatchReader.this.synchronized {
val newData = (line,
Timestamp.valueOf(
TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime()))
)
currentOffset = currentOffset + 1
TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime()))
)
currentOffset += 1
batches.append(newData)
}
}
Expand All @@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
readThread.start()
}

/** Returns the schema of the data from this source */
override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP
else TextSocketSource.SCHEMA_REGULAR
override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized {
startOffset = start.orElse(LongOffset(-1L))
endOffset = end.orElse(currentOffset)
}

override def getStartOffset(): Offset = {
Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set"))
}

override def getEndOffset(): Offset = {
Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set"))
}

override def deserializeOffset(json: String): Offset = {
LongOffset(json.toLong)
}

override def getOffset: Option[Offset] = synchronized {
if (currentOffset.offset == -1) {
None
override def readSchema(): StructType = {
if (options.getBoolean("includeTimestamp", false)) {
TextSocketMicroBatchReader.SCHEMA_TIMESTAMP
} else {
Some(currentOffset)
TextSocketMicroBatchReader.SCHEMA_REGULAR
}
}

/** Returns the data that is between the offsets (`start`, `end`]. */
override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
val startOrdinal =
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
assert(startOffset != null && endOffset != null,
"start offset and end offset should already be set before create read tasks.")

val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1
val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1

// Internal buffer only holds the batches after lastOffsetCommitted
val rawList = synchronized {
Expand All @@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
batches.slice(sliceStart, sliceEnd)
}

val rdd = sqlContext.sparkContext
.parallelize(rawList)
.map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) }
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
assert(SparkSession.getActiveSession.isDefined)
val spark = SparkSession.getActiveSession.get
val numPartitions = spark.sparkContext.defaultParallelism

val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)])
rawList.zipWithIndex.foreach { case (r, idx) =>
slices(idx % numPartitions).append(r)
}

(0 until numPartitions).map { i =>
val slice = slices(i)
new DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] = new DataReader[Row] {
private var currentIdx = -1

override def next(): Boolean = {
currentIdx += 1
currentIdx < slice.size
}

override def get(): Row = {
Row(slice(currentIdx)._1, slice(currentIdx)._2)
}

override def close(): Unit = {}
}
}
}.toList.asJava
}

override def commit(end: Offset): Unit = synchronized {
Expand Down Expand Up @@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
}
}

override def toString: String = s"TextSocketSource[host: $host, port: $port]"
override def toString: String = s"TextSocket[host: $host, port: $port]"
}

class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
private def parseIncludeTimestamp(params: Map[String, String]): Boolean = {
Try(params.getOrElse("includeTimestamp", "false").toBoolean) match {
case Success(bool) => bool
case Failure(_) =>
throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
}
}
class TextSocketSourceProvider extends DataSourceV2
with MicroBatchReadSupport with DataSourceRegister with Logging {

/** Returns the name and schema of the source that can be used to continually read data. */
override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
private def checkParameters(params: DataSourceOptions): Unit = {
logWarning("The socket source should not be used for production applications! " +
"It does not support recovery.")
if (!parameters.contains("host")) {
if (!params.get("host").isPresent) {
throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
}
if (!parameters.contains("port")) {
if (!params.get("port").isPresent) {
throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
}
if (schema.nonEmpty) {
throw new AnalysisException("The socket source does not support a user-specified schema.")
Try {
params.get("includeTimestamp").orElse("false").toBoolean
} match {
case Success(_) =>
case Failure(_) =>
throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
}

val sourceSchema =
if (parseIncludeTimestamp(parameters)) {
TextSocketSource.SCHEMA_TIMESTAMP
} else {
TextSocketSource.SCHEMA_REGULAR
}
("textSocket", sourceSchema)
}

override def createSource(
sqlContext: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
val host = parameters("host")
val port = parameters("port").toInt
new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext)
override def createMicroBatchReader(
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceOptions): MicroBatchReader = {
checkParameters(options)
if (schema.isPresent) {
throw new AnalysisException("The socket source does not support a user-specified schema.")
}

new TextSocketMicroBatchReader(options)
}

/** String that represents the format that this data source provider uses. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -172,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
}
ds match {
case s: MicroBatchReadSupport =>
val tempReader = s.createMicroBatchReader(
Optional.ofNullable(userSpecifiedSchema.orNull),
Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
options)
var tempReader: MicroBatchReader = null
val schema = try {
tempReader = s.createMicroBatchReader(
Optional.ofNullable(userSpecifiedSchema.orNull),
Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
options)
tempReader.readSchema()
} finally {
// Stop tempReader to avoid side-effect thing
if (tempReader != null) {
tempReader.stop()
tempReader = null
}
}
Dataset.ofRows(
sparkSession,
StreamingRelationV2(
s, source, extraOptions.toMap,
tempReader.readSchema().toAttributes, v1Relation)(sparkSession))
schema.toAttributes, v1Relation)(sparkSession))
case s: ContinuousReadSupport =>
val tempReader = s.createContinuousReader(
Optional.ofNullable(userSpecifiedSchema.orNull),
Expand Down
Loading