diff --git a/python/pyspark/context.py b/python/pyspark/context.py index cb15b4b91f91..7c98959e4075 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -32,7 +32,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, AutoBatchedSerializer, NoOpSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call @@ -73,7 +73,8 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None, profiler_cls=BasicProfiler): + gateway=None, jsc=None, profiler_cls=BasicProfiler, + function_serializer=CloudPickleSerializer()): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -98,6 +99,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param jsc: The JavaSparkContext instance (optional). :param profiler_cls: A class of custom Profiler used to do profiling (default is pyspark.profiler.BasicProfiler). + :param function_serializer: The serializer for functions used in RDD + transformations. >>> from pyspark.context import SparkContext @@ -112,14 +115,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc, profiler_cls) + conf, jsc, profiler_cls, function_serializer) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc, profiler_cls): + conf, jsc, profiler_cls, function_serializer): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -129,6 +132,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + self._function_serializer = function_serializer # Set any parameters passed directly to us on the conf if master: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8978f028c592..937d3a530336 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -40,8 +40,8 @@ from itertools import imap as map, ifilter as filter from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + BatchedSerializer, PairDeserializer, PickleSerializer, pack_long, \ + AutoBatchedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -2311,12 +2311,11 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) + pickled_command = sc._function_serializer.dumps(command) if len(pickled_command) > (1 << 20): # 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) + pickled_command = sc._function_serializer.dumps(broadcast) # There is a bug in py4j.java_gateway.JavaClass with auto_convert # https://github.com/bartdag/py4j/issues/161 # TODO: use auto_convert once py4j fix the bug