Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT

override def simpleString: String = s"array<${elementType.simpleString}>"

private[spark] override def asNullable: ArrayType =
override private[spark] def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)

override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || elementType.existsRecursively(f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def asNullable: DataType

/**
* Returns true if any `DataType` of this DataType tree satisfies the given function `f`.
*/
private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)

override private[sql] def defaultConcreteType: DataType = this

override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ case class MapType(

override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"

private[spark] override def asNullable: MapType =
override private[spark] def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)

override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}


/**
Expand Down Expand Up @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private[sql] def merge(that: StructType): StructType =
StructType.merge(this, that).asInstanceOf[StructType]

private[spark] override def asNullable: StructType = {
override private[spark] def asNullable: StructType = {
val newFields = fields.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(name, dataType.asNullable, nullable = true, metadata)
Expand All @@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(newFields)
}

override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || fields.exists(field => field.dataType.existsRecursively(f))
}

private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite {
}
}

test("existsRecursively") {
val struct = StructType(
StructField("a", LongType) ::
StructField("b", FloatType) :: Nil)
assert(struct.existsRecursively(_.isInstanceOf[LongType]))
assert(struct.existsRecursively(_.isInstanceOf[StructType]))
assert(!struct.existsRecursively(_.isInstanceOf[IntegerType]))

val mapType = MapType(struct, StringType)
assert(mapType.existsRecursively(_.isInstanceOf[LongType]))
assert(mapType.existsRecursively(_.isInstanceOf[StructType]))
assert(mapType.existsRecursively(_.isInstanceOf[StringType]))
assert(mapType.existsRecursively(_.isInstanceOf[MapType]))
assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType]))

val arrayType = ArrayType(mapType)
assert(arrayType.existsRecursively(_.isInstanceOf[LongType]))
assert(arrayType.existsRecursively(_.isInstanceOf[StructType]))
assert(arrayType.existsRecursively(_.isInstanceOf[StringType]))
assert(arrayType.existsRecursively(_.isInstanceOf[MapType]))
assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType]))
assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
}

def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier}
import org.apache.spark.sql.execution.{FileRelation, datasources}
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
import org.apache.spark.sql.execution.{FileRelation, datasources}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode}
Expand Down Expand Up @@ -86,9 +85,9 @@ private[hive] object HiveSerDe {
serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")))

val key = source.toLowerCase match {
case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet"
case _ if source.startsWith("org.apache.spark.sql.orc") => "orc"
case _ => source.toLowerCase
case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
case s if s.startsWith("org.apache.spark.sql.orc") => "orc"
case s => s
}

serdeMap.get(key)
Expand Down Expand Up @@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val hiveTable = (maybeSerDe, dataSource.relation) match {
case (Some(serde), relation: HadoopFsRelation)
if relation.paths.length == 1 && relation.partitionColumns.isEmpty =>
logInfo {
"Persisting data source relation with a single input path into Hive metastore in Hive " +
s"compatible format. Input path: ${relation.paths.head}"
// Hive ParquetSerDe doesn't support decimal type until 1.2.0.
val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet"))
val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType])

val hiveParquetSupportsDecimal = client.version match {
case org.apache.spark.sql.hive.client.hive.v1_2 => true
case _ => false
}

if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) {
// If Hive version is below 1.2.0, we cannot save Hive compatible schema to
// metastore when the file format is Parquet and the schema has DecimalType.
logWarning {
"Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " +
"specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " +
s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384."
}
newSparkSQLSpecificMetastoreTable()
} else {
logInfo {
"Persisting data source relation with a single input path into Hive metastore in " +
s"Hive compatible format. Input path: ${relation.paths.head}"
}
newHiveCompatibleMetastoreTable(relation, serde)
}
newHiveCompatibleMetastoreTable(relation, serde)

case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty =>
logWarning {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ private[hive] case class HiveTable(
*/
private[hive] trait ClientInterface {

/** Returns the Hive Version of this client. */
def version: HiveVersion

/** Returns the configuration for the given key in the current session. */
def getConf(key: String, defaultValue: String): String

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils}
* this ClientWrapper.
*/
private[hive] class ClientWrapper(
version: HiveVersion,
override val version: HiveVersion,
config: Map[String, String],
initClassLoader: ClassLoader)
extends ClientInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package object client {
val exclusions: Seq[String] = Nil)

// scalastyle:off
private[client] object hive {
private[hive] object hive {
case object v12 extends HiveVersion("0.12.0")
case object v13 extends HiveVersion("0.13.1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive

import java.io.File

import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable}
import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.sources.DataSourceTest
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.{Logging, SparkFunSuite}

Expand Down Expand Up @@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
override val sqlContext = TestHive

private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1)
private val testDF = range(1, 3).select(
('id + 0.1) cast DecimalType(10, 3) as 'd1,
'id cast StringType as 'd2
).coalesce(1)

Seq(
"parquet" -> (
Expand Down Expand Up @@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes

val columns = hiveTable.schema
assert(columns.map(_.name) === Seq("d1", "d2"))
assert(columns.map(_.hiveType) === Seq("int", "string"))
assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string"))

checkAnswer(table("t"), testDF)
assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2"))
assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2"))
}
}

Expand All @@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes

val columns = hiveTable.schema
assert(columns.map(_.name) === Seq("d1", "d2"))
assert(columns.map(_.hiveType) === Seq("int", "string"))
assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string"))

checkAnswer(table("t"), testDF)
assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2"))
assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2"))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ package org.apache.spark.sql.hive
import java.io.File

import scala.collection.mutable.ArrayBuffer
import scala.sys.process.{ProcessLogger, Process}
import scala.sys.process.{Process, ProcessLogger}

import org.scalatest.Matchers
import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.util.{ResetSystemProperties, Utils}
import org.scalatest.Matchers
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._

/**
* This suite tests spark-submit with applications using HiveContext.
Expand All @@ -50,8 +52,8 @@ class HiveSparkSubmitSuite
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()
val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath()
val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath
val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath
val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",")
val args = Seq(
"--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"),
Expand Down Expand Up @@ -83,6 +85,16 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}

test("SPARK-9757 Persist Parquet relation with decimal column") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SPARK_9757.getClass.getName.stripSuffix("$"),
"--name", "SparkSQLConfTest",
"--master", "local-cluster[2,1,1024]",
unusedJar.toString)
runSparkSubmit(args)
}

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
Expand Down Expand Up @@ -205,7 +217,7 @@ object SparkSQLConfTest extends Logging {
// before spark.sql.hive.metastore.jars get set, we will see the following exception:
// Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only
// be used when hive execution version == hive metastore version.
// Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars
// Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars
// using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1.
val conf = new SparkConf() {
override def getAll: Array[(String, String)] = {
Expand All @@ -231,3 +243,45 @@ object SparkSQLConfTest extends Logging {
sc.stop()
}
}

object SPARK_9757 extends QueryTest with Logging {
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")

val sparkContext = new SparkContext(
new SparkConf()
.set("spark.sql.hive.metastore.version", "0.13.1")
.set("spark.sql.hive.metastore.jars", "maven"))

val hiveContext = new TestHiveContext(sparkContext)
import hiveContext.implicits._
import org.apache.spark.sql.functions._

val dir = Utils.createTempDir()
dir.delete()

try {
{
val df =
hiveContext
.range(10)
.select(('id + 0.1) cast DecimalType(10, 3) as 'dec)
df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
checkAnswer(hiveContext.table("t"), df)
}

{
val df =
hiveContext
.range(10)
.select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct)
df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
checkAnswer(hiveContext.table("t"), df)
}
} finally {
dir.delete()
hiveContext.sql("DROP TABLE t")
sparkContext.stop()
}
}
}