diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 21f182b0ff13..2030b8080491 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2061,8 +2061,12 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) + # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index fc9310fef318..96ea1309ecda 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -27,7 +27,7 @@ from array import array from operator import itemgetter -from pyspark.rdd import RDD, PipelinedRDD +from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel @@ -974,7 +974,11 @@ def registerFunction(self, name, f, returnType=StringType()): command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self._sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self._sc._pickled_broadcast_vars], self._sc._gateway._gateway_client) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f255b44359fe..82d183b545f3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -434,6 +434,12 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_large_closure(self): + N = 1000000 + data = [float(i) for i in xrange(N)] + m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum() + self.assertEquals(N, m) + def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) b = self.sc.parallelize(range(100, 105)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 252176ac65fe..d6c06e2dbef6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -77,10 +77,12 @@ def main(infile, outfile): _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 - _broadcastRegistry.remove(bid) + _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) + if isinstance(command, Broadcast): + command = pickleSer.loads(command.value) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile)