diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa265..f44cedd734de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -162,7 +164,7 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { - val dataType = schemaFor(tpe).dataType + val dataType = schemaForDefaultBinaryType(tpe).dataType if (path.isDefined) { path.get } else { @@ -407,7 +409,8 @@ object ScalaReflection extends ScalaReflection { val cls = getClassFromType(tpe) val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => - val Schema(dataType, nullable) = schemaFor(fieldType) + val Schema(dataType, nullable) = schemaForDefaultBinaryType(fieldType) + val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. @@ -441,6 +444,10 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } + + case _ => + // default kryo deserializer + DecodeUsingSerializer(getPath, ClassTag(getClassFromType(tpe)), true) } } @@ -639,9 +646,9 @@ object ScalaReflection extends ScalaReflection { val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - case other => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) + case _ => + // default kryo serializer + EncodeUsingSerializer(inputObject, true) } } @@ -708,6 +715,13 @@ object ScalaReflection extends ScalaReflection { s.toAttributes } + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * If the tpe mismatched in schemaFor function, the default BinaryType returned + */ + def schemaForDefaultBinaryType(tpe: `Type`): Schema = scala.util.Try(schemaFor(tpe)).toOption + .getOrElse(Schema(BinaryType, nullable = true)) + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) @@ -723,20 +737,20 @@ object ScalaReflection extends ScalaReflection { Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) + Schema(schemaForDefaultBinaryType(optType).dataType, nullable = true) case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(dataType, nullable) = schemaForDefaultBinaryType(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(dataType, nullable) = schemaForDefaultBinaryType(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, + val Schema(valueDataType, valueNullable) = schemaForDefaultBinaryType(valueType) + Schema(MapType(schemaForDefaultBinaryType(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) @@ -767,7 +781,7 @@ object ScalaReflection extends ScalaReflection { val params = getConstructorParameters(t) Schema(StructType( params.map { case (fieldName, fieldType) => - val Schema(dataType, nullable) = schemaFor(fieldType) + val Schema(dataType, nullable) = schemaForDefaultBinaryType(fieldType) StructField(fieldName, dataType, nullable) }), nullable = true) case other => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index 8c766ef82992..8b8709035005 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -22,18 +22,6 @@ import scala.reflect.ClassTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders -class NonEncodable(i: Int) - -case class ComplexNonEncodable1(name1: NonEncodable) - -case class ComplexNonEncodable2(name2: ComplexNonEncodable1) - -case class ComplexNonEncodable3(name3: Option[NonEncodable]) - -case class ComplexNonEncodable4(name4: Array[NonEncodable]) - -case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) - class EncoderErrorMessageSuite extends SparkFunSuite { // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. @@ -51,52 +39,5 @@ class EncoderErrorMessageSuite extends SparkFunSuite { intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } } - test("nice error message for missing encoder") { - val errorMsg1 = - intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage - assert(errorMsg1.contains( - s"""root class: "${clsName[ComplexNonEncodable1]}"""")) - assert(errorMsg1.contains( - s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) - - val errorMsg2 = - intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage - assert(errorMsg2.contains( - s"""root class: "${clsName[ComplexNonEncodable2]}"""")) - assert(errorMsg2.contains( - s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) - assert(errorMsg1.contains( - s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) - - val errorMsg3 = - intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage - assert(errorMsg3.contains( - s"""root class: "${clsName[ComplexNonEncodable3]}"""")) - assert(errorMsg3.contains( - s"""field (class: "scala.Option", name: "name3")""")) - assert(errorMsg3.contains( - s"""option value class: "${clsName[NonEncodable]}"""")) - - val errorMsg4 = - intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage - assert(errorMsg4.contains( - s"""root class: "${clsName[ComplexNonEncodable4]}"""")) - assert(errorMsg4.contains( - s"""field (class: "scala.Array", name: "name4")""")) - assert(errorMsg4.contains( - s"""array element class: "${clsName[NonEncodable]}"""")) - - val errorMsg5 = - intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage - assert(errorMsg5.contains( - s"""root class: "${clsName[ComplexNonEncodable5]}"""")) - assert(errorMsg5.contains( - s"""field (class: "scala.Option", name: "name5")""")) - assert(errorMsg5.contains( - s"""option value class: "scala.Array"""")) - assert(errorMsg5.contains( - s"""array element class: "${clsName[NonEncodable]}"""")) - } - private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b76938..b6c6dc350c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -50,6 +50,11 @@ case class NestedArray(a: Array[Array[Int]]) { } } +case class KryoUnsupportedEncoderForSubFiled( + a: String, + b: Seq[Int], + c: Option[Set[Int]]) + case class BoxedData( intField: java.lang.Integer, longField: java.lang.Long, @@ -183,6 +188,10 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(new KryoSerializable(15), "kryo object")( encoderFor(Encoders.kryo[KryoSerializable])) + // use kryo to ser/deser the type which has a unsupported Encoder + encodeDecodeTest(Seq(KryoUnsupportedEncoderForSubFiled("a", Seq(1), Some(Set(2))), + KryoUnsupportedEncoderForSubFiled("b", Seq(3), None)), "type with unsupported encoder,use kryo") + // Java encoders encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) encodeDecodeTest(new JavaSerializable(15), "java object")( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b37bf131e8dc..758b4a21afb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1136,8 +1136,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("fallback to kryo for unknow classes in ExpressionEncoder") { + val ds = Seq(DefaultKryoEncoderForSubFiled("a", Seq(1), Some(Set(2))), + DefaultKryoEncoderForSubFiled("b", Seq(3), None)).toDS() + checkDataset(ds, DefaultKryoEncoderForSubFiled("a", Seq(1), Some(Set(2))), + DefaultKryoEncoderForSubFiled("b", Seq(3), None)) + + val df = ds.toDF() + assert(df.schema(0).dataType == StringType) + assert(df.schema(1).dataType == ArrayType(IntegerType, containsNull = false)) + assert(df.schema(2).dataType == BinaryType) + } } +case class DefaultKryoEncoderForSubFiled(a: String, b: Seq[Int], c: Option[Set[Int]]) + case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String])