Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ private[spark] class PythonRDD(
}
}

private[spark] case class PythonFunction(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])

/**
* A helper class to run Python UDFs in Spark.
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,6 +2330,15 @@ def _prepare_for_python_RDD(sc, command, obj=None):
return pickled_command, broadcast_vars, env, includes


def _wrap_function(sc, func, deserializer, serializer, profiler=None):
assert deserializer, "deserializer should not be empty"
assert serializer, "serializer should not be empty"
command = (func, profiler, deserializer, serializer)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)


class PipelinedRDD(RDD):

"""
Expand Down
134 changes: 131 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from pyspark import since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer, PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
Expand Down Expand Up @@ -236,9 +237,14 @@ def collect(self):
>>> df.collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""

if self._jdf.isPickled():
deserializer = PickleSerializer()
else:
deserializer = BatchedSerializer(PickleSerializer())
with SCCallSiteSync(self._sc) as css:
port = self._jdf.collectToPython()
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
return list(_load_from_socket(port, deserializer))

@ignore_unicode_prefix
@since(1.3)
Expand Down Expand Up @@ -278,6 +284,25 @@ def map(self, f):
"""
return self.rdd.map(f)

@ignore_unicode_prefix
@since(2.0)
def applySchema(self, schema=None):
""" TODO """
# TODO: should we throw exception instead?
return self

@ignore_unicode_prefix
@since(2.0)
def mapPartitions2(self, func):
""" TODO """
return PipelinedDataFrame(self, func)

@ignore_unicode_prefix
@since(2.0)
def map2(self, func):
""" TODO """
return self.mapPartitions2(lambda iterator: map(func, iterator))

@ignore_unicode_prefix
@since(1.3)
def flatMap(self, f):
Expand Down Expand Up @@ -890,10 +915,20 @@ def groupBy(self, *cols):
>>> sorted(df.groupBy(['name', df.age]).count().collect())
[Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)]
"""
jgd = self._jdf.groupBy(self._jcols(*cols))
jgd = self._jdf.pythonGroupBy(self._jcols(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self.sql_ctx)

@ignore_unicode_prefix
@since(2.0)
def groupByKey(self, key_func, key_type):
""" TODO """
f = lambda iterator: map(key_func, iterator)
wraped_func = _wrap_func(self._sc, self._jdf, f, False)
jgd = self._jdf.pythonGroupBy(wraped_func, key_type.json())
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self.sql_ctx, not isinstance(key_type, StructType))

@since(1.4)
def rollup(self, *cols):
"""
Expand Down Expand Up @@ -1354,6 +1389,99 @@ def toPandas(self):
drop_duplicates = dropDuplicates


class PipelinedDataFrame(DataFrame):

""" TODO """

def __init__(self, prev, func):
from pyspark.sql.group import GroupedData

self._jdf_val = None
self.is_cached = False
self.sql_ctx = prev.sql_ctx
self._sc = self.sql_ctx and self.sql_ctx._sc
self._lazy_rdd = None

if isinstance(prev, GroupedData):
# prev is GroupedData, set the grouped flag to true and use jgd as jdf.
self._grouped = True
self._func = func
self._prev_jdf = prev._jgd
elif not isinstance(prev, PipelinedDataFrame) or prev.is_cached:
# This transformation is the first in its stage:
self._grouped = False
self._func = func
self._prev_jdf = prev._jdf
else:
self._grouped = prev._grouped
self._func = _pipeline_func(prev._func, func)
# maintain the pipeline.
self._prev_jdf = prev._prev_jdf

def applySchema(self, schema=None):
if schema is None:
from pyspark.sql.types import _infer_type, _merge_type
# If no schema is specified, infer it from the whole data set.
jrdd = self._prev_jdf.javaToPython()
rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer()))
schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type)

if isinstance(schema, StructType):
to_rows = lambda iterator: map(schema.toInternal, iterator)
else:
data_type = schema
schema = StructType().add("value", data_type)
to_row = lambda obj: (data_type.toInternal(obj), )
to_rows = lambda iterator: map(to_row, iterator)

jdf = self._create_jdf(_pipeline_func(self._func, to_rows), schema)
return DataFrame(jdf, self.sql_ctx)

@property
def _jdf(self):
if self._jdf_val is None:
self._jdf_val = self._create_jdf(self._func)
return self._jdf_val

def _create_jdf(self, func, schema=None):
wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None, self._grouped)
if schema is None:
if self._grouped:
return self._prev_jdf.flatMapGroups(wrapped_func)
else:
return self._prev_jdf.pythonMapPartitions(wrapped_func)
else:
schema_string = schema.json()
if self._grouped:
return self._prev_jdf.flatMapGroups(wrapped_func, schema_string)
else:
return self._prev_jdf.pythonMapPartitions(wrapped_func, schema_string)


def _wrap_func(sc, jdf, func, output_binary, input_grouped=False):
if input_grouped:
deserializer = PairDeserializer(PickleSerializer(), PickleSerializer())
elif jdf.isPickled():
deserializer = PickleSerializer()
else:
deserializer = AutoBatchedSerializer(PickleSerializer())

if output_binary:
serializer = PickleSerializer()
else:
serializer = AutoBatchedSerializer(PickleSerializer())

from pyspark.rdd import _wrap_function
return _wrap_function(sc, lambda _, iterator: func(iterator), deserializer, serializer)


def _pipeline_func(prev_func, next_func):
if prev_func is None:
return next_func
else:
return lambda iterator: next_func(prev_func(iterator))


def _to_scala_map(sc, jm):
"""
Convert a dict into a JVM Map.
Expand Down
140 changes: 131 additions & 9 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
# limitations under the License.
#

import sys

if sys.version >= '3':
basestring = unicode = str
long = int
from functools import reduce
else:
from itertools import imap as map

from pyspark import since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.dataframe import DataFrame, PipelinedDataFrame
from pyspark.sql.types import *

__all__ = ["GroupedData"]
Expand All @@ -27,7 +36,7 @@
def dfapi(f):
def _api(self):
name = f.__name__
jdf = getattr(self._jdf, name)()
jdf = getattr(self._jgd, name)()
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
Expand All @@ -37,7 +46,7 @@ def _api(self):
def df_varargs_api(f):
def _api(self, *args):
name = f.__name__
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
Expand All @@ -54,9 +63,33 @@ class GroupedData(object):
.. versionadded:: 1.3
"""

def __init__(self, jdf, sql_ctx):
self._jdf = jdf
def __init__(self, jgd, sql_ctx, flat_key=False):
self._jgd = jgd
self.sql_ctx = sql_ctx
if flat_key:
self._key_converter = lambda key: key[0]
else:
self._key_converter = lambda key: key

@ignore_unicode_prefix
@since(2.0)
def flatMapGroups(self, func):
""" TODO """
key_converter = self._key_converter

def process(inputs):
record_converter = lambda record: (key_converter(record[0]), record[1])
for key, values in GroupedIterator(map(record_converter, inputs)):
for output in func(key, values):
yield output

return PipelinedDataFrame(self, process)

@ignore_unicode_prefix
@since(2.0)
def mapGroups(self, func):
""" TODO """
return self.flatMapGroups(lambda key, values: iter([func(key, values)]))

@ignore_unicode_prefix
@since(1.3)
Expand All @@ -83,11 +116,11 @@ def agg(self, *exprs):
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jdf = self._jdf.agg(exprs[0])
jdf = self._jgd.agg(exprs[0])
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jdf = self._jdf.agg(exprs[0]._jc,
jdf = self._jgd.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)

Expand Down Expand Up @@ -187,12 +220,101 @@ def pivot(self, pivot_col, values=None):
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
if values is None:
jgd = self._jdf.pivot(pivot_col)
jgd = self._jgd.pivot(pivot_col)
else:
jgd = self._jdf.pivot(pivot_col, values)
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self.sql_ctx)


class GroupedIterator(object):
""" TODO """

def __init__(self, inputs):
self.inputs = BufferedIterator(inputs)
self.current_input = next(inputs)
self.current_key = self.current_input[0]
self.current_values = GroupValuesIterator(self)

def __iter__(self):
return self

def __next__(self):
return self.next()

def next(self):
if self.current_values is None:
self._fetch_next_group()

ret = (self.current_key, self.current_values)
self.current_values = None
return ret

def _fetch_next_group(self):
if self.current_input is None:
self.current_input = next(self.inputs)

# Skip to next group, or consume all inputs and throw StopIteration exception.
while self.current_input[0] == self.current_key:
self.current_input = next(self.inputs)

self.current_key = self.current_input[0]
self.current_values = GroupValuesIterator(self)


class GroupValuesIterator(object):
""" TODO """

def __init__(self, outter):
self.outter = outter

def __iter__(self):
return self

def __next__(self):
return self.next()

def next(self):
if self.outter.current_input is None:
self._fetch_next_value()

value = self.outter.current_input[1]
self.outter.current_input = None
return value

def _fetch_next_value(self):
if self.outter.inputs.head()[0] == self.outter.current_key:
self.outter.current_input = next(self.outter.inputs)
else:
raise StopIteration


class BufferedIterator(object):
""" TODO """

def __init__(self, iterator):
self.iterator = iterator
self.buffered = None

def __iter__(self):
return self

def __next__(self):
return self.next()

def next(self):
if self.buffered is None:
return next(self.iterator)
else:
item = self.buffered
self.buffered = None
return item

def head(self):
if self.buffered is None:
self.buffered = next(self.iterator)
return self.buffered


def _test():
import doctest
from pyspark.context import SparkContext
Expand Down
Loading