From 089ab18dd03dde6a4fd6bd787b1bd3674f394ca7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 Dec 2015 16:27:35 +0800 Subject: [PATCH 1/7] Add SQLUserDefinedType support for encoder. --- .../spark/sql/catalyst/ScalaReflection.scala | 24 +++++++++++++++++++ .../spark/sql/UserDefinedTypeSuite.scala | 21 +++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c1b1d5cd2dee..1cef0db04e8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } + val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -372,6 +373,17 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } + + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -421,6 +433,7 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { + val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -589,6 +602,17 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f602f2fb89ca..88b1477d08f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import scala.beans.{BeanInfo, BeanProperty} +import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -89,6 +94,20 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + test("user type with ScalaReflection") { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + + val schema = ScalaReflection.schemaFor[MyLabeledPoint].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + + val pointEncoder = encoderFor[MyLabeledPoint] + val unsafeRows = points.map(pointEncoder.toRow(_).copy()) + val df = DataFrame(sqlContext, LocalRelation(attributeSeq, unsafeRows)) + val decodedPoints = df.collect() + } + test("UDTs and UDFs") { sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") From 51b76b163a95de571d01473a24fda7773db57498 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 Dec 2015 16:39:19 +0800 Subject: [PATCH 2/7] Add assert. --- .../test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 88b1477d08f4..d2982c2e13a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -106,6 +106,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val unsafeRows = points.map(pointEncoder.toRow(_).copy()) val df = DataFrame(sqlContext, LocalRelation(attributeSeq, unsafeRows)) val decodedPoints = df.collect() + points.zip(decodedPoints).foreach { case (p, p2) => + assert(p.label == p2(0) && p.features == p2(1)) + } } test("UDTs and UDFs") { From d01c661c0cf2157a90b6c732b70309f1aa2f3ccd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Dec 2015 15:39:09 +0800 Subject: [PATCH 3/7] Add UserDefinedType to Cast. --- .../apache/spark/sql/catalyst/expressions/Cast.scala | 6 ++++++ .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b18f49f3203f..02e5d707cdb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -81,6 +81,9 @@ object Cast { toField.nullable) } + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass => + true + case _ => false } @@ -473,6 +476,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + (c, evPrim, evNull) => s"$evPrim = $c;" } // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index d2982c2e13a5..4947e84eef97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql +import java.util.concurrent.ConcurrentMap + import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import scala.beans.{BeanInfo, BeanProperty} import scala.reflect.runtime.universe.TypeTag +import com.google.common.collect.MapMaker + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} import org.apache.spark.sql.catalyst.encoders._ @@ -94,6 +98,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + test("user type with ScalaReflection") { val points = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), @@ -109,6 +116,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT points.zip(decodedPoints).foreach { case (p, p2) => assert(p.label == p2(0) && p.features == p2(1)) } + + val boundEncoder = pointEncoder.resolve(attributeSeq, outers).bind(attributeSeq) + val point = MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))) + assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point) } test("UDTs and UDFs") { From 42303d25fe6e85d65644e45e7d4b600f0041a1f2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 30 Dec 2015 16:40:20 +0800 Subject: [PATCH 4/7] Complete the update to Cast for UDT. --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 02e5d707cdb6..d82d3edae4e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -434,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) @@ -479,6 +485,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => (c, evPrim, evNull) => s"$evPrim = $c;" + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's From 72446f15ca26f2ea0570d01368756fed759a89fb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 31 Dec 2015 16:15:23 +0800 Subject: [PATCH 5/7] Move tests. --- .../encoders/ExpressionEncoderSuite.scala | 12 ++++++++- .../spark/sql/UserDefinedTypeSuite.scala | 27 +------------------ 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7233e0f1b5ba..2453f1b0637e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData, ScalaReflection} import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -239,6 +239,16 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("user type with ScalaReflection") { + val point = (new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)) + val schema = ScalaReflection.schemaFor[Tuple2[ExamplePoint, ExamplePoint]] + .dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + val boundEncoder = encoderFor[Tuple2[ExamplePoint, ExamplePoint]] + .resolve(attributeSeq, outers).bind(attributeSeq) + assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point) + } + test("nullable of encoder schema") { def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 4947e84eef97..caba21a52350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -27,11 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import com.google.common.collect.MapMaker import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -101,27 +97,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) - test("user type with ScalaReflection") { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - - val schema = ScalaReflection.schemaFor[MyLabeledPoint].dataType.asInstanceOf[StructType] - val attributeSeq = schema.toAttributes - - val pointEncoder = encoderFor[MyLabeledPoint] - val unsafeRows = points.map(pointEncoder.toRow(_).copy()) - val df = DataFrame(sqlContext, LocalRelation(attributeSeq, unsafeRows)) - val decodedPoints = df.collect() - points.zip(decodedPoints).foreach { case (p, p2) => - assert(p.label == p2(0) && p.features == p2(1)) - } - - val boundEncoder = pointEncoder.resolve(attributeSeq, outers).bind(attributeSeq) - val point = MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))) - assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point) - } - test("UDTs and UDFs") { sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") From 62fa738dca9e7848156772ec99d5f956622321d9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 31 Dec 2015 16:55:00 +0800 Subject: [PATCH 6/7] Update for new NewInstance. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f4dd86871d5e..c6aa60b0b4d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -369,7 +369,6 @@ object ScalaReflection extends ScalaReflection { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } @@ -579,7 +578,6 @@ object ScalaReflection extends ScalaReflection { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) From 80a3f7b56f14eeb1b7c3d84cc2544458d9de13cd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jan 2016 10:49:43 +0800 Subject: [PATCH 7/7] For comments. --- .../catalyst/encoders/ExpressionEncoderSuite.scala | 12 ++---------- .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 8 -------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index eb1442b7551c..f0f52213dded 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData, ScalaReflection} +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -242,15 +242,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } - test("user type with ScalaReflection") { - val point = (new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)) - val schema = ScalaReflection.schemaFor[Tuple2[ExamplePoint, ExamplePoint]] - .dataType.asInstanceOf[StructType] - val attributeSeq = schema.toAttributes - val boundEncoder = encoderFor[Tuple2[ExamplePoint, ExamplePoint]] - .resolve(attributeSeq, outers).bind(attributeSeq) - assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point) - } + productTest(("UDT", new ExamplePoint(0.1, 0.2))) test("nullable of encoder schema") { def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index dc4db7c99d9b..2a1117318ad1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,14 +17,9 @@ package org.apache.spark.sql -import java.util.concurrent.ConcurrentMap - import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import scala.beans.{BeanInfo, BeanProperty} -import scala.reflect.runtime.universe.TypeTag - -import com.google.common.collect.MapMaker import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -99,9 +94,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } - private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() - outers.put(getClass.getName, this) - test("UDTs and UDFs") { sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points")