diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4dcc5eb5fbfcd..fe28ec6fc1c29 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -46,6 +46,7 @@ private[spark] object PythonEvalType { val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 val SQL_WINDOW_AGG_PANDAS_UDF = 203 + val SQL_COGROUPED_MAP_PANDAS_UDF = 204 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -54,6 +55,7 @@ private[spark] object PythonEvalType { case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" + case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f0682e71a1780..f198316d0c4a5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -73,6 +73,7 @@ class PythonEvalType(object): SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 SQL_WINDOW_AGG_PANDAS_UDF = 203 + SQL_COGROUPED_MAP_PANDAS_UDF = 204 def portable_hash(x): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 516ee7e7b3084..e818241aaeebf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -359,6 +359,24 @@ def __repr__(self): return "ArrowStreamPandasSerializer" +class InterleavedArrowReader(object): + + def __init__(self, stream): + import pyarrow as pa + self._schema1 = pa.read_schema(stream) + self._schema2 = pa.read_schema(stream) + self._reader = pa.MessageReader.open_stream(stream) + + def __iter__(self): + return self + + def __next__(self): + import pyarrow as pa + batch1 = pa.read_record_batch(self._reader.read_next_message(), self._schema1) + batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2) + return batch1, batch2 + + class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): """ Serializer used by Python worker to evaluate Pandas UDFs @@ -404,6 +422,22 @@ def __repr__(self): return "ArrowStreamPandasUDFSerializer" +class InterleavedArrowStreamPandasSerializer(ArrowStreamPandasUDFSerializer): + + def __init__(self, timezone, safecheck, assign_cols_by_name): + super(InterleavedArrowStreamPandasSerializer, self).__init__(timezone, safecheck, assign_cols_by_name) + + def load_stream(self, stream): + """ + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. + """ + import pyarrow as pa + reader = InterleavedArrowReader(pa.input_stream(stream)) + for batch1, batch2 in reader: + yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch1]).itercolumns()], + [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch2]).itercolumns()]) + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py new file mode 100644 index 0000000000000..18dc397c8e348 --- /dev/null +++ b/python/pyspark/sql/cogroup.py @@ -0,0 +1,21 @@ +from pyspark.sql.dataframe import DataFrame + + +class CoGroupedData(object): + + def __init__(self, gd1, gd2): + self._gd1 = gd1 + self._gd2 = gd2 + self.sql_ctx = gd1.sql_ctx + + def apply(self, udf): + all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) + udf_column = udf(*all_cols) + jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) + + @staticmethod + def _extract_cols(gd): + df = gd._df + return [df[col] for col in df.columns] + diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 613822b7edf2d..f156b9e9bd984 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2798,6 +2798,8 @@ class PandasUDFType(object): GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @@ -3179,6 +3181,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index cc1da8e7c1f72..04f42b1598376 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -22,6 +22,7 @@ from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * +from pyspark.sql.cogroup import CoGroupedData __all__ = ["GroupedData"] @@ -220,6 +221,9 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self._df) + def cogroup(self, other): + return CoGroupedData(self, other) + @since(2.3) def apply(self, udf): """ diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py new file mode 100644 index 0000000000000..d74f9b10325ed --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import unittest +import sys + +from collections import OrderedDict +from decimal import Decimal + +from pyspark.sql import Row +from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + +if have_pandas: + import pandas as pd + from pandas.util.testing import assert_frame_equal + +if have_pyarrow: + import pyarrow as pa + + +""" +Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +""" +if sys.version < '3': + _check_column_type = False +else: + _check_column_type = True + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): + + @property + def data1(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks')))\ + .withColumn("v", col('k') * 10)\ + .drop('ks') + + @property + def data2(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks'))) \ + .withColumn("v2", col('k') * 100) \ + .drop('ks') + + def test_simple(self): + import pandas as pd + + l = self.data1 + r = self.data2 + + @pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP) + def merge_pandas(left, right): + return pd.merge(left, right, how='outer', on=['k', 'id']) + + # TODO: Grouping by a string fails to resolve here as analyzer cannot determine side + result = l\ + .groupby(l.id)\ + .cogroup(r.groupby(r.id))\ + .apply(merge_pandas)\ + .sort(['id', 'k'])\ + .toPandas() + + expected = pd\ + .merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id']) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 16257bef6b320..1a423920b77f5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,7 +38,7 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasUDFSerializer + BatchedSerializer, ArrowStreamPandasUDFSerializer, InterleavedArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -103,8 +103,25 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_cogrouped_map_pandas_udf(f, return_type): + def wrapped(left, right): + import pandas as pd + result = f(pd.concat(left, axis=1), pd.concat(right, axis=1)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result + + return lambda v: [(wrapped(v[0], v[1]), to_arrow_type(return_type))] + + +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd @@ -219,6 +236,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -233,6 +252,7 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): @@ -255,9 +275,12 @@ def read_udfs(pickleSer, infile, eval_type): # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. - df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, - df_for_struct) + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + ser = InterleavedArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) + else: + df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, + df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) @@ -282,6 +305,14 @@ def read_udfs(pickleSer, infile, eval_type): arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + # We assume there is only one UDF here because cogrouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0) + udfs['f'] = udf + mapper_str = "lambda a: f(a)" else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 2df30a1a53ad7..35adbae423f25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas( override val producedAttributes = AttributeSet(output) } + +case class FlatMapCoGroupsInPandas( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + override val producedAttributes = AttributeSet(output) +} + + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e85636d82a62c..147cc00c0ba91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], + private val df: DataFrame, + private val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -523,6 +523,33 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + private[sql] def flatMapCoGroupsInPandas + (r: RelationalGroupedDataset, expr: PythonUDF): DataFrame = { + require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + "Must pass a cogrouped map udf") + require(expr.dataType.isInstanceOf[StructType], + s"The returnType of the udf must be a ${StructType.simpleString}") + + val leftGroupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val rightGroupingNamedExpressions = r.groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute) + val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute) + val left = df.logicalPlan + val right = r.df.logicalPlan + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right) + + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4031496f610f..965f04c058966 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -679,6 +679,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, p, b, is, ot, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil + case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) => + execution.python.FlatMapCoGroupsInPandasExec( + leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 3710218b2af5f..5e00eecf1b230 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( schema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + extends BaseArrowPythonRunner[Iterator[InternalRow]]( funcs, evalType, argOffsets) { protected override def newWriterThread( @@ -112,72 +112,4 @@ class ArrowPythonRunner( } } - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - batch - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala new file mode 100644 index 0000000000000..3cba06dcf7d52 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + + +abstract class BaseArrowPythonRunner[T]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[T, ColumnarBatch]( + funcs, evalType, argOffsets) { + + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala new file mode 100644 index 0000000000000..12620264de087 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, GroupedIterator, SparkPlan} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +case class FlatMapCoGroupsInPandasExec( + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) + extends BinaryExecNode { + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + leftGroup + .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + } + + + override protected def doExecute(): RDD[InternalRow] = { + + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + val cogroup = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map{case (k, l, r) => (l, r)} + val context = TaskContext.get() + + val columnarBatchIter = new InterleavedArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + Array(Array.empty), + left.schema, + right.schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(cogroup, context.partitionId(), context) + + + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + } + + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala new file mode 100644 index 0000000000000..b39885ee47a2d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ + +import org.apache.arrow.vector.VectorSchemaRoot + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + + +class InterleavedArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + leftSchema: StructType, + rightSchema: StructType, + timeZoneId: String, + conf: Map[String, String]) + extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])]( + funcs, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val leftArrowSchema = ArrowUtils.toArrowSchema(leftSchema, timeZoneId) + val rightArrowSchema = ArrowUtils.toArrowSchema(rightSchema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + val leftRoot = VectorSchemaRoot.create(leftArrowSchema, allocator) + val rightRoot = VectorSchemaRoot.create(rightArrowSchema, allocator) + + Utils.tryWithSafeFinally { + val leftArrowWriter = ArrowWriter.create(leftRoot) + val rightArrowWriter = ArrowWriter.create(rightRoot) + val writer = InterleavedArrowWriter(leftRoot, rightRoot, dataOut) + writer.start() + + while (inputIterator.hasNext) { + + val (nextLeft, nextRight) = inputIterator.next() + + while (nextLeft.hasNext) { + leftArrowWriter.write(nextLeft.next()) + } + while (nextRight.hasNext) { + rightArrowWriter.write(nextRight.next()) + } + leftArrowWriter.finish() + rightArrowWriter.finish() + writer.writeBatch() + leftArrowWriter.reset() + rightArrowWriter.reset() + } + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + leftRoot.close() + rightRoot.close() + allocator.close() + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala new file mode 100644 index 0000000000000..eb9f1d4494b91 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.OutputStream +import java.nio.channels.Channels + +import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.ipc.WriteChannel +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + + +class InterleavedArrowWriter( leftRoot: VectorSchemaRoot, + rightRoot: VectorSchemaRoot, + out: WriteChannel) extends AutoCloseable{ + + + private var started = false + private val leftUnloader = new VectorUnloader(leftRoot) + private val rightUnloader = new VectorUnloader(rightRoot) + + def start(): Unit = { + this.ensureStarted() + } + + def writeBatch(): Unit = { + this.ensureStarted() + writeRecordBatch(leftUnloader.getRecordBatch) + writeRecordBatch(rightUnloader.getRecordBatch) + } + + private def writeRecordBatch(b: ArrowRecordBatch): Unit = { + try + MessageSerializer.serialize(out, b) + finally + b.close() + } + + private def ensureStarted(): Unit = { + if (!started) { + started = true + MessageSerializer.serialize(out, leftRoot.getSchema) + MessageSerializer.serialize(out, rightRoot.getSchema) + } + } + + def end(): Unit = { + ensureStarted() + ensureEnded() + } + + def ensureEnded(): Unit = { + out.writeIntLittleEndian(0) + } + + def close(): Unit = { + out.close() + } + +} + +object InterleavedArrowWriter{ + + def apply(leftRoot: VectorSchemaRoot, + rightRoot: VectorSchemaRoot, + out: OutputStream): InterleavedArrowWriter = { + new InterleavedArrowWriter(leftRoot, rightRoot, new WriteChannel(Channels.newChannel(out))) + } + +}