Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ private[sql] trait CacheManager {
CachedData(
planToCache,
InMemoryRelation(
useCompression,
columnBatchSize,
conf.useCompression,
conf.columnBatchSize,
storageLevel,
query.queryExecution.executedPlan,
tableName))
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private[spark] object SQLConf {
*
* SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads).
*/
private[sql] trait SQLConf {
private[sql] class SQLConf {
import SQLConf._

/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
Expand Down
40 changes: 33 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql

import java.util.Properties

import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

Expand Down Expand Up @@ -49,14 +52,37 @@ import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser
@AlphaComponent
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with SQLConf
with CacheManager
with ExpressionConversions
with UDFRegistration
with Serializable {

self =>

// Note that this is a lazy val so we can override the default value in subclasses.
private[sql] lazy val conf: SQLConf = new SQLConf

/** Set Spark SQL configuration properties. */
def setConf(props: Properties): Unit = conf.setConf(props)

/** Set the given Spark SQL configuration property. */
def setConf(key: String, value: String): Unit = conf.setConf(key, value)

/** Return the value of Spark SQL configuration property for the given key. */
def getConf(key: String): String = conf.getConf(key)

/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
* yet, return `defaultValue`.
*/
def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue)

/**
* Return all the configuration properties that have been set (i.e. not the default).
* This creates a new copy of the config properties in the form of a Map.
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs

@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)

Expand Down Expand Up @@ -212,7 +238,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(
JsonRDD.nullTypeToStringType(
Expand All @@ -226,7 +252,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
Expand Down Expand Up @@ -299,10 +325,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def sql(sqlText: String): SchemaRDD = {
if (dialect == "sql") {
if (conf.dialect == "sql") {
new SchemaRDD(this, parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: $dialect")
sys.error(s"Unsupported SQL dialect: ${conf.dialect}")
}
}

Expand All @@ -323,9 +349,9 @@ class SQLContext(@transient val sparkContext: SparkContext)

val sqlContext: SQLContext = self

def codegenEnabled = self.codegenEnabled
def codegenEnabled = self.conf.codegenEnabled

def numPartitions = self.numShufflePartitions
def numPartitions = self.conf.numShufflePartitions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nit, but you probably don't need the self anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe more future proof and slightly more clear? Since it is a small addition (just "self.").


def strategies: Seq[Strategy] =
extraStrategies ++ (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
* @group userf
*/
def sql(sqlText: String): JavaSchemaRDD = {
if (sqlContext.dialect == "sql") {
if (sqlContext.conf.dialect == "sql") {
new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: $sqlContext.dialect")
Expand Down Expand Up @@ -164,7 +164,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
* It goes through the entire dataset once to determine the schema.
*/
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord
val appliedScalaSchema =
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord))
Expand All @@ -182,7 +182,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
*/
@Experimental
def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = {
val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord
val appliedScalaSchema =
Option(asScalaDataType(schema)).getOrElse(
JsonRDD.nullTypeToStringType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private[sql] case class InMemoryRelation(
if (batchStats.value.isEmpty) {
// Underlying columnar RDD hasn't been materialized, no useful statistics information
// available, return the default statistics.
Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)
} else {
// Underlying columnar RDD has been materialized, required information has also been collected
// via the `batchStats` accumulator, compute the final statistics, and update `_statistics`.
Expand Down Expand Up @@ -233,7 +233,7 @@ private[sql] case class InMemoryColumnarTableScan(
val readPartitions = sparkContext.accumulator(0)
val readBatches = sparkContext.accumulator(0)

private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning

override def execute() = {
readPartitions.setValue(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
*/
private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions = sqlContext.numShufflePartitions
def numPartitions = sqlContext.conf.numShufflePartitions

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont
@transient override lazy val statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes)
)
}

Expand Down Expand Up @@ -106,6 +106,6 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ
@transient override lazy val statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
// sqlContext will be null when we are being deserialized on the slaves. In this instance
// the value of codegenEnabled will be set by the desserializer after the constructor has run.
val codegenEnabled: Boolean = if (sqlContext != null) {
sqlContext.codegenEnabled
sqlContext.conf.codegenEnabled
} else {
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
Expand Down Expand Up @@ -80,13 +80,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
Expand Down Expand Up @@ -214,7 +214,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sqlContext.parquetFilterPushDown) {
if (sqlContext.conf.parquetFilterPushDown) {
(predicates: Seq[Expression]) => {
// Note: filters cannot be pushed down to Parquet if they contain more complex
// expressions than simple "Attribute cmp Literal" comparisons. Here we remove all
Expand All @@ -236,7 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
ParquetTableScan(
_,
relation,
if (sqlContext.parquetFilterPushDown) filters else Nil)) :: Nil
if (sqlContext.conf.parquetFilterPushDown) filters else Nil)) :: Nil

case _ => Nil
}
Expand Down Expand Up @@ -269,7 +269,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// This sort only sorts tuples within a partition. Its requiredDistribution will be
// an UnspecifiedDistribution.
execution.Sort(sortExprs, global = false, planLater(child)) :: Nil
case logical.Sort(sortExprs, global, child) if sqlContext.externalSortEnabled =>
case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled =>
execution.ExternalSort(sortExprs, global, planLater(child)):: Nil
case logical.Sort(sortExprs, global, child) =>
execution.Sort(sortExprs, global, planLater(child)):: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ case class SetCommand(
logWarning(
s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.numShufflePartitions}"))
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.conf.numShufflePartitions}"))

// Queries a single property.
case Some((key, None)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class BroadcastHashJoin(
extends BinaryNode with HashJoin {

val timeout = {
val timeoutValue = sqlContext.broadcastTimeout
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
Duration.Inf
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ private[sql] case class JSONRelation(
JsonRDD.inferSchema(
baseRDD,
samplingRatio,
sqlContext.columnNameOfCorruptRecord)))
sqlContext.conf.columnNameOfCorruptRecord)))

override def buildScan() =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord)
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private[sql] case class ParquetRelation(
ParquetTypesConverter.readSchemaFromFile(
new Path(path.split(",").head),
conf,
sqlContext.isParquetBinaryAsString)
sqlContext.conf.isParquetBinaryAsString)

override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]

Expand All @@ -78,7 +78,7 @@ private[sql] case class ParquetRelation(
}

// TODO: Use data from the footers.
override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes)
override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes)
}

private[sql] object ParquetRelation {
Expand Down Expand Up @@ -161,7 +161,8 @@ private[sql] object ParquetRelation {
sqlContext: SQLContext): ParquetRelation = {
val path = checkPath(pathString, allowExisting, conf)
conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse(
sqlContext.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name())
sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED)
.name())
ParquetRelation.enableLogForwarding()
ParquetTypesConverter.writeMetaData(attributes, path, conf)
new ParquetRelation(path.toString, Some(conf), sqlContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait ParquetTest {
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => setConf(key, value)
case (key, None) => unsetConf(key)
case (key, None) => conf.unsetConf(key)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))
sqlContext.conf.isParquetBinaryAsString))

val dataIncludesKey =
partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true)
Expand Down Expand Up @@ -198,7 +198,7 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
predicates
.reduceOption(And)
.flatMap(ParquetFilters.createFilter)
.filter(_ => sqlContext.parquetFilterPushDown)
.filter(_ => sqlContext.conf.parquetFilterPushDown)
.foreach(ParquetInputFormat.setFilterPredicate(jobConf, _))

def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ abstract class BaseRelation {
* large to broadcast. This method will be called multiple times during query planning
* and thus should not perform expensive operations for each invocation.
*/
def sizeInBytes = sqlContext.defaultSizeInBytes
def sizeInBytes = sqlContext.conf.defaultSizeInBytes
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object TestSQLContext
new SparkConf().set("spark.sql.testkey", "true"))) {

/** Fewer partitions to speed up testing. */
override private[spark] def numShufflePartitions: Int =
getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
private[sql] override lazy val conf: SQLConf = new SQLConf {
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("broadcasted left semi join operator selection") {
clearCache()
sql("CACHE TABLE testData")
val tmp = autoBroadcastJoinThreshold
val tmp = conf.autoBroadcastJoinThreshold

sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
Seq(
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SQLConfSuite extends QueryTest with FunSuiteLike {
}

test("programmatic ways of basic setting and getting") {
clear()
conf.clear()
assert(getAllConfs.size === 0)

setConf(testKey, testVal)
Expand All @@ -51,11 +51,11 @@ class SQLConfSuite extends QueryTest with FunSuiteLike {
assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal)
assert(TestSQLContext.getAllConfs.contains(testKey))

clear()
conf.clear()
}

test("parse SQL set commands") {
clear()
conf.clear()
sql(s"set $testKey=$testVal")
assert(getConf(testKey, testVal + "_") == testVal)
assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal)
Expand All @@ -73,11 +73,11 @@ class SQLConfSuite extends QueryTest with FunSuiteLike {
sql(s"set $key=")
assert(getConf(key, "0") == "")

clear()
conf.clear()
}

test("deprecated property") {
clear()
conf.clear()
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10")
}
Expand Down
Loading