Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
Expand Down Expand Up @@ -501,12 +502,22 @@ case class LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true) extends LeafExpression
with Unevaluable with NonSQLExpression {
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {

// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
assert(input.numFields == 1,
"The input row of interpreted LambdaVariable should have only 1 field.")
input.get(0, dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a change for this PR. Maybe we should use accessors here? This uses a matching under the hood and is slower than virtual function dispatch. Implementing this would also be useful for BoundReference for example.

Copy link
Member Author

@viirya viirya Mar 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean something like this?

lazy val accessor:  InternalRow => Any = dataType match {
  case IntegerType => (inputRow) => inputRow.getInt(0)
  case LongType => (inputRow) => inputRow.getLong(0)
  ...
}

override def eval(input: InternalRow): Any = accessor(input)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's spin that off into a different ticket if we want to work on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. After this is merged, I will create another PR for it.

}

override def genCode(ctx: CodegenContext): ExprCode = {
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
}

// This won't be called as `genCode` is overrided, just overriding it to make
// `LambdaVariable` non-abstract.
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}

/**
Expand Down Expand Up @@ -599,8 +610,92 @@ case class MapObjects private(

override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
// The data with UserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
lazy private val inputDataType = inputData.dataType match {
case u: UserDefinedType[_] => u.sqlType
case _ => inputData.dataType
}

private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
val row = new GenericInternalRow(1)
inputCollection.toIterator.map { element =>
row.update(0, element)
lambdaFunction.eval(row)
}
}

private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
_.asInstanceOf[Seq[_]]
case ObjectType(cls) if cls.isArray =>
_.asInstanceOf[Array[_]].toSeq
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
_.asInstanceOf[java.util.List[_]].asScala
case ObjectType(cls) if cls == classOf[Object] =>
(inputCollection) => {
if (inputCollection.getClass.isArray) {
inputCollection.asInstanceOf[Array[_]].toSeq
} else {
inputCollection.asInstanceOf[Seq[_]]
}
}
case ArrayType(et, _) =>
_.asInstanceOf[ArrayData].array
}

private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
executeFuncOnCollection(_).toSeq
case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala set
executeFuncOnCollection(_).toSet
case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
// Java list
if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
cls == classOf[java.util.AbstractSequentialList[_]]) {
// Specifying non concrete implementations of `java.util.List`
executeFuncOnCollection(_).toSeq.asJava
} else {
val constructors = cls.getConstructors()
val intParamConstructor = constructors.find { constructor =>
constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
}
val noParamConstructor = constructors.find { constructor =>
constructor.getParameterCount == 0
}

val constructor = intParamConstructor.map { intConstructor =>
(len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
}.getOrElse {
(_: Int) => noParamConstructor.get.newInstance()
}

// Specifying concrete implementations of `java.util.List`
(inputs) => {
val results = executeFuncOnCollection(inputs)
val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]]
results.foreach(builder.add(_))
builder
}
}
case None =>
// array
x => new GenericArrayData(executeFuncOnCollection(x).toArray)
case Some(cls) =>
throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " +
"resulting collection.")
}

override def eval(input: InternalRow): Any = {
val inputCollection = inputData.eval(input)

if (inputCollection == null) {
return null
}
mapElements(convertToSeq(inputCollection))
}

override def dataType: DataType =
customCollectionCls.map(ObjectType.apply).getOrElse(
Expand Down Expand Up @@ -647,13 +742,6 @@ case class MapObjects private(
case _ => ""
}

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
case _ => inputData.dataType
}

// `MapObjects` generates a while loop to traverse the elements of the input collection. We
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{SparkConf, SparkFunSuite}
Expand All @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -126,6 +127,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("SPARK-23587: MapObjects should support interpreted execution") {
def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = {
val function = (lambda: Expression) => Add(lambda, Literal(1))
val elementType = IntegerType
val expected = Seq(2, 3, 4)

val inputObject = BoundReference(0, inputType, nullable = true)
val optClass = Option(collectionCls)
val mapObj = MapObjects(function, inputObject, elementType, true, optClass)
val row = InternalRow.fromSeq(Seq(collection))
val result = mapObj.eval(row)

collectionCls match {
case null =>
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected)
case s if classOf[Seq[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[Seq[_]].toSeq == expected)
case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
}
}

val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
classOf[java.util.Stack[Int]], null)

val list = new java.util.ArrayList[Int]()
list.add(1)
list.add(2)
list.add(3)
val arrayData = new GenericArrayData(Array(1, 2, 3))
val vector = new java.util.Vector[Int]()
vector.add(1)
vector.add(2)
vector.add(3)
val stack = new java.util.Stack[Int]()
stack.add(1)
stack.add(2)
stack.add(3)

Seq(
(Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
(Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[Object])),
(Array(1, 2, 3), ObjectType(classOf[Object])),
(list, ObjectType(classOf[java.util.List[Int]])),
(vector, ObjectType(classOf[java.util.Vector[Int]])),
(stack, ObjectType(classOf[java.util.Stack[Int]])),
(arrayData, ArrayType(IntegerType))
).foreach { case (collection, inputType) =>
customCollectionClasses.foreach(testMapObjects(collection, _, inputType))

// Unsupported custom collection class
val errMsg = intercept[RuntimeException] {
testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType)
}.getMessage()
assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " +
"as resulting collection."))
}
}

test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") {
val cls = classOf[java.lang.Integer]
val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true)
Expand Down