Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
Expand Down Expand Up @@ -162,7 +164,7 @@ object ScalaReflection extends ScalaReflection {

/** Returns the current path or `GetColumnByOrdinal`. */
def getPath: Expression = {
val dataType = schemaFor(tpe).dataType
val dataType = schemaForDefaultBinaryType(tpe).dataType
if (path.isDefined) {
path.get
} else {
Expand Down Expand Up @@ -407,7 +409,8 @@ object ScalaReflection extends ScalaReflection {
val cls = getClassFromType(tpe)

val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(fieldType)

val clsName = getClassNameFromType(fieldType)
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
Expand Down Expand Up @@ -441,6 +444,10 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}

case _ =>
// default kryo deserializer
DecodeUsingSerializer(getPath, ClassTag(getClassFromType(tpe)), true)
}
}

Expand Down Expand Up @@ -639,9 +646,9 @@ object ScalaReflection extends ScalaReflection {
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)

case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
case _ =>
// default kryo serializer
EncodeUsingSerializer(inputObject, true)
}

}
Expand Down Expand Up @@ -708,6 +715,13 @@ object ScalaReflection extends ScalaReflection {
s.toAttributes
}

/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
* If the tpe mismatched in schemaFor function, the default BinaryType returned
*/
def schemaForDefaultBinaryType(tpe: `Type`): Schema = scala.util.Try(schemaFor(tpe)).toOption
.getOrElse(Schema(BinaryType, nullable = true))

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])

Expand All @@ -723,20 +737,20 @@ object ScalaReflection extends ScalaReflection {
Schema(udt, nullable = true)
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
Schema(schemaForDefaultBinaryType(optType).dataType, nullable = true)
case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
val Schema(valueDataType, valueNullable) = schemaForDefaultBinaryType(valueType)
Schema(MapType(schemaForDefaultBinaryType(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
Expand Down Expand Up @@ -767,7 +781,7 @@ object ScalaReflection extends ScalaReflection {
val params = getConstructorParameters(t)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders

class NonEncodable(i: Int)

case class ComplexNonEncodable1(name1: NonEncodable)

case class ComplexNonEncodable2(name2: ComplexNonEncodable1)

case class ComplexNonEncodable3(name3: Option[NonEncodable])

case class ComplexNonEncodable4(name4: Array[NonEncodable])

case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])

class EncoderErrorMessageSuite extends SparkFunSuite {

// Note: we also test error messages for encoders for private classes in JavaDatasetSuite.
Expand All @@ -51,52 +39,5 @@ class EncoderErrorMessageSuite extends SparkFunSuite {
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
}

test("nice error message for missing encoder") {
val errorMsg1 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
assert(errorMsg1.contains(
s"""root class: "${clsName[ComplexNonEncodable1]}""""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

val errorMsg2 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
assert(errorMsg2.contains(
s"""root class: "${clsName[ComplexNonEncodable2]}""""))
assert(errorMsg2.contains(
s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

val errorMsg3 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
assert(errorMsg3.contains(
s"""root class: "${clsName[ComplexNonEncodable3]}""""))
assert(errorMsg3.contains(
s"""field (class: "scala.Option", name: "name3")"""))
assert(errorMsg3.contains(
s"""option value class: "${clsName[NonEncodable]}""""))

val errorMsg4 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
assert(errorMsg4.contains(
s"""root class: "${clsName[ComplexNonEncodable4]}""""))
assert(errorMsg4.contains(
s"""field (class: "scala.Array", name: "name4")"""))
assert(errorMsg4.contains(
s"""array element class: "${clsName[NonEncodable]}""""))

val errorMsg5 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
assert(errorMsg5.contains(
s"""root class: "${clsName[ComplexNonEncodable5]}""""))
assert(errorMsg5.contains(
s"""field (class: "scala.Option", name: "name5")"""))
assert(errorMsg5.contains(
s"""option value class: "scala.Array""""))
assert(errorMsg5.contains(
s"""array element class: "${clsName[NonEncodable]}""""))
}

private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ case class NestedArray(a: Array[Array[Int]]) {
}
}

case class KryoUnsupportedEncoderForSubFiled(
a: String,
b: Seq[Int],
c: Option[Set[Int]])

case class BoxedData(
intField: java.lang.Integer,
longField: java.lang.Long,
Expand Down Expand Up @@ -183,6 +188,10 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(new KryoSerializable(15), "kryo object")(
encoderFor(Encoders.kryo[KryoSerializable]))

// use kryo to ser/deser the type which has a unsupported Encoder
encodeDecodeTest(Seq(KryoUnsupportedEncoderForSubFiled("a", Seq(1), Some(Set(2))),
KryoUnsupportedEncoderForSubFiled("b", Seq(3), None)), "type with unsupported encoder,use kryo")

// Java encoders
encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String]))
encodeDecodeTest(new JavaSerializable(15), "java object")(
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1136,8 +1136,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head ==
new java.sql.Timestamp(100000))
}

test("fallback to kryo for unknow classes in ExpressionEncoder") {
val ds = Seq(DefaultKryoEncoderForSubFiled("a", Seq(1), Some(Set(2))),
DefaultKryoEncoderForSubFiled("b", Seq(3), None)).toDS()
checkDataset(ds, DefaultKryoEncoderForSubFiled("a", Seq(1), Some(Set(2))),
DefaultKryoEncoderForSubFiled("b", Seq(3), None))

val df = ds.toDF()
assert(df.schema(0).dataType == StringType)
assert(df.schema(1).dataType == ArrayType(IntegerType, containsNull = false))
assert(df.schema(2).dataType == BinaryType)
}
}

case class DefaultKryoEncoderForSubFiled(a: String, b: Seq[Int], c: Option[Set[Int]])

case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
case class WithMap(id: String, map_test: scala.collection.Map[Long, String])

Expand Down