diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1f7634bafa420..8e4013231376a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -338,11 +338,11 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) + val unwrappedParams = getConstructorParameters(t).map(unwrapValueClassParam) val cls = getClassFromType(tpe) - val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => + val arguments = unwrappedParams.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = walkedTypePath.recordField(clsName, fieldName) @@ -537,8 +537,8 @@ object ScalaReflection extends ScalaReflection { s"cannot have circular references in class, but got the circular reference of class $t") } - val params = getConstructorParameters(t) - val fields = params.map { case (fieldName, fieldType) => + val unwrappedParams = getConstructorParameters(t).map(unwrapValueClassParam) + val fields = unwrappedParams.map { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + "cannot be used as field name\n" + walkedTypePath) @@ -696,9 +696,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false) case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false) case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) + val unwrappedParams = getConstructorParameters(t).map(unwrapValueClassParam) Schema(StructType( - params.map { case (fieldName, fieldType) => + unwrappedParams.map { case (fieldName, fieldType) => val Schema(dataType, nullable) = schemaFor(fieldType) StructField(fieldName, dataType, nullable) }), nullable = true) @@ -753,6 +753,24 @@ object ScalaReflection extends ScalaReflection { } } + /** + * [SPARK-20384] Create an underlying param for a given parameter of value class. + * When a member of case class is value class `extends AnyVal`, the member's parameter type + * for encoder should be the underlying type. This is to be consistent with the generated + * type of that member, and avoid compile error. + * @param param param (type is consistent with [[ScalaReflection.getConstructorParameters]]) + * @return unwrapped param + */ + private def unwrapValueClassParam(param: (String, `Type`)): (String, `Type`) = { + val (name, tpe) = param + val unwrappedTpe = if (tpe.typeSymbol.asClass.isDerivedValueClass) { + getConstructorParameters(tpe.dealias).head._2 + } else { + tpe + } + (name, unwrappedTpe) + } + private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f9cd9c3c398f6..70989aed1b6d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -146,8 +146,19 @@ object TraitProductWithNoConstructorCompanion {} trait TraitProductWithNoConstructorCompanion extends Product1[Int] {} +object TestingValueClass { + case class IntWrapper(val i: Int) extends AnyVal + case class StrWrapper(s: String) extends AnyVal + + case class ValueClassData(intField: Int, + wrappedInt: IntWrapper, // an int column + strField: String, + wrappedStr: StrWrapper) // a string column +} + class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + import TestingValueClass._ // A helper method used to test `ScalaReflection.serializerForType`. private def serializerFor[T: TypeTag]: Expression = @@ -432,4 +443,42 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("f2", StringType)))) assert(deserializerFor[FooWithAnnotation].dataType == ObjectType(classOf[FooWithAnnotation])) } + + test("schema for case class that is a value class") { + val schema = schemaFor[IntWrapper] + assert( + schema === Schema(StructType(Seq(StructField("i", IntegerType, false))), nullable = true)) + } + + test("schema for case class that contains value class fields") { + val schema = schemaFor[ValueClassData] + assert( + schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("wrappedInt", IntegerType, nullable = false), + StructField("strField", StringType), + StructField("wrappedStr", StringType) + )), + nullable = true)) + } + + test("schema for array of value class") { + val schema = schemaFor[Array[IntWrapper]] + assert( + schema === Schema( + ArrayType(StructType(Seq(StructField("i", IntegerType, false))), containsNull = true), + nullable = true)) + } + + test("schema for map of value class") { + val schema = schemaFor[Map[IntWrapper, StrWrapper]] + assert( + schema === Schema( + MapType( + StructType(Seq(StructField("i", IntegerType, false))), + StructType(Seq(StructField("s", StringType))), + valueContainsNull = true), + nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index c1f1be3b30e4b..2fe574debacc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -113,6 +113,20 @@ object ReferenceValueClass { case class Container(data: Int) } +case class StringWrapper(s: String) extends AnyVal +case class ValueContainer( + a: Int, + b: StringWrapper) // a string column +case class IntWrapper(i: Int) extends AnyVal +case class ComplexValueClassContainer( + a: Int, + b: ValueContainer, + c: IntWrapper) +case class SeqOfValueClass(s: Seq[StringWrapper]) +case class MapOfValueClassKey(m: Map[IntWrapper, String]) +case class MapOfValueClassValue(m: Map[String, StringWrapper]) +case class OptionOfValueClassValue(o: Option[StringWrapper]) + class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -298,12 +312,44 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + // test for value classes encodeDecodeTest( PrimitiveValueClass(42), "primitive value class") encodeDecodeTest( ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + encodeDecodeTest(StringWrapper("a"), "string value class") + encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class") + encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null") + encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")), + IntWrapper(3)), "complex value class") + encodeDecodeTest( + Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)), + "array of value class") + encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class") + encodeDecodeTest( + Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)), + "seq of value class") + encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class") + encodeDecodeTest( + Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")), + "map with value class") + + // test for nested value class collections + encodeDecodeTest( + MapOfValueClassKey(Map(IntWrapper(1)-> "a")), + "case class with map of value class key") + encodeDecodeTest( + MapOfValueClassValue(Map("a"-> StringWrapper("b"))), + "case class with map of value class value") + encodeDecodeTest( + SeqOfValueClass(Seq(StringWrapper("a"))), + "case class with seq of class value") + encodeDecodeTest( + OptionOfValueClassValue(Some(StringWrapper("a"))), + "case class with option of class value") + encodeDecodeTest(Option(31), "option of int") encodeDecodeTest(Option.empty[Int], "empty option of int") encodeDecodeTest(Option("abc"), "option of string") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8a9b923e284f3..22493c0932d3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} -import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2} +import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -679,6 +679,33 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } + test("Value class filter") { + val df = spark.sparkContext + .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) + .toDF() + val filtered = df.where("s = \"a\"") + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF) + } + + test("Array value class filter") { + val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b"))) + val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d"))) + + val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF + val filtered = df.where(array_contains(col("wrappers.s"), "b")) + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF) + } + + test("Nested value class filter") { + val a = ContainerStringWrapper(StringWrapper("a")) + val b = ContainerStringWrapper(StringWrapper("b")) + + val df = spark.sparkContext.parallelize(Seq(a, b)).toDF + // flat value class, `s` field is not in schema + val filtered = df.where("wrapper = \"a\"") + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF) + } + private lazy val person2: DataFrame = Seq( ("Bob", 16, 176), ("Alice", 32, 164), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index c51faaf10f5dd..c8669079a2ed5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -344,4 +344,7 @@ private[sql] object SQLTestData { case class CourseSales(course: String, year: Int, earnings: Double) case class TrainingSales(training: String, sales: CourseSales) case class IntervalData(data: CalendarInterval) + case class StringWrapper(s: String) extends AnyVal + case class ArrayStringWrapper(wrappers: Seq[StringWrapper]) + case class ContainerStringWrapper(wrapper: StringWrapper) }