diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f12e2dfafa19..d779ce76d434 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -71,6 +71,14 @@ private[spark] class PythonRDD( } } +private[spark] case class PythonFunction( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]]) /** * A helper class to run Python UDFs in Spark. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fe2264a63cf3..8db38bcf0c7f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2330,6 +2330,15 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 83b034fe7743..b5260224c2ee 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -28,7 +28,8 @@ from pyspark import since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer, PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -236,9 +237,14 @@ def collect(self): >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ + + if self._jdf.isPickled(): + deserializer = PickleSerializer() + else: + deserializer = BatchedSerializer(PickleSerializer()) with SCCallSiteSync(self._sc) as css: port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(port, deserializer)) @ignore_unicode_prefix @since(1.3) @@ -278,6 +284,25 @@ def map(self, f): """ return self.rdd.map(f) + @ignore_unicode_prefix + @since(2.0) + def applySchema(self, schema=None): + """ TODO """ + # TODO: should we throw exception instead? + return self + + @ignore_unicode_prefix + @since(2.0) + def mapPartitions2(self, func): + """ TODO """ + return PipelinedDataFrame(self, func) + + @ignore_unicode_prefix + @since(2.0) + def map2(self, func): + """ TODO """ + return self.mapPartitions2(lambda iterator: map(func, iterator)) + @ignore_unicode_prefix @since(1.3) def flatMap(self, f): @@ -890,10 +915,20 @@ def groupBy(self, *cols): >>> sorted(df.groupBy(['name', df.age]).count().collect()) [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] """ - jgd = self._jdf.groupBy(self._jcols(*cols)) + jgd = self._jdf.pythonGroupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx) + @ignore_unicode_prefix + @since(2.0) + def groupByKey(self, key_func, key_type): + """ TODO """ + f = lambda iterator: map(key_func, iterator) + wraped_func = _wrap_func(self._sc, self._jdf, f, False) + jgd = self._jdf.pythonGroupBy(wraped_func, key_type.json()) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx, not isinstance(key_type, StructType)) + @since(1.4) def rollup(self, *cols): """ @@ -1354,6 +1389,99 @@ def toPandas(self): drop_duplicates = dropDuplicates +class PipelinedDataFrame(DataFrame): + + """ TODO """ + + def __init__(self, prev, func): + from pyspark.sql.group import GroupedData + + self._jdf_val = None + self.is_cached = False + self.sql_ctx = prev.sql_ctx + self._sc = self.sql_ctx and self.sql_ctx._sc + self._lazy_rdd = None + + if isinstance(prev, GroupedData): + # prev is GroupedData, set the grouped flag to true and use jgd as jdf. + self._grouped = True + self._func = func + self._prev_jdf = prev._jgd + elif not isinstance(prev, PipelinedDataFrame) or prev.is_cached: + # This transformation is the first in its stage: + self._grouped = False + self._func = func + self._prev_jdf = prev._jdf + else: + self._grouped = prev._grouped + self._func = _pipeline_func(prev._func, func) + # maintain the pipeline. + self._prev_jdf = prev._prev_jdf + + def applySchema(self, schema=None): + if schema is None: + from pyspark.sql.types import _infer_type, _merge_type + # If no schema is specified, infer it from the whole data set. + jrdd = self._prev_jdf.javaToPython() + rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) + schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) + + if isinstance(schema, StructType): + to_rows = lambda iterator: map(schema.toInternal, iterator) + else: + data_type = schema + schema = StructType().add("value", data_type) + to_row = lambda obj: (data_type.toInternal(obj), ) + to_rows = lambda iterator: map(to_row, iterator) + + jdf = self._create_jdf(_pipeline_func(self._func, to_rows), schema) + return DataFrame(jdf, self.sql_ctx) + + @property + def _jdf(self): + if self._jdf_val is None: + self._jdf_val = self._create_jdf(self._func) + return self._jdf_val + + def _create_jdf(self, func, schema=None): + wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None, self._grouped) + if schema is None: + if self._grouped: + return self._prev_jdf.flatMapGroups(wrapped_func) + else: + return self._prev_jdf.pythonMapPartitions(wrapped_func) + else: + schema_string = schema.json() + if self._grouped: + return self._prev_jdf.flatMapGroups(wrapped_func, schema_string) + else: + return self._prev_jdf.pythonMapPartitions(wrapped_func, schema_string) + + +def _wrap_func(sc, jdf, func, output_binary, input_grouped=False): + if input_grouped: + deserializer = PairDeserializer(PickleSerializer(), PickleSerializer()) + elif jdf.isPickled(): + deserializer = PickleSerializer() + else: + deserializer = AutoBatchedSerializer(PickleSerializer()) + + if output_binary: + serializer = PickleSerializer() + else: + serializer = AutoBatchedSerializer(PickleSerializer()) + + from pyspark.rdd import _wrap_function + return _wrap_function(sc, lambda _, iterator: func(iterator), deserializer, serializer) + + +def _pipeline_func(prev_func, next_func): + if prev_func is None: + return next_func + else: + return lambda iterator: next_func(prev_func(iterator)) + + def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ee734cb43928..7e344d962437 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,10 +15,19 @@ # limitations under the License. # +import sys + +if sys.version >= '3': + basestring = unicode = str + long = int + from functools import reduce +else: + from itertools import imap as map + from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal -from pyspark.sql.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame, PipelinedDataFrame from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -27,7 +36,7 @@ def dfapi(f): def _api(self): name = f.__name__ - jdf = getattr(self._jdf, name)() + jdf = getattr(self._jgd, name)() return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -37,7 +46,7 @@ def _api(self): def df_varargs_api(f): def _api(self, *args): name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, args)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -54,9 +63,33 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jdf, sql_ctx): - self._jdf = jdf + def __init__(self, jgd, sql_ctx, flat_key=False): + self._jgd = jgd self.sql_ctx = sql_ctx + if flat_key: + self._key_converter = lambda key: key[0] + else: + self._key_converter = lambda key: key + + @ignore_unicode_prefix + @since(2.0) + def flatMapGroups(self, func): + """ TODO """ + key_converter = self._key_converter + + def process(inputs): + record_converter = lambda record: (key_converter(record[0]), record[1]) + for key, values in GroupedIterator(map(record_converter, inputs)): + for output in func(key, values): + yield output + + return PipelinedDataFrame(self, process) + + @ignore_unicode_prefix + @since(2.0) + def mapGroups(self, func): + """ TODO """ + return self.flatMapGroups(lambda key, values: iter([func(key, values)])) @ignore_unicode_prefix @since(1.3) @@ -83,11 +116,11 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) + jdf = self._jgd.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, + jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @@ -187,12 +220,101 @@ def pivot(self, pivot_col, values=None): [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: - jgd = self._jdf.pivot(pivot_col) + jgd = self._jgd.pivot(pivot_col) else: - jgd = self._jdf.pivot(pivot_col, values) + jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) +class GroupedIterator(object): + """ TODO """ + + def __init__(self, inputs): + self.inputs = BufferedIterator(inputs) + self.current_input = next(inputs) + self.current_key = self.current_input[0] + self.current_values = GroupValuesIterator(self) + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + if self.current_values is None: + self._fetch_next_group() + + ret = (self.current_key, self.current_values) + self.current_values = None + return ret + + def _fetch_next_group(self): + if self.current_input is None: + self.current_input = next(self.inputs) + + # Skip to next group, or consume all inputs and throw StopIteration exception. + while self.current_input[0] == self.current_key: + self.current_input = next(self.inputs) + + self.current_key = self.current_input[0] + self.current_values = GroupValuesIterator(self) + + +class GroupValuesIterator(object): + """ TODO """ + + def __init__(self, outter): + self.outter = outter + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + if self.outter.current_input is None: + self._fetch_next_value() + + value = self.outter.current_input[1] + self.outter.current_input = None + return value + + def _fetch_next_value(self): + if self.outter.inputs.head()[0] == self.outter.current_key: + self.outter.current_input = next(self.outter.inputs) + else: + raise StopIteration + + +class BufferedIterator(object): + """ TODO """ + + def __init__(self, iterator): + self.iterator = iterator + self.buffered = None + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + if self.buffered is None: + return next(self.iterator) + else: + item = self.buffered + self.buffered = None + return item + + def head(self): + if self.buffered is None: + self.buffered = next(self.iterator) + return self.buffered + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a79692..314683404b33 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,13 @@ import time import datetime +if sys.version >= '3': + basestring = unicode = str + long = int + from functools import reduce +else: + from itertools import imap as map + import py4j try: import xmlrunner @@ -1153,6 +1160,97 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_dataset(self): + data = [(i, str(i)) for i in range(100)] + ds = self.sqlCtx.createDataFrame(data, ("key", "value")) + + def check_result(result, f): + expected_result = [] + for k, v in data: + expected_result.append(f(k, v)) + self.assertEqual(result, expected_result) + + # convert row to python dict + ds2 = ds.map2(lambda row: {"key": row.key + 1, "value": row.value}) + schema = StructType().add("key", IntegerType()).add("value", StringType()) + ds3 = ds2.applySchema(schema) + result = ds3.select("key").collect() + check_result(result, lambda k, v: Row(key=k + 1)) + + # use a different but compatible schema + schema = StructType().add("value", StringType()) + result = ds2.applySchema(schema).collect() + check_result(result, lambda k, v: Row(value=v)) + + # use a flat schema + ds2 = ds.map2(lambda row: row.key * 3) + result = ds2.applySchema(IntegerType()).collect() + check_result(result, lambda k, v: Row(value=k * 3)) + + # schema can be inferred automatically + result = ds.map2(lambda row: row.key + 10).applySchema().collect() + check_result(result, lambda k, v: Row(value=k + 10)) + + # If no schema is given, by default it's a single binary field struct type. + from pyspark.sql.functions import length + result = ds2.select(length("value")).collect() + self.assertEqual(len(result), 100) + + # If no schema is given, collect will return custom objects instead of rows. + result = ds.map2(lambda row: row.value + "#").collect() + check_result(result, lambda k, v: v + "#") + + # row count should be corrected even no schema is specified. + self.assertEqual(ds.map2(lambda row: row.key + 1).count(), 100) + + # call cache() in the middle of 2 typed operations. + ds3 = ds.map2(lambda row: row.key * 2).cache().map2(lambda key: key + 1) + self.assertEqual(ds3.count(), 100) + result = ds3.collect() + check_result(result, lambda k, v: k * 2 + 1) + + def test_typed_aggregate(self): + data = [(i, i * 2) for i in range(100)] + ds = self.sqlCtx.createDataFrame(data, ("i", "j")) + sum_tuple = lambda values: sum(map(lambda value: value[0] * value[1], values)) + + def get_python_result(data, key_func, agg_func): + data = sorted(data, key=key_func) + expected_result = [] + import itertools + for key, values in itertools.groupby(data, key_func): + expected_result.append(agg_func(key, values)) + return expected_result + + grouped = ds.groupByKey(lambda row: row.i % 5, IntegerType()) + agg_func = lambda key, values: str(key) + ":" + str(sum_tuple(values)) + result = sorted(grouped.mapGroups(agg_func).collect()) + expected_result = get_python_result(data, lambda i: i[0] % 5, agg_func) + self.assertEqual(result, expected_result) + + # We can also call groupByKey on a Dataset of custom objects. + ds2 = ds.map2(lambda row: row.i) + grouped = ds2.groupByKey(lambda i: i % 5, IntegerType()) + agg_func = lambda key, values: str(key) + ":" + str(sum(values)) + result = sorted(grouped.mapGroups(agg_func).collect()) + expected_result = get_python_result(range(100), lambda i: i % 5, agg_func) + self.assertEqual(result, expected_result) + + # We can also apply typed aggregate after structured groupBy, the key is row object. + grouped = ds.groupBy(ds.i % 2, ds.i % 3) + agg_func = lambda key, values: str(key[0]) + str(key[1]) + ":" + str(sum_tuple(values)) + result = sorted(grouped.mapGroups(agg_func).collect()) + expected_result = get_python_result(data, lambda i: (i[0] % 2, i[0] % 3), agg_func) + self.assertEqual(result, expected_result) + + # We can also apply structured aggregate after groupByKey + grouped = ds.groupByKey(lambda row: row.i % 5, IntegerType()) + result = sorted(grouped.sum("j").collect()) + get_sum = lambda key: sum(filter(lambda i: i % 5 == key, range(100))) * 2 + result_row = Row("key", "sum(j)") + expected_result = [result_row(i, get_sum(i)) for i in range(5)] + self.assertEqual(result, expected_result) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa77e..83389ae2ef7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -88,8 +88,6 @@ package object expressions { */ implicit class AttributeSeq(attrs: Seq[Attribute]) { /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ - def toStructType: StructType = { - StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) - } + def toStructType: StructType = StructType.fromAttributes(attrs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 3f97662957b8..83cc20f5640d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} /** * A trait for logical operators that apply user defined functions to domain objects. @@ -91,6 +92,13 @@ case class MapPartitions( override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } +case class PythonMapPartitions( + func: PythonFunction, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def expressions: Seq[Expression] = Nil +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( @@ -124,6 +132,17 @@ case class AppendColumns( override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } +case class PythonAppendColumns( + func: PythonFunction, + newColumns: Seq[Attribute], + isFlat: Boolean, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output ++ newColumns + + override def expressions: Seq[Expression] = Nil +} + /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( @@ -164,6 +183,15 @@ case class MapGroups( Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) } +case class PythonMapGroups( + func: PythonFunction, + groupingExprs: Seq[Expression], + dataAttributes: Seq[Attribute], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def expressions: Seq[Expression] = groupingExprs +} + /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( @@ -208,8 +236,6 @@ case class CoGroup( left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectOperator { - override def producedAttributes: AttributeSet = outputSet - override def deserializers: Seq[(Expression, Seq[Attribute])] = // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve // the `keyDeserializer` based on either of them, here we pick the left one. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 76c09a285dc4..154a60b0bf97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -27,7 +27,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonFunction, PythonRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ @@ -1750,9 +1750,13 @@ class DataFrame private[sql]( * Converts a JavaRDD to a PythonRDD. */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val structType = schema // capture it for closure - val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) - EvaluatePython.javaToPython(rdd) + if (EvaluatePython.isPickled(schema)) { + queryExecution.toRdd.map(_.getBinary(0)) + } else { + val structType = schema // capture it for closure + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) + } } protected[sql] def collectToPython(): Int = { @@ -1761,6 +1765,49 @@ class DataFrame private[sql]( } } + protected[sql] def pythonMapPartitions(func: PythonFunction): DataFrame = withPlan { + PythonMapPartitions(func, EvaluatePython.schemaOfPickled.toAttributes, logicalPlan) + } + + protected[sql] def pythonMapPartitions( + func: PythonFunction, + schemaJson: String): DataFrame = withPlan { + val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] + PythonMapPartitions(func, schema.toAttributes, logicalPlan) + } + + protected[sql] def pythonGroupBy( + func: PythonFunction, + keyTypeJson: String): GroupedPythonDataset = { + val keyType = DataType.fromJson(keyTypeJson) + val isFlat = !keyType.isInstanceOf[StructType] + val keyAttributes = if (isFlat) { + Seq(AttributeReference("key", keyType)()) + } else { + keyType.asInstanceOf[StructType].toAttributes + } + + val inputPlan = queryExecution.analyzed + val withGroupingKey = PythonAppendColumns(func, keyAttributes, isFlat, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new GroupedPythonDataset( + executed, + withGroupingKey.newColumns, + inputPlan.output, + GroupedData.GroupByType) + } + + protected[sql] def pythonGroupBy(cols: Column*): GroupedPythonDataset = { + new GroupedPythonDataset( + queryExecution, + cols.map(_.expr), + queryExecution.analyzed.output, + GroupedData.GroupByType) + } + + protected[sql] def isPickled(): Boolean = EvaluatePython.isPickled(schema) + /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala new file mode 100644 index 000000000000..adc47e14b5cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.GroupedData.GroupType +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.logical.PythonMapGroups +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.types.{DataType, StructType} + +class GroupedPythonDataset private[sql]( + queryExecution: QueryExecution, + groupingExprs: Seq[Expression], + dataAttributes: Seq[Attribute], + groupType: GroupType) { + + private def sqlContext = queryExecution.sqlContext + + protected[sql] def isPickled(): Boolean = EvaluatePython.isPickled(dataAttributes.toStructType) + + private def groupedData = + new GroupedData( + new DataFrame(sqlContext, queryExecution), groupingExprs, GroupedData.GroupByType) + + @scala.annotation.varargs + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + groupedData.agg(aggExpr, aggExprs: _*) + } + + def agg(exprs: Map[String, String]): DataFrame = groupedData.agg(exprs) + + def agg(exprs: java.util.Map[String, String]): DataFrame = groupedData.agg(exprs) + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) + + def count(): DataFrame = groupedData.count() + + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = groupedData.mean(colNames: _*) + + @scala.annotation.varargs + def max(colNames: String*): DataFrame = groupedData.max(colNames: _*) + + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = groupedData.avg(colNames: _*) + + @scala.annotation.varargs + def min(colNames: String*): DataFrame = groupedData.min(colNames: _*) + + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = groupedData.sum(colNames: _*) + + def pivot(pivotColumn: String): GroupedData = groupedData.pivot(pivotColumn) + + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = + groupedData.pivot(pivotColumn, values) + + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = + groupedData.pivot(pivotColumn, values) + + def flatMapGroups(f: PythonFunction, schemaJson: String): DataFrame = { + val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] + internalFlatMapGroups(f, schema) + } + + def flatMapGroups(f: PythonFunction): DataFrame = { + internalFlatMapGroups(f, EvaluatePython.schemaOfPickled) + } + + private def internalFlatMapGroups(f: PythonFunction, schema: StructType): DataFrame = { + new DataFrame( + sqlContext, + PythonMapGroups( + f, + groupingExprs, + dataAttributes, + schema.toAttributes, + queryExecution.logical)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73fd22b38e1d..470dcaa0f6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -371,6 +371,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ python.EvaluatePython(udf, child, _) => python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case logical.PythonMapPartitions(func, output, child) => + execution.PythonMapPartitions(func, output, planLater(child)) :: Nil + case logical.PythonAppendColumns(func, newColumns, isFlat, child) => + execution.PythonAppendColumns(func, newColumns, isFlat, planLater(child)) :: Nil + case logical.PythonMapGroups(func, grouping, data, output, child) => + execution.PythonMapGroups(func, grouping, data, output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 582dda8603f4..509140ca2648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{PythonFunction, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} /** * Helper functions for physical operators that work with user defined objects. @@ -67,6 +74,73 @@ case class MapPartitions( } } +case class PythonMapPartitions( + func: PythonFunction, + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def expressions: Seq[Expression] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val isChildPickled = EvaluatePython.isPickled(child.schema) + val isOutputPickled = EvaluatePython.isPickled(schema) + + inputRDD.mapPartitions { iter => + val inputIterator = if (isChildPickled) { + iter.map(_.getBinary(0)) + } else { + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(row, child.schema) + }.toArray + pickle.dumps(toBePickled) + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val toUnsafe = UnsafeProjection.create(output, output) + + if (isOutputPickled) { + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + toUnsafe(row) + } + } else { + val unpickle = new Unpickler + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + } + } + } + } +} + /** * Applies the given function to each input row, appending the encoded result at the end of the row. */ @@ -96,6 +170,92 @@ case class AppendColumns( } } +case class PythonAppendColumns( + func: PythonFunction, + newColumns: Seq[Attribute], + isFlat: Boolean, + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output ++ newColumns + + override def expressions: Seq[Expression] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val newColumnSchema = newColumns.toStructType + val isChildPickled = EvaluatePython.isPickled(child.schema) + + inputRDD.mapPartitionsInternal { iter => + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.LinkedList[InternalRow]() + + val inputIterator = if (isChildPickled) { + iter.map { row => + queue.add(row) + row.getBinary(0) + } + } else { + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + queue.add(row) + EvaluatePython.toJava(row, child.schema) + }.toArray + pickle.dumps(toBePickled) + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val toUnsafe = UnsafeProjection.create(newColumns, newColumns) + val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) + + val newData = outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + } + + val newRows = if (isFlat) { + val row = new GenericMutableRow(1) + newData.map { key => + row(0) = EvaluatePython.fromJava(key, newColumns.head.dataType) + toUnsafe(row) + } + } else { + newData.map { key => + toUnsafe(EvaluatePython.fromJava(key, newColumnSchema).asInstanceOf[InternalRow]) + } + } + + newRows.map { newRow => + combiner.join(queue.poll().asInstanceOf[UnsafeRow], newRow) + } + } + } +} + /** * Groups the input rows together and calls the function with each group and an iterator containing * all elements in the group. The result of this function is encoded and flattened before @@ -136,6 +296,89 @@ case class MapGroups( } } +case class PythonMapGroups( + func: PythonFunction, + groupingExprs: Seq[Expression], + dataAttributes: Seq[Attribute], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def expressions: Seq[Expression] = groupingExprs + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingExprs) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExprs.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + val keySchema = StructType(groupingExprs.map(_.dataType).map(dt => StructField("k", dt))) + val valueSchema = dataAttributes.toStructType + val isValuePickled = EvaluatePython.isPickled(valueSchema) + val isOutputPickled = EvaluatePython.isPickled(schema) + + inputRDD.mapPartitionsInternal { iter => + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + val getKey = UnsafeProjection.create(groupingExprs, child.output) + val getValue: InternalRow => InternalRow = if (dataAttributes == child.output) { + identity + } else { + UnsafeProjection.create(dataAttributes, child.output) + } + + val inputIterator = iter.map { input => + val keyBytes = pickle.dumps(EvaluatePython.toJava(getKey(input), keySchema)) + val valueBytes = if (isValuePickled) { + input.getBinary(0) + } else { + pickle.dumps(EvaluatePython.toJava(getValue(input), valueSchema)) + } + keyBytes -> valueBytes + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val toUnsafe = UnsafeProjection.create(output, output) + + if (isOutputPickled) { + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + toUnsafe(row) + } + } else { + val unpickle = new Unpickler + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + } + } + } + } +} + /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 8c46516594a2..78aa0cc33fd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -248,6 +248,17 @@ object EvaluatePython { } } + val schemaOfPickled = { + val metaPickled = new MetadataBuilder().putBoolean("pickled", true).build() + new StructType().add("value", BinaryType, nullable = false, metadata = metaPickled) + } + + def isPickled(schema: StructType): Boolean = schema.length == 1 && { + val field = schema.head + field.dataType == BinaryType && + field.metadata.contains("pickled") && field.metadata.getBoolean("pickled") + } + /** * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark.