From c55a634e0ae88b2be9a043bc6dc43057ddc8de73 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Mar 2018 06:28:56 +0000 Subject: [PATCH 1/6] Add interpreted execution for MapObjects expression. --- .../expressions/objects/objects.scala | 89 ++++++++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 50 +++++++++++ 2 files changed, 128 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 721d58970913..2d9a8ffa72dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -446,12 +447,22 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + input.get(0, dataType) + } override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -544,8 +555,71 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + private val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Seq[_] = { + inputCollection.map { element => + val row = InternalRow.fromSeq(Seq(element)) + lambdaFunction.eval(row) + } + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return inputCollection + } + + val results = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + executeFuncOnCollection(inputCollection.asInstanceOf[Seq[_]]) + case ObjectType(cls) if cls.isArray => + executeFuncOnCollection(inputCollection.asInstanceOf[Array[_]].toSeq) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + executeFuncOnCollection(inputCollection.asInstanceOf[java.util.List[_]].asScala) + case ObjectType(cls) if cls == classOf[Object] => + if (inputCollection.getClass.isArray) { + executeFuncOnCollection(inputCollection.asInstanceOf[Array[_]].toSeq) + } else { + executeFuncOnCollection(inputCollection.asInstanceOf[Seq[_]]) + } + case ArrayType(et, _) => + executeFuncOnCollection(inputCollection.asInstanceOf[ArrayData].array) + } + + customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + results.toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + results.toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + results.asJava + } else { + val builder = Try(cls.getConstructor(Integer.TYPE)).map { constructor => + constructor.newInstance() + }.getOrElse { + cls.getConstructor().newInstance(results.length.asInstanceOf[Object]) + }.asInstanceOf[java.util.List[Any]] + + results.foreach(builder.add(_)) + builder + } + case None => + // array + new GenericArrayData(results.toArray) + } + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( @@ -592,13 +666,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 50e57737a461..9c6eb7be0098 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.JavaConverters._ + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -110,4 +112,52 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { }.getMessage assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") } + + test("SPARK-23587: MapObjects should support interpreted execution") { + val customCollectionClasses = Seq(classOf[Seq[_]], classOf[scala.collection.Set[_]], + classOf[java.util.List[_]], classOf[java.util.AbstractList[_]], + classOf[java.util.AbstractSequentialList[_]], null) + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + val inputObject = BoundReference(0, inputType, nullable = true) + + customCollectionClasses.foreach { customCollectionCls => + val optClass = Option(customCollectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + customCollectionCls match { + case null => + case l if l.isAssignableFrom(classOf[java.util.AbstractList[_]]) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected.toSeq) + case l if l.isAssignableFrom(classOf[java.util.AbstractSequentialList[_]]) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected.toSeq) + case s if s.isAssignableFrom(classOf[Seq[_]]) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected.toSeq) + case s if s.isAssignableFrom(classOf[scala.collection.Set[_]]) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + optClass.foreach(_.isAssignableFrom(result.getClass)) + } + } + } } From 3627dc3a71878c743435969ee43c5cf1c33dbbb8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Mar 2018 08:37:15 +0000 Subject: [PATCH 2/6] Fix bug. --- .../sql/catalyst/expressions/objects/objects.scala | 4 ++-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5ae12c2ca812..93f3623b08d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -662,9 +662,9 @@ case class MapObjects private( results.asJava } else { val builder = Try(cls.getConstructor(Integer.TYPE)).map { constructor => - constructor.newInstance() + constructor.newInstance(results.length.asInstanceOf[Object]) }.getOrElse { - cls.getConstructor().newInstance(results.length.asInstanceOf[Object]) + cls.getConstructor().newInstance() }.asInstanceOf[java.util.List[Any]] results.foreach(builder.add(_)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 18c031bf094d..c4a8eeda3093 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -127,9 +127,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-23587: MapObjects should support interpreted execution") { - val customCollectionClasses = Seq(classOf[Seq[_]], classOf[scala.collection.Set[_]], - classOf[java.util.List[_]], classOf[java.util.AbstractList[_]], - classOf[java.util.AbstractSequentialList[_]], null) + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) val function = (lambda: Expression) => Add(lambda, Literal(1)) val elementType = IntegerType val expected = Seq(2, 3, 4) @@ -143,6 +144,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { vector.add(1) vector.add(2) vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) Seq( (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), From 07f8143c72ab8c0b2f6ae15016c263f64ef18f36 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Mar 2018 12:23:17 +0000 Subject: [PATCH 3/6] Use lazy to avoid call dataType on UnresolvedAttribute. --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 93f3623b08d1..bb5727327dae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -612,7 +612,7 @@ case class MapObjects private( // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. // When we want to apply MapObjects on it, we have to use it. - private val inputDataType = inputData.dataType match { + lazy private val inputDataType = inputData.dataType match { case p: PythonUserDefinedType => p.sqlType case _ => inputData.dataType } From e725608d1b38a7a2b1a0677afca947cec6a12801 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 10 Mar 2018 08:55:16 +0000 Subject: [PATCH 4/6] Address comments. --- .../expressions/objects/objects.scala | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ef99b51ed79e..c4b593651f9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -633,12 +633,10 @@ case class MapObjects private( case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => x => executeFuncOnCollection(x.asInstanceOf[java.util.List[_]].asScala) case ObjectType(cls) if cls == classOf[Object] => - (inputCollection) => { - if (inputCollection.getClass.isArray) { - executeFuncOnCollection(inputCollection.asInstanceOf[Array[_]].toSeq) - } else { - executeFuncOnCollection(inputCollection.asInstanceOf[Seq[_]]) - } + if (cls.isArray) { + x => executeFuncOnCollection(x.asInstanceOf[Array[_]].toSeq) + } else { + x => executeFuncOnCollection(x.asInstanceOf[Seq[_]]) } case ArrayType(et, _) => x => executeFuncOnCollection(x.asInstanceOf[ArrayData].array) @@ -648,7 +646,7 @@ case class MapObjects private( private lazy val getResults: Seq[_] => Any = customCollectionCls match { case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => // Scala sequence - _.toSeq + identity _ case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => // Scala set _.toSet @@ -656,13 +654,22 @@ case class MapObjects private( // Java list if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` _.asJava } else { + // Specifying concrete implementations of `java.util.List` (results) => { - val builder = Try(cls.getConstructor(Integer.TYPE)).map { constructor => + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + val builder = intParamConstructor.map { constructor => constructor.newInstance(results.length.asInstanceOf[Object]) }.getOrElse { - cls.getConstructor().newInstance() + noParamConstructor.get.newInstance() }.asInstanceOf[java.util.List[Any]] results.foreach(builder.add(_)) From f0ba6147162510ffae6f98e1b3877ff948d9b1eb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Mar 2018 07:06:06 +0000 Subject: [PATCH 5/6] Address comments. --- .../expressions/objects/objects.scala | 76 ++++++++++--------- .../expressions/ObjectExpressionsSuite.scala | 50 +++++++----- 2 files changed, 70 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c4b593651f9e..baba4e77bd7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -617,68 +617,75 @@ case class MapObjects private( case _ => inputData.dataType } - private def executeFuncOnCollection(inputCollection: Seq[_]): Seq[_] = { - inputCollection.map { element => - val row = InternalRow.fromSeq(Seq(element)) + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) lambdaFunction.eval(row) } } - // Executes lambda function on input collection. - private lazy val executeFunc: Any => Seq[_] = inputDataType match { + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - x => executeFuncOnCollection(x.asInstanceOf[Seq[_]]) + _.asInstanceOf[Seq[_]] case ObjectType(cls) if cls.isArray => - x => executeFuncOnCollection(x.asInstanceOf[Array[_]].toSeq) + _.asInstanceOf[Array[_]].toSeq case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - x => executeFuncOnCollection(x.asInstanceOf[java.util.List[_]].asScala) + _.asInstanceOf[java.util.List[_]].asScala case ObjectType(cls) if cls == classOf[Object] => - if (cls.isArray) { - x => executeFuncOnCollection(x.asInstanceOf[Array[_]].toSeq) - } else { - x => executeFuncOnCollection(x.asInstanceOf[Seq[_]]) + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } } case ArrayType(et, _) => - x => executeFuncOnCollection(x.asInstanceOf[ArrayData].array) + _.asInstanceOf[ArrayData].array } - // Converts the processed collection to custom collection class if any. - private lazy val getResults: Seq[_] => Any = customCollectionCls match { + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => // Scala sequence - identity _ + executeFuncOnCollection(_).toSeq case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => // Scala set - _.toSet + executeFuncOnCollection(_).toSet case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => // Java list if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || cls == classOf[java.util.AbstractSequentialList[_]]) { // Specifying non concrete implementations of `java.util.List` - _.asJava + executeFuncOnCollection(_).toSeq.asJava } else { - // Specifying concrete implementations of `java.util.List` - (results) => { - val constructors = cls.getConstructors() - val intParamConstructor = constructors.find { constructor => - constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] - } - val noParamConstructor = constructors.find { constructor => - constructor.getParameterCount == 0 - } - val builder = intParamConstructor.map { constructor => - constructor.newInstance(results.length.asInstanceOf[Object]) - }.getOrElse { - noParamConstructor.get.newInstance() - }.asInstanceOf[java.util.List[Any]] + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] results.foreach(builder.add(_)) builder } } case None => // array - x => new GenericArrayData(x.toArray) + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `$cls` is not supported by `MapObjects` as " + + "resulting collection.") } override def eval(input: InternalRow): Any = { @@ -687,8 +694,7 @@ case class MapObjects private( if (inputCollection == null) { return null } - - getResults(executeFunc(inputCollection)) + mapElements(convertToSeq(inputCollection)) } override def dataType: DataType = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 86c725eace92..13468d69842c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -128,13 +128,33 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], classOf[java.util.Stack[Int]], null) - val function = (lambda: Expression) => Add(lambda, Literal(1)) - val elementType = IntegerType - val expected = Seq(2, 3, 4) val list = new java.util.ArrayList[Int]() list.add(1) @@ -154,27 +174,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), (list, ObjectType(classOf[java.util.List[Int]])), (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), (arrayData, ArrayType(IntegerType)) ).foreach { case (collection, inputType) => - val inputObject = BoundReference(0, inputType, nullable = true) + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) - customCollectionClasses.foreach { customCollectionCls => - val optClass = Option(customCollectionCls) - val mapObj = MapObjects(function, inputObject, elementType, true, optClass) - val row = InternalRow.fromSeq(Seq(collection)) - val result = mapObj.eval(row) - - customCollectionCls match { - case null => - assert(result.asInstanceOf[ArrayData].array.toSeq == expected) - case l if classOf[java.util.List[_]].isAssignableFrom(l) => - assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) - case s if classOf[Seq[_]].isAssignableFrom(s) => - assert(result.asInstanceOf[Seq[_]].toSeq == expected) - case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => - assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) - } - } + // Unsupported custom collection class + assert(intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage().contains("not supported by `MapObjects` as resulting collection.")) } } From d4f0ecb6d62bdc31a7bab3d842c5aebb62903a88 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Mar 2018 13:23:46 +0000 Subject: [PATCH 6/6] Improve test case. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 2 +- .../catalyst/expressions/ObjectExpressionsSuite.scala | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index baba4e77bd7d..0e9d357c19c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -684,7 +684,7 @@ case class MapObjects private( // array x => new GenericArrayData(executeFuncOnCollection(x).toArray) case Some(cls) => - throw new RuntimeException(s"class `$cls` is not supported by `MapObjects` as " + + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + "resulting collection.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 13468d69842c..c7af8a26464f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -172,6 +172,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq( (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), (list, ObjectType(classOf[java.util.List[Int]])), (vector, ObjectType(classOf[java.util.Vector[Int]])), (stack, ObjectType(classOf[java.util.Stack[Int]])), @@ -180,9 +183,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) // Unsupported custom collection class - assert(intercept[RuntimeException] { + val errMsg = intercept[RuntimeException] { testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) - }.getMessage().contains("not supported by `MapObjects` as resulting collection.")) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) } }