Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,11 @@ object ScalaReflection extends ScalaReflection {
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)

case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val unwrappedParams = getConstructorParameters(t).map(unwrapValueClassParam)

val cls = getClassFromType(tpe)

val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val arguments = unwrappedParams.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
val newTypePath = walkedTypePath.recordField(clsName, fieldName)
Expand Down Expand Up @@ -537,8 +537,8 @@ object ScalaReflection extends ScalaReflection {
s"cannot have circular references in class, but got the circular reference of class $t")
}

val params = getConstructorParameters(t)
val fields = params.map { case (fieldName, fieldType) =>
val unwrappedParams = getConstructorParameters(t).map(unwrapValueClassParam)
val fields = unwrappedParams.map { case (fieldName, fieldType) =>
if (javaKeywords.contains(fieldName)) {
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
"cannot be used as field name\n" + walkedTypePath)
Expand Down Expand Up @@ -696,9 +696,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false)
case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val unwrappedParams = getConstructorParameters(t).map(unwrapValueClassParam)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
unwrappedParams.map { case (fieldName, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
Expand Down Expand Up @@ -753,6 +753,24 @@ object ScalaReflection extends ScalaReflection {
}
}

/**
* [SPARK-20384] Create an underlying param for a given parameter of value class.
* When a member of case class is value class `extends AnyVal`, the member's parameter type
* for encoder should be the underlying type. This is to be consistent with the generated
* type of that member, and avoid compile error.
* @param param param (type is consistent with [[ScalaReflection.getConstructorParameters]])
* @return unwrapped param
*/
private def unwrapValueClassParam(param: (String, `Type`)): (String, `Type`) = {
val (name, tpe) = param
val unwrappedTpe = if (tpe.typeSymbol.asClass.isDerivedValueClass) {
getConstructorParameters(tpe.dealias).head._2
} else {
tpe
}
(name, unwrappedTpe)
}

private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
"char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false",
"final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,19 @@ object TraitProductWithNoConstructorCompanion {}

trait TraitProductWithNoConstructorCompanion extends Product1[Int] {}

object TestingValueClass {
case class IntWrapper(val i: Int) extends AnyVal
case class StrWrapper(s: String) extends AnyVal

case class ValueClassData(intField: Int,
wrappedInt: IntWrapper, // an int column
strField: String,
wrappedStr: StrWrapper) // a string column
}

class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
import TestingValueClass._

// A helper method used to test `ScalaReflection.serializerForType`.
private def serializerFor[T: TypeTag]: Expression =
Expand Down Expand Up @@ -432,4 +443,42 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("f2", StringType))))
assert(deserializerFor[FooWithAnnotation].dataType == ObjectType(classOf[FooWithAnnotation]))
}

test("schema for case class that is a value class") {
val schema = schemaFor[IntWrapper]
assert(
schema === Schema(StructType(Seq(StructField("i", IntegerType, false))), nullable = true))
}

test("schema for case class that contains value class fields") {
val schema = schemaFor[ValueClassData]
assert(
schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = false),
StructField("wrappedInt", IntegerType, nullable = false),
StructField("strField", StringType),
StructField("wrappedStr", StringType)
)),
nullable = true))
}

test("schema for array of value class") {
val schema = schemaFor[Array[IntWrapper]]
assert(
schema === Schema(
ArrayType(StructType(Seq(StructField("i", IntegerType, false))), containsNull = true),
nullable = true))
}

test("schema for map of value class") {
val schema = schemaFor[Map[IntWrapper, StrWrapper]]
assert(
schema === Schema(
MapType(
StructType(Seq(StructField("i", IntegerType, false))),
StructType(Seq(StructField("s", StringType))),
valueContainsNull = true),
nullable = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ object ReferenceValueClass {
case class Container(data: Int)
}

case class StringWrapper(s: String) extends AnyVal
case class ValueContainer(
a: Int,
b: StringWrapper) // a string column
case class IntWrapper(i: Int) extends AnyVal
case class ComplexValueClassContainer(
a: Int,
b: ValueContainer,
c: IntWrapper)
case class SeqOfValueClass(s: Seq[StringWrapper])
case class MapOfValueClassKey(m: Map[IntWrapper, String])
case class MapOfValueClassValue(m: Map[String, StringWrapper])
case class OptionOfValueClassValue(o: Option[StringWrapper])

class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)

Expand Down Expand Up @@ -298,12 +312,44 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}

// test for value classes
encodeDecodeTest(
PrimitiveValueClass(42), "primitive value class")

encodeDecodeTest(
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")

encodeDecodeTest(StringWrapper("a"), "string value class")
encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class")
encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null")
encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")),
IntWrapper(3)), "complex value class")
encodeDecodeTest(
Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
"array of value class")
encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class")
encodeDecodeTest(
Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
"seq of value class")
encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class")
encodeDecodeTest(
Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")),
"map with value class")

// test for nested value class collections
encodeDecodeTest(
MapOfValueClassKey(Map(IntWrapper(1)-> "a")),
"case class with map of value class key")
encodeDecodeTest(
MapOfValueClassValue(Map("a"-> StringWrapper("b"))),
"case class with map of value class value")
encodeDecodeTest(
SeqOfValueClass(Seq(StringWrapper("a"))),
"case class with seq of class value")
encodeDecodeTest(
OptionOfValueClassValue(Some(StringWrapper("a"))),
"case class with option of class value")

encodeDecodeTest(Option(31), "option of int")
encodeDecodeTest(Option.empty[Int], "empty option of int")
encodeDecodeTest(Option("abc"), "option of string")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
Expand Down Expand Up @@ -679,6 +679,33 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}

test("Value class filter") {
val df = spark.sparkContext
.parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))
.toDF()
val filtered = df.where("s = \"a\"")
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF)
}

test("Array value class filter") {
val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b")))
val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d")))

val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF
val filtered = df.where(array_contains(col("wrappers.s"), "b"))
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF)
}

test("Nested value class filter") {
val a = ContainerStringWrapper(StringWrapper("a"))
val b = ContainerStringWrapper(StringWrapper("b"))

val df = spark.sparkContext.parallelize(Seq(a, b)).toDF
// flat value class, `s` field is not in schema
val filtered = df.where("wrapper = \"a\"")
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this change we never support nested value class:

  • Filter with wrapper would break with
org.apache.spark.sql.AnalysisException: cannot resolve '(`wrapper` = 'a')' due to data type mismatch: differing types in '(`wrapper` = 'a')' (struct<s:string> and string).; line 1 pos 0;
  • Filter with wrapper.s would break with:
java.lang.ClassCastException: java.lang.String cannot be cast to org.apache.spark.sql.test.SQLTestData$StringWrapper

}

private lazy val person2: DataFrame = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,7 @@ private[sql] object SQLTestData {
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
case class IntervalData(data: CalendarInterval)
case class StringWrapper(s: String) extends AnyVal
case class ArrayStringWrapper(wrappers: Seq[StringWrapper])
case class ContainerStringWrapper(wrapper: StringWrapper)
}