diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 019c249699c2..9e60a822e859 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -913,7 +913,7 @@ def rightOuterJoin(self, other, numPartitions=None): return python_right_outer_join(self, other, numPartitions) # TODO: add option to control map-side combining - def partitionBy(self, numPartitions, partitionFunc=hash): + def partitionBy(self, numPartitions, partitionFunc=None): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -924,6 +924,9 @@ def partitionBy(self, numPartitions, partitionFunc=hash): """ if numPartitions is None: numPartitions = self.ctx.defaultParallelism + + if partitionFunc is None: + partitionFunc = lambda x: 0 if x is None else hash(x) # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair.