Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.streaming.StreamingRelation
Expand Down Expand Up @@ -122,12 +122,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 1.4.0
*/
def load(): DataFrame = {
val resolved = ResolvedDataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
provider = source,
options = extraOptions.toMap)
DataFrame(sqlContext, LogicalRelation(resolved.relation))
val dataSource =
DataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
DataFrame(sqlContext, LogicalRelation(dataSource.resolveRelation()))
}

/**
Expand All @@ -152,12 +153,12 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
sqlContext.emptyDataFrame
} else {
sqlContext.baseRelationToDataFrame(
ResolvedDataSource.apply(
DataSource.apply(
sqlContext,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
provider = source,
options = extraOptions.toMap).relation)
className = source,
options = extraOptions.toMap).resolveRelation())
}
}

Expand All @@ -168,12 +169,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 2.0.0
*/
def stream(): DataFrame = {
val resolved = ResolvedDataSource.createSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
providerName = source,
options = extraOptions.toMap)
DataFrame(sqlContext, StreamingRelation(resolved))
val dataSource =
DataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
DataFrame(sqlContext, StreamingRelation(dataSource.createSource()))
}

/**
Expand Down
29 changes: 15 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.sources.HadoopFsRelation
Expand Down Expand Up @@ -195,14 +195,14 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
def save(): Unit = {
assertNotBucketed()
ResolvedDataSource(
val dataSource = DataSource(
df.sqlContext,
source,
partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]),
getBucketSpec,
mode,
extraOptions.toMap,
df)
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec,
options = extraOptions.toMap)

dataSource.write(mode, df)
}

/**
Expand Down Expand Up @@ -235,14 +235,15 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def stream(): ContinuousQuery = {
val sink = ResolvedDataSource.createSink(
df.sqlContext,
source,
extraOptions.toMap,
normalizedParCols.getOrElse(Nil))
val dataSource =
DataSource(
df.sqlContext,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))

df.sqlContext.continuousQueryManager.startQuery(
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, sink)
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,39 @@ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
import org.apache.spark.util.Utils

case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)

/**
* Responsible for taking a description of a datasource (either from
* [[org.apache.spark.sql.DataFrameReader]], or a metastore) and converting it into a logical
* relation that can be used in a query plan.
* The main class responsible for representing a pluggable Data Source in Spark SQL. In addition to
* acting as the canonical set of parameters that can describe a Data Source, this class is used to
* resolve a description to a concrete implementation that can be used in a query plan
* (either batch or streaming) or to write out data using an external library.
*
* From an end user's perspective a DataSource description can be created explicitly using
* [[org.apache.spark.sql.DataFrameReader]] or CREATE TABLE USING DDL. Additionally, this class is
* used when resolving a description from a metastore to a concrete implementation.
*
* Many of the arguments to this class are optional, though depending on the specific API being used
* these optional arguments might be filled in during resolution using either inference or external
* metadata. For example, when reading a partitioned table from a file system, partition columns
* will be inferred from the directory layout even if they are not specified.
*
* @param paths A list of file system paths that hold data. These will be globbed before and
* qualified. This option only works when reading from a [[FileFormat]].
* @param userSpecifiedSchema An optional specification of the schema of the data. When present
* we skip attempting to infer the schema.
* @param partitionColumns A list of column names that the relation is partitioned by. When this
* list is empty, the relation is unpartitioned.
* @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data.
*/
object ResolvedDataSource extends Logging {
case class DataSource(
sqlContext: SQLContext,
className: String,
paths: Seq[String] = Nil,
userSpecifiedSchema: Option[StructType] = None,
partitionColumns: Seq[String] = Seq.empty,
bucketSpec: Option[BucketSpec] = None,
options: Map[String, String] = Map.empty) extends Logging {

lazy val providingClass: Class[_] = lookupDataSource(className)

/** A map to maintain backward compatibility in case we move data sources around. */
private val backwardCompatibilityMap = Map(
Expand All @@ -54,7 +79,7 @@ object ResolvedDataSource extends Logging {
)

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider0: String): Class[_] = {
private def lookupDataSource(provider0: String): Class[_] = {
val provider = backwardCompatibilityMap.getOrElse(provider0, provider0)
val provider2 = s"$provider.DefaultSource"
val loader = Utils.getContextOrSparkClassLoader
Expand Down Expand Up @@ -96,15 +121,11 @@ object ResolvedDataSource extends Logging {
}
}

// TODO: Combine with apply?
def createSource(
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
providerName: String,
options: Map[String, String]): Source = {
val provider = lookupDataSource(providerName).newInstance() match {
/** Returns a source that can be used to continually read data. */
def createSource(): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
s.createSource(sqlContext, userSpecifiedSchema, providerName, options)
s.createSource(sqlContext, userSpecifiedSchema, className, options)

case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
Expand Down Expand Up @@ -135,53 +156,38 @@ object ResolvedDataSource extends Logging {
new DataFrame(
sqlContext,
LogicalRelation(
apply(
DataSource(
sqlContext,
paths = files,
userSpecifiedSchema = Some(dataSchema),
provider = providerName,
options = options.filterKeys(_ != "path")).relation))
className = className,
options = options.filterKeys(_ != "path")).resolveRelation()))
}

new FileStreamSource(
sqlContext, metadataPath, path, Some(dataSchema), providerName, dataFrameBuilder)
sqlContext, metadataPath, path, Some(dataSchema), className, dataFrameBuilder)
case _ =>
throw new UnsupportedOperationException(
s"Data source $providerName does not support streamed reading")
s"Data source $className does not support streamed reading")
}

provider
}

def createSink(
sqlContext: SQLContext,
providerName: String,
options: Map[String, String],
partitionColumns: Seq[String]): Sink = {
val provider = lookupDataSource(providerName).newInstance() match {
/** Returns a sink that can be used to continually write data. */
def createSink(): Sink = {
val datasourceClass = providingClass.newInstance() match {
case s: StreamSinkProvider => s
case _ =>
throw new UnsupportedOperationException(
s"Data source $providerName does not support streamed writing")
s"Data source $className does not support streamed writing")
}

provider.createSink(sqlContext, options, partitionColumns)
datasourceClass.createSink(sqlContext, options, partitionColumns)
}

/** Create a [[ResolvedDataSource]] for reading data in. */
def apply(
sqlContext: SQLContext,
paths: Seq[String] = Nil,
userSpecifiedSchema: Option[StructType] = None,
partitionColumns: Array[String] = Array.empty,
bucketSpec: Option[BucketSpec] = None,
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
def className: String = clazz.getCanonicalName

/** Create a resolved [[BaseRelation]] that can be used to read data from this [[DataSource]] */
def resolveRelation(): BaseRelation = {
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val relation = (clazz.newInstance(), userSpecifiedSchema) match {
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema)
Expand Down Expand Up @@ -238,43 +244,19 @@ object ResolvedDataSource extends Logging {
throw new AnalysisException(
s"$className is not a valid Spark SQL Data Source.")
}
new ResolvedDataSource(clazz, relation)
}

def partitionColumnsSchema(
schema: StructType,
partitionColumns: Array[String],
caseSensitive: Boolean): StructType = {
val equality = columnNameEquality(caseSensitive)
StructType(partitionColumns.map { col =>
schema.find(f => equality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema $schema")
}
}).asNullable
relation
}

private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = {
if (caseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
}

/** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */
def apply(
sqlContext: SQLContext,
provider: String,
partitionColumns: Array[String],
bucketSpec: Option[BucketSpec],
/** Writes the give [[DataFrame]] out to this [[DataSource]]. */
def write(
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
data: DataFrame): BaseRelation = {
if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
val clazz: Class[_] = lookupDataSource(provider)
clazz.newInstance() match {

providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sqlContext, mode, options, data)
case format: FileFormat =>
Expand All @@ -295,27 +277,28 @@ object ResolvedDataSource extends Logging {
PartitioningUtils.validatePartitionColumnDataTypes(
data.schema, partitionColumns, caseSensitive)

val equality = columnNameEquality(caseSensitive)
val equality =
if (sqlContext.conf.caseSensitiveAnalysis) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}

val dataSchema = StructType(
data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))

// If we are appending to a table that already exists, make sure the partitioning matches
// up. If we fail to load the table for whatever reason, ignore the check.
if (mode == SaveMode.Append) {
val existingPartitionColumnSet = try {
val resolved = apply(
sqlContext,
userSpecifiedSchema = Some(data.schema.asNullable),
provider = provider,
options = options)

Some(resolved.relation
.asInstanceOf[HadoopFsRelation]
.location
.partitionSpec(None)
.partitionColumns
.fieldNames
.toSet)
Some(
resolveRelation()
.asInstanceOf[HadoopFsRelation]
.location
.partitionSpec(None)
.partitionColumns
.fieldNames
.toSet)
} catch {
case e: Exception =>
None
Expand Down Expand Up @@ -346,15 +329,10 @@ object ResolvedDataSource extends Logging {
sqlContext.executePlan(plan).toRdd

case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}

apply(
sqlContext,
userSpecifiedSchema = Some(data.schema.asNullable),
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
provider = provider,
options = options)
// We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
}
}
Loading