Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import java.sql.{Date, Timestamp}

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
Expand All @@ -31,25 +29,29 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData

private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.{logicalPlanToSparkQuery, sql}

test("simple columnar query") {
val plan = executePlan(testData.logicalPlan).executedPlan
val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)

checkAnswer(scan, testData.collect().toSeq)
}

test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
.toDF().registerTempTable("sizeTst")
cacheTable("sizeTst")
ctx.cacheTable("sizeTst")
assert(
table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
conf.autoBroadcastJoinThreshold)
ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
ctx.conf.autoBroadcastJoinThreshold)
}

test("projection") {
val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)

checkAnswer(scan, testData.collect().map {
Expand All @@ -58,7 +60,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
}

test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = executePlan(testData.logicalPlan).executedPlan
val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)

checkAnswer(scan, testData.collect().toSeq)
Expand All @@ -70,7 +72,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))

cacheTable("repeatedData")
ctx.cacheTable("repeatedData")

checkAnswer(
sql("SELECT * FROM repeatedData"),
Expand All @@ -82,7 +84,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))

cacheTable("nullableRepeatedData")
ctx.cacheTable("nullableRepeatedData")

checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
Expand All @@ -94,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT time FROM timestamps"),
timestamps.collect().toSeq.map(Row.fromTuple))

cacheTable("timestamps")
ctx.cacheTable("timestamps")

checkAnswer(
sql("SELECT time FROM timestamps"),
Expand All @@ -106,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))

cacheTable("withEmptyParts")
ctx.cacheTable("withEmptyParts")

checkAnswer(
sql("SELECT * FROM withEmptyParts"),
Expand Down Expand Up @@ -155,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {

// Create a RDD for the schema
val rdd =
sparkContext.parallelize((1 to 100), 10).map { i =>
ctx.sparkContext.parallelize((1 to 100), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
Expand All @@ -175,18 +177,18 @@ class InMemoryColumnarQuerySuite extends QueryTest {
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan
val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
isCached("InMemoryCache_different_data_types"),
ctx.isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
table("InMemoryCache_different_data_types").collect())
dropTempTable("InMemoryCache_different_data_types")
ctx.table("InMemoryCache_different_data_types").collect())
ctx.dropTempTable("InMemoryCache_different_data_types")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,42 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._

class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning

private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._

private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning

override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10")

val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
}, 5).toDF()
pruningData.registerTempTable("pruningData")

// Enable in-memory partition pruning
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
// Enable in-memory table scan accumulators
setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}

override protected def afterAll(): Unit = {
setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}

before {
cacheTable("pruningData")
ctx.cacheTable("pruningData")
}

after {
uncacheTable("pruningData")
ctx.uncacheTable("pruningData")
}

// Comparisons
Expand Down Expand Up @@ -108,7 +110,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
expectedQueryResult: => Seq[Int]): Unit = {

test(query) {
val df = sql(query)
val df = ctx.sql(query)
val queryExecution = df.queryExecution

assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
Expand Down
45 changes: 22 additions & 23 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar, Properties}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import TestSQLContext._
import TestSQLContext.implicits._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
Expand All @@ -37,12 +35,16 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)

val testH2Dialect = new JdbcDialect {
def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
Some(StringType)
}

private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql

before {
Class.forName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
Expand Down Expand Up @@ -253,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}

test("Basic API") {
assert(TestSQLContext.read.jdbc(
assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}

test("Basic API with FetchSize") {
val properties = new Properties
properties.setProperty("fetchSize", "2")
assert(TestSQLContext.read.jdbc(
assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}

test("Partitioning via JDBCPartitioningInfo API") {
assert(
TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
.collect().length === 3)
}

test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
.collect().length === 3)
}

Expand Down Expand Up @@ -328,27 +330,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}

test("test DATE types") {
val rows = TestSQLContext.read.jdbc(
val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
}

test("test DATE types in cache") {
val rows =
TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().registerTempTable("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
}

test("test types for null value") {
val rows = TestSQLContext.read.jdbc(
val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
Expand Down Expand Up @@ -395,10 +396,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {

test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(
_.dataType != org.apache.spark.sql.types.StringType
).isEmpty)
val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
assert(rows(0).get(1).isInstanceOf[String])
Expand All @@ -419,7 +418,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {

test("Aggregated dialects") {
val agg = new AggregatedDialect(List(new JdbcDialect {
def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
if (sqlType % 2 == 0) {
Expand All @@ -430,8 +429,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
assert(agg.getCatalystType(0, "", 1, null) == Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) == Some(StringType))
assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
}

}
Loading