Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ def values(self):
"""
return self.map(lambda (k, v): v)

def reduceByKey(self, func, numPartitions=None):
def reduceByKey(self, func, numPartitions=None, mapSideCombine=True):
"""
Merge the values for each key using an associative reduce function.

Expand All @@ -1334,7 +1334,7 @@ def reduceByKey(self, func, numPartitions=None):
>>> sorted(rdd.reduceByKey(add).collect())
[('a', 2), ('b', 1)]
"""
return self.combineByKey(lambda x: x, func, func, numPartitions)
return self.combineByKey(lambda x: x, func, func, numPartitions, mapSideCombine)

def reduceByKeyLocally(self, func):
"""
Expand Down Expand Up @@ -1516,9 +1516,8 @@ def add_shuffle_key(split, iterator):
rdd._partitionFunc = partitionFunc
return rdd

# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numPartitions=None):
numPartitions=None, mapSideCombine=True):
"""
Generic function to combine the elements for each key using a custom
set of aggregation functions.
Expand Down Expand Up @@ -1559,18 +1558,21 @@ def combineLocally(iterator):
merger.mergeValues(iterator)
return merger.iteritems()

locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)

def _mergeCombiners(iterator):
merger = ExternalMerger(agg, memory, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeCombiners(iterator)
return merger.iteritems()

return shuffled.mapPartitions(_mergeCombiners, True)
if mapSideCombine:
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)
return shuffled.mapPartitions(_mergeCombiners)
else:
shuffled = self.partitionBy(numPartitions)
return shuffled.mapPartitions(combineLocally)

def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, mapSideCombine=True):
"""
Aggregate the values of each key, using given combine functions and a neutral
"zero value". This function can return a different result type, U, than the type
Expand All @@ -1584,9 +1586,9 @@ def createZero():
return copy.deepcopy(zeroValue)

return self.combineByKey(
lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, mapSideCombine)

def foldByKey(self, zeroValue, func, numPartitions=None):
def foldByKey(self, zeroValue, func, numPartitions=None, mapSideCombine=True):
"""
Merge the values for each key using an associative function "func"
and a neutral "zeroValue" which may be added to the result an
Expand All @@ -1601,10 +1603,11 @@ def foldByKey(self, zeroValue, func, numPartitions=None):
def createZero():
return copy.deepcopy(zeroValue)

return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
return self.combineByKey(
lambda v: func(createZero(), v), func, func, numPartitions, mapSideCombine)

# TODO: support variant with custom partitioner
def groupByKey(self, numPartitions=None):
def groupByKey(self, numPartitions=None, mapSideCombine=True):
"""
Group the values for each key in the RDD into a single sequence.
Hash-partitions the resulting RDD with into numPartitions partitions.
Expand All @@ -1629,8 +1632,8 @@ def mergeCombiners(a, b):
a.extend(b)
return a

return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: ResultIterable(x))
return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions,
mapSideCombine).mapValues(lambda x: ResultIterable(x))

def flatMapValues(self, f):
"""
Expand Down