diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f524de68fbce0..1492b340b9b06 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,6 +39,7 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGGREGATE_UDF = 202 } /** diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..48a1a1b785c23 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGGREGATE_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b631e2041706f..8d7eb78b4f2b3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,7 +32,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType -from pyspark.sql.udf import UserDefinedFunction, _create_udf +from pyspark.sql.udf import UserDefinedFunction, UserDefinedAggregateFunction, _create_udf def _create_function(name, doc=""): @@ -2241,6 +2241,30 @@ def pandas_udf(f=None, returnType=None, functionType=None): return _create_udf(f=f, returnType=return_type, evalType=eval_type) +# ---------------------------- User Defined Aggregate Function ---------------------------------- + +def pandas_udaf(f=None, returnType=StringType(), supportsPartial=False): + """ + Creates a :class:`Column` expression representing a vectorized user defined aggregate + function (UDAF). + """ + def _udaf(f, returnType, supportsPartial): + udaf_obj = UserDefinedAggregateFunction(f, returnType, supportsPartial) + return udaf_obj._wrapped() + + # decorator @pandas_udaf, @pandas_udaf() or @pandas_udaf(dataType()) + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + if isinstance(returnType, bool): + supportsPartial = returnType + returnType = StringType() + return_type = f or returnType + return functools.partial(_udaf, returnType=return_type, supportsPartial=supportsPartial) + else: + return _udaf(f=f, returnType=returnType, supportsPartial=supportsPartial) + + blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 762afe0d730f3..e5ca9c99f3f74 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3849,6 +3849,26 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class VectorizedUDAFTests(ReusedSQLTestCase): + + def test_vectorized_udaf_basic(self): + from pyspark.sql.functions import pandas_udaf, col, expr + df = self.spark.range(100).select(col('id').alias('n'), (col('id') % 2 == 0).alias('g')) + + @pandas_udaf(LongType(), supportsPartial=True) + def p_sum(v): + return v.sum() + + @pandas_udaf(DoubleType(), supportsPartial=False) + def p_avg(v): + return v.mean() + + res = df.groupBy(col('g')).agg(p_sum(col('n')), expr('count(n)'), p_avg(col('n'))) + expected = df.groupBy(col('g')).agg(expr('sum(n)'), expr('count(n)'), expr('avg(n)')) + self.assertEquals(expected.collect(), res.collect()) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c3301a41ccd5a..cff8de9cf84ab 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -159,3 +159,90 @@ def wrapper(*args): wrapper.evalType = self.evalType return wrapper + + +class UserDefinedAggregateFunction(object): + """ + User defined aggregate function in Python + + .. versionadded:: 2.3 + """ + def __init__(self, func, returnType, supportsPartial, name=None): + if not callable(func): + raise TypeError( + "Not a function or callable (__call__ is not defined): " + "{0}".format(type(func))) + + self.func = func + self._returnType = returnType + self.supportsPartial = supportsPartial + # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None + self._judaf_placeholder = None + self._name = name or ( + func.__name__ if hasattr(func, '__name__') + else func.__class__.__name__) + + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + return self._returnType_placeholder + + @property + def _judaf(self): + # It is possible that concurrent access, to newly created UDF, + # will initialize multiple UserDefinedPythonFunctions. + # This is unlikely, doesn't affect correctness, + # and should have a minimal performance impact. + if self._judaf_placeholder is None: + self._judaf_placeholder = self._create_judaf() + return self._judaf_placeholder + + def _create_judaf(self): + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + wrapped_func = _wrap_function(sc, self.func, self.returnType) + jdt = spark._jsparkSession.parseDataType(self.returnType.json()) + judaf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedAggregatePythonFunction( + self._name, wrapped_func, jdt, self.supportsPartial) + return judaf + + def __call__(self, *cols): + judaf = self._judaf + sc = SparkContext._active_spark_context + return Column(judaf.apply(_to_seq(sc, cols, _to_java_column))) + + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.func, assigned=assignments) + def wrapper(*args): + return self(*args) + + wrapper.__name__ = self._name + wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') + else self.func.__class__.__module__) + wrapper.func = self.func + wrapper.returnType = self.returnType + wrapper.supportsPartial = self.supportsPartial + + return wrapper diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 939643071943a..d97df2a7d3d93 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -33,7 +33,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type +from pyspark.sql.types import to_arrow_type, StructType from pyspark import shuffle pickleSer = PickleSerializer() @@ -110,6 +110,24 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_aggregate_udf(f, return_type): + import pandas as pd + if isinstance(return_type, StructType): + arrow_return_types = [to_arrow_type(field.dataType) for field in return_type] + else: + arrow_return_types = [to_arrow_type(return_type)] + + def fn(*args): + out = f(*[pd.Series(arg[0]) for arg in args]) + if not isinstance(out, (tuple, list)): + out = (out,) + assert len(out) == len(arrow_return_types), \ + 'Columns of tuple don\'t match return schema' + + return [(pd.Series(v), t) for v, t in zip(out, arrow_return_types)] + return fn + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,6 +144,8 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF: + return arg_offsets, wrap_pandas_group_aggregate_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) @@ -143,13 +163,17 @@ def read_udfs(pickleSer, infile, eval_type): # lambda a: (f0(a0), f1(a1, a2), f2(a3)) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + if eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF: + mapper_str = "lambda a: sum([%s], [])" % (", ".join(call_udf)) + else: + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala new file mode 100644 index 0000000000000..6e4781637ae71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +class ArrayDataBuffer(val buffer: ArrayBuffer[Any]) extends ArrayData { + + def this(initialCapacity: Int) = this(new ArrayBuffer[Any](initialCapacity)) + def this() = this(new ArrayBuffer[Any]()) + + override def copy(): ArrayData = { + val newValues = new ArrayBuffer[Any](buffer.length) + var i = 0 + while (i < buffer.length) { + newValues(i) = InternalRow.copyValue(buffer(i)) + i += 1 + } + new ArrayDataBuffer(newValues) + } + + override def array: Array[Any] = { + val newValues = new Array[Any](buffer.length) + var i = 0 + while (i < buffer.length) { + newValues(i) = InternalRow.copyValue(buffer(i)) + i += 1 + } + newValues + } + + override def numElements(): Int = buffer.length + + private def getAs[T](ordinal: Int) = buffer(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + + override def setNullAt(ordinal: Int): Unit = buffer(ordinal) = null + + override def update(ordinal: Int, value: Any): Unit = buffer(ordinal) = value + + def +=(value: Any): this.type = { + buffer += value + this + } + + def ++=(values: TraversableOnce[Any]): this.type = { + buffer ++= values + this + } + + def ++=(values: ArrayData): this.type = { + values match { + case buff: ArrayDataBuffer => buffer ++= buff.buffer + case _ => buffer ++= values.array + } + this + } + + def clear(): Unit = { + buffer.clear() + } + + override def toString(): String = buffer.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayDataBuffer]) { + return false + } + + val other = o.asInstanceOf[ArrayDataBuffer] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = buffer(i) + val o2 = other.buffer(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + buffer(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f404621399cea..5f2cf7d3cd189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -103,6 +103,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, + python.ExtractPythonUDAFs, PlanSubqueries(sparkSession), new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..34a259f0579a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object AggUtils { - private def createAggregate( + private[sql] def createAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 0000000000000..a94ebef04a992 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.UnaryExecNode + +trait AggregateExec extends UnaryExecNode { + + def requiredChildDistributionExpressions: Option[Seq[Expression]] + + def groupingExpressions: Seq[NamedExpression] + + def aggregateExpressions: Seq[AggregateExpression] + + def aggregateAttributes: Seq[Attribute] + + def initialInputBufferOffset: Int + + def resultExpressions: Seq[NamedExpression] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 51f7c9e22b902..9d3042f808e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -43,7 +43,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with AggregateExec with CodegenSupport { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..c95577795bd87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -65,7 +65,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AggregateExec { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..76cb3de412de2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AggregateExec { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index e27210117a1e7..2466de45b92f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -58,10 +58,15 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) /** * A physical plan that evaluates a [[PythonUDF]], */ -case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) +case class ArrowEvalPythonExec( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: SparkPlan, + evalType: Int = PythonEvalType.SQL_PANDAS_SCALAR_UDF, + _batchSize: Option[Int] = None) extends EvalPythonExec(udfs, output, child) { - private val batchSize = conf.arrowMaxRecordsPerBatch + private val batchSize = _batchSize.getOrElse(conf.arrowMaxRecordsPerBatch) private val sessionLocalTimeZone = conf.sessionLocalTimeZone protected override def evaluate( @@ -80,8 +85,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) + funcs, bufferSize, reuseWorker, evalType, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala new file mode 100644 index 0000000000000..e27c6497a5f91 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.{ArrayBuffer, Map} + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.aggregate.{AggregateExec, AggUtils} +import org.apache.spark.sql.types.{ArrayType, StructType} + +object ExtractPythonUDAFs extends Rule[SparkPlan] { + + private def isPythonUDAF(aggregateExpression: AggregateExpression): Boolean = { + aggregateExpression.aggregateFunction.isInstanceOf[PythonUDAF] + } + + private def hasPythonUDAF(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.exists(isPythonUDAF) + } + + private def hasDistinct(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.exists(_.isDistinct) + } + + override def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case agg: AggregateExec if !hasPythonUDAF(agg.aggregateExpressions) => agg + case agg: AggregateExec if hasDistinct(agg.aggregateExpressions) => + throw new AnalysisException("Vectorized UDAF with distinct is not supported.") + case agg: AggregateExec => + + val newAggExprs = ArrayBuffer.empty[AggregateExpression] ++ agg.aggregateExpressions + val newAggAttrs = ArrayBuffer.empty[Attribute] ++ agg.aggregateAttributes + + val buffers = ArrayBuffer.empty[BufferInputs] + val udafs = ArrayBuffer.empty[PythonUDF] + val udafResultAttrs = ArrayBuffer.empty[AttributeReference] + + val replacingReslutExprs = Map.empty[Expression, NamedExpression] ++ + agg.groupingExpressions.map(expr => expr -> expr.toAttribute) + + agg.aggregateExpressions.foreach { + case aggExpr if isPythonUDAF(aggExpr) => + val pythonUDAF = aggExpr.aggregateFunction.asInstanceOf[PythonUDAF] + + aggExpr.mode match { + case Partial => + val buffer = buffers.find { buf => + buf.children.length == pythonUDAF.children.length && + buf.children.zip(pythonUDAF.children).forall { case (c, child) => + c.semanticEquals(child) + } + } match { + case Some(buf) => + newAggExprs -= aggExpr + newAggAttrs --= pythonUDAF.aggBufferAttributes + + buf + case None => + val buf = BufferInputs(pythonUDAF.children) + buffers += buf + + newAggExprs.update( + newAggExprs.indexOf(aggExpr), aggExpr.copy(aggregateFunction = buf)) + + val index = newAggAttrs.indexOfSlice(pythonUDAF.aggBufferAttributes) + newAggAttrs --= pythonUDAF.aggBufferAttributes + newAggAttrs.insertAll(index, buf.aggBufferAttributes) + + buf + } + + if (pythonUDAF.udaf.supportsPartial) { + val udaf = PythonUDF(pythonUDAF.udaf.name, pythonUDAF.udaf.func, + pythonUDAF.udaf.returnType, buffer.aggBufferAttributes, + PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF) + udafs += udaf + + val (resultAttrs, replacingExprs) = pythonUDAF.inputAggBufferAttributes.map { + attr => + val arrayType = attr.dataType.asInstanceOf[ArrayType] + val resultAttr = AttributeReference( + attr.name, arrayType.elementType, arrayType.containsNull)() + (resultAttr, attr -> Alias(CreateArray(Seq(resultAttr)), attr.name)()) + }.unzip + udafResultAttrs ++= resultAttrs + replacingReslutExprs ++= replacingExprs + } else { + replacingReslutExprs ++= + pythonUDAF.inputAggBufferAttributes.zip( + buffer.inputAggBufferAttributes.zip(buffer.aggBufferAttributes)).map { + case (attr, (newAttr, buffer)) => + attr -> Alias(buffer, newAttr.name)( + newAttr.exprId, newAttr.qualifier, Option(newAttr.metadata)) + } + } + + case Final => + val buffer = BufferInputs(pythonUDAF.inputAggBufferAttributes.map { attr => + val arrayType = attr.dataType.asInstanceOf[ArrayType] + AttributeReference(attr.name, arrayType.elementType, arrayType.containsNull)() + }) + + newAggExprs.update( + newAggExprs.indexOf(aggExpr), aggExpr.copy(aggregateFunction = buffer)) + + val bufferOut = AttributeReference("buffer", buffer.dataType, buffer.nullable)() + newAggAttrs.update(newAggAttrs.indexOf(aggExpr.resultAttribute), bufferOut) + + val udafInputs = buffer.dataType.asInstanceOf[StructType].zipWithIndex.map { + case (field, idx) => + GetStructField(bufferOut, idx, Option(field.name)) + } + val udaf = PythonUDF(pythonUDAF.udaf.name, pythonUDAF.udaf.func, + pythonUDAF.udaf.returnType, udafInputs, + PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF) + udafs += udaf + + val resultAttr = AttributeReference(udaf.name, udaf.dataType, udaf.nullable)() + udafResultAttrs += resultAttr + replacingReslutExprs += aggExpr.resultAttribute -> resultAttr + + case _ => + throw new AnalysisException(s"Unsupported aggregate mode: ${aggExpr.mode}.") + } + case aggExpr => + aggExpr.mode match { + case Partial => + val af = aggExpr.aggregateFunction + replacingReslutExprs ++= + af.inputAggBufferAttributes.zip(af.aggBufferAttributes).map { + case (attr, buffer) => + attr -> Alias(buffer, attr.name)( + attr.exprId, attr.qualifier, Option(attr.metadata)) + } + case _ => + } + } + + val newAgg = AggUtils.createAggregate( + requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions, + groupingExpressions = agg.groupingExpressions, + aggregateExpressions = newAggExprs, + aggregateAttributes = newAggAttrs, + initialInputBufferOffset = agg.initialInputBufferOffset, + resultExpressions = agg.groupingExpressions ++ newAggAttrs, + child = agg.child) + + val exec = if (udafs.size > 0) { + ArrowEvalPythonExec( + udafs, + newAgg.output ++ udafResultAttrs, + newAgg, + PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF, + Some(1)) + } else { + newAgg + } + + val project = agg.resultExpressions.map { expr => + expr.transformUp { + case expr if replacingReslutExprs.contains(expr) => replacingReslutExprs(expr) + }.asInstanceOf[NamedExpression] + } + + ProjectExec(project, exec) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala new file mode 100644 index 0000000000000..1992ef47e760d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.util.{ArrayData, ArrayDataBuffer} +import org.apache.spark.sql.types._ + +/** + * A user-defined aggregate Python function. This is used by the Python API. + */ +case class UserDefinedAggregatePythonFunction( + name: String, + func: PythonFunction, + returnType: DataType, + supportsPartial: Boolean) { + + /** + * Creates a `Column` for this UDAF using given `Column`s as input arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression( + PythonUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregatePythonFunction]] `udaf` in the + * internal aggregation code path. + */ +case class PythonUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregatePythonFunction) + extends AggregateFunction + with Unevaluable + with NonSQLExpression + with UserDefinedExpression { + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnType + + override lazy val aggBufferSchema: StructType = if (udaf.supportsPartial) { + udaf.returnType match { + case StructType(fields) => StructType(fields.map { field => + StructField(field.name, + ArrayType(field.dataType, containsNull = field.nullable), nullable = false) + }) + case dt => + new StructType().add(udaf.name, ArrayType(dt, containsNull = true), nullable = false) + } + } else { + StructType(children.zipWithIndex.map { case (child, i) => + StructField(s"_$i", ArrayType(child.dataType, containsNull = true), nullable = false) + }) + } + + override lazy val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + override def toString: String = { + s"${udaf.name}(${children.mkString(",")})" + } + + override def nodeName: String = udaf.name +} + +case class BufferInputs( + children: Seq[Expression], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate + with NonSQLExpression + with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = aggBufferSchema + + override val aggBufferSchema: StructType = + StructType(children.zipWithIndex.map { + case (child, i) => + StructField(s"_$i", ArrayType(child.dataType, child.nullable), nullable = false) + }) + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + private[this] lazy val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + private lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + GenerateMutableProjection.generate(children, inputAttributes) + } + + override def initialize(buffer: InternalRow): Unit = { + aggBufferSchema.zipWithIndex.foreach { case (_, i) => + buffer.update(i + mutableAggBufferOffset, new ArrayDataBuffer()) + } + } + + override def update(buffer: InternalRow, input: InternalRow): Unit = { + val projected = inputProjection(input) + aggBufferSchema.zip(childrenSchema).zipWithIndex.foreach { + case ((StructField(_, dt @ ArrayType(_, _), _, _), childSchema), i) => + val bufferOffset = i + mutableAggBufferOffset + val arrayDataBuffer = + buffer.get(bufferOffset, dt).asInstanceOf[ArrayDataBuffer] + if (projected.isNullAt(i)) { + arrayDataBuffer += null + } else { + arrayDataBuffer += InternalRow.copyValue(projected.get(i, childSchema.dataType)) + } + } + } + + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + aggBufferSchema.zipWithIndex.foreach { + case (StructField(_, dt @ ArrayType(elementType, _), _, _), i) => + val bufferOffset = i + mutableAggBufferOffset + val inputOffset = i + inputAggBufferOffset + val arrayDataBuffer1 = buffer1.get(bufferOffset, dt).asInstanceOf[ArrayDataBuffer] + buffer2.get(inputOffset, dt) match { + case arrayDataBuffer2: UnsafeArrayData => + elementType match { + case BooleanType => arrayDataBuffer1 ++= arrayDataBuffer2.toBooleanArray() + case ByteType => arrayDataBuffer1 ++= arrayDataBuffer2.toByteArray() + case ShortType => arrayDataBuffer1 ++= arrayDataBuffer2.toShortArray() + case IntegerType => arrayDataBuffer1 ++= arrayDataBuffer2.toIntArray() + case LongType => arrayDataBuffer1 ++= arrayDataBuffer2.toLongArray() + case FloatType => arrayDataBuffer1 ++= arrayDataBuffer2.toFloatArray() + case DoubleType => arrayDataBuffer1 ++= arrayDataBuffer2.toDoubleArray() + } + case arrayDataBuffer2: ArrayData => + arrayDataBuffer1 ++= arrayDataBuffer2 + } + } + } + + private val row = new GenericInternalRow(aggBufferSchema.size) + + override def eval(buffer: InternalRow): Any = { + aggBufferSchema.zipWithIndex.foreach { case (buffSchema, i) => + val bufferOffset = i + mutableAggBufferOffset + row.update(i, buffer.get(bufferOffset, buffSchema.dataType)) + } + row + } +}