Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution

import java.sql.{Timestamp, Date}

import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.{ShuffleDependency, SparkFunSuite}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}

class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
Expand Down Expand Up @@ -74,11 +74,13 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
var numShufflePartitions: Int = _
var useSerializer2: Boolean = _

protected lazy val ctx = TestSQLContext

override def beforeAll(): Unit = {
numShufflePartitions = conf.numShufflePartitions
useSerializer2 = conf.useSqlSerializer2
numShufflePartitions = ctx.conf.numShufflePartitions
useSerializer2 = ctx.conf.useSqlSerializer2

sql("set spark.sql.useSerializer2=true")
ctx.sql("set spark.sql.useSerializer2=true")

val supportedTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
Expand All @@ -94,7 +96,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll

// Create a RDD with all data types supported by SparkSqlSerializer2.
val rdd =
sparkContext.parallelize((1 to 1000), 10).map { i =>
ctx.sparkContext.parallelize((1 to 1000), 10).map { i =>
Row(
s"str${i}: test serializer2.",
s"binary${i}: test serializer2.".getBytes("UTF-8"),
Expand All @@ -112,15 +114,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
new Timestamp(i))
}

createDataFrame(rdd, schema).registerTempTable("shuffle")
ctx.createDataFrame(rdd, schema).registerTempTable("shuffle")

super.beforeAll()
}

override def afterAll(): Unit = {
dropTempTable("shuffle")
sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
sql(s"set spark.sql.useSerializer2=$useSerializer2")
ctx.dropTempTable("shuffle")
ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2")
super.afterAll()
}

Expand All @@ -141,32 +143,31 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
}

test("key schema and value schema are not nulls") {
val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
table("shuffle").collect())
ctx.table("shuffle").collect())
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
val df = ctx.sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}

test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
assert(
df.map(r => r.getString(0)).collect().toSeq ===
table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
assert(df.map(r => r.getString(0)).collect().toSeq ===
ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
}

test("no map output field") {
val df = sql(s"SELECT 1 + 1 FROM shuffle")
val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
}
}
Expand All @@ -177,8 +178,8 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
super.beforeAll()
// Sort merge will not be triggered.
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
}
}

Expand All @@ -189,7 +190,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
super.beforeAll()
// To trigger the sort merge.
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ import org.apache.spark.util.Utils

class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {

import caseInsensitiveContext._
import caseInsensitiveContext.sql

private lazy val sparkContext = caseInsensitiveContext.sparkContext

var path: File = null

override def beforeAll(): Unit = {
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
read.json(rdd).registerTempTable("jt")
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
}

override def afterAll(): Unit = {
dropTempTable("jt")
caseInsensitiveContext.dropTempTable("jt")
}

after {
Expand All @@ -59,7 +61,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a, b FROM jsonTable"),
sql("SELECT a, b FROM jt").collect())

dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jsonTable")
}

test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") {
Expand Down Expand Up @@ -129,7 +131,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT * FROM jsonTable"),
sql("SELECT a * 4 FROM jt").collect())

dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jsonTable")
// Explicitly delete the data.
if (path.exists()) Utils.deleteRecursively(path)

Expand All @@ -147,7 +149,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT * FROM jsonTable"),
sql("SELECT b FROM jt").collect())

dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jsonTable")
}

test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,18 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
}

class DDLTestSuite extends DataSourceTest {
import caseInsensitiveContext._

before {
sql(
"""
|CREATE TEMPORARY TABLE ddlPeople
|USING org.apache.spark.sql.sources.DDLScanSource
|OPTIONS (
| From '1',
| To '10',
| Table 'test1'
|)
""".stripMargin)
caseInsensitiveContext.sql(
"""
|CREATE TEMPORARY TABLE ddlPeople
|USING org.apache.spark.sql.sources.DDLScanSource
|OPTIONS (
| From '1',
| To '10',
| Table 'test1'
|)
""".stripMargin)
}

sqlTest(
Expand All @@ -100,7 +99,8 @@ class DDLTestSuite extends DataSourceTest {
))

test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output
val attributes = caseInsensitiveContext.sql("describe ddlPeople")
.queryExecution.executedPlan.output
assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
assert(attributes.map(_.dataType).toSet === Set(StringType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.sql.sources

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfter


abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
// We want to test some edge cases.
implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext)
protected implicit lazy val caseInsensitiveContext = {
val ctx = new SQLContext(TestSQLContext.sparkContext)
ctx.setConf(SQLConf.CASE_SENSITIVE, "false")
ctx
}

caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false")
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ object FiltersPushed {

class FilteredScanSuite extends DataSourceTest {

import caseInsensitiveContext._
import caseInsensitiveContext.sql

before {
sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ import org.apache.spark.util.Utils

class InsertSuite extends DataSourceTest with BeforeAndAfterAll {

import caseInsensitiveContext._
import caseInsensitiveContext.sql

private lazy val sparkContext = caseInsensitiveContext.sparkContext

var path: File = null

override def beforeAll: Unit = {
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
read.json(rdd).registerTempTable("jt")
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
sql(
s"""
|CREATE TEMPORARY TABLE jsonTable (a int, b string)
Expand All @@ -45,8 +47,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}

override def afterAll: Unit = {
dropTempTable("jsonTable")
dropTempTable("jt")
caseInsensitiveContext.dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jt")
Utils.deleteRecursively(path)
}

Expand Down Expand Up @@ -109,7 +111,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {

// Writing the table to less part files.
val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5)
read.json(rdd1).registerTempTable("jt1")
caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1")
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1
Expand All @@ -121,7 +123,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {

// Writing the table to more part files.
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10)
read.json(rdd2).registerTempTable("jt2")
caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2")
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2
Expand All @@ -140,8 +142,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
(1 to 10).map(i => Row(i * 10, s"str$i"))
)

dropTempTable("jt1")
dropTempTable("jt2")
caseInsensitiveContext.dropTempTable("jt1")
caseInsensitiveContext.dropTempTable("jt2")
}

test("INSERT INTO not supported for JSONRelation for now") {
Expand All @@ -154,13 +156,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}

test("save directly to the path of a JSON table") {
table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString)
caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b")
.write.mode(SaveMode.Overwrite).json(path.toString)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i * 5, s"str$i"))
)

table("jt").write.mode(SaveMode.Overwrite).json(path.toString)
caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
Expand All @@ -181,7 +184,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {

test("Caching") {
// Cached Query Execution
cacheTable("jsonTable")
caseInsensitiveContext.cacheTable("jsonTable")
assertCached(sql("SELECT * FROM jsonTable"))
checkAnswer(
sql("SELECT * FROM jsonTable"),
Expand Down Expand Up @@ -220,7 +223,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a * 2, b FROM jt").collect())

// Verify uncaching
uncacheTable("jsonTable")
caseInsensitiveContext.uncacheTable("jsonTable")
assertCached(sql("SELECT * FROM jsonTable"), 0)
}

Expand Down Expand Up @@ -251,6 +254,6 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
"It is not allowed to insert into a table that is not an InsertableRelation."
)

dropTempTable("oneToTen")
caseInsensitiveContext.dropTempTable("oneToTen")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
}

class PrunedScanSuite extends DataSourceTest {
import caseInsensitiveContext._

before {
sql(
caseInsensitiveContext.sql(
"""
|CREATE TEMPORARY TABLE oneToTenPruned
|USING org.apache.spark.sql.sources.PrunedScanSource
Expand Down Expand Up @@ -115,7 +114,7 @@ class PrunedScanSuite extends DataSourceTest {

def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
val queryExecution = sql(sqlString).queryExecution
val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
val rawPlan = queryExecution.executedPlan.collect {
case p: execution.PhysicalRDD => p
} match {
Expand Down
Loading