Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from threading import Thread
import warnings
import heapq
import bisect
from random import Random
from math import sqrt, log

Expand Down Expand Up @@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
# noqa

>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
Expand All @@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

bounds = list()
if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)

def sort(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return self.mapPartitions(sort)

# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
# number of (key, value) pairs falling into them
if numPartitions > 1:
rddSize = self.count()
# constant from Spark's RangePartitioner
maxSampleSize = numPartitions * 20.0
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)

samples = self.sample(False, fraction, 1).map(
lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)

# we have numPartitions many parts but one of the them has
# an implicit boundary
for i in range(0, numPartitions - 1):
index = (len(samples) - 1) * (i + 1) / numPartitions
bounds.append(samples[index])
rddSize = self.count()
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)

# we have numPartitions many parts but one of the them has
# an implicit boundary
bounds = [samples[len(samples) * (i + 1) / numPartitions]
for i in range(0, numPartitions - 1)]

def rangePartitionFunc(k):
p = 0
while p < len(bounds) and keyfunc(k) > bounds[p]:
p += 1
p = bisect.bisect_left(bounds, keyfunc(k))
if ascending:
return p
else:
return numPartitions - 1 - p

def mapFunc(iterator):
yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
.mapPartitions(mapFunc, preservesPartitioning=True)
.flatMap(lambda x: x, preservesPartitioning=True))
return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we have the flatMap(lambda x: x) before? Just want to make sure we're not removing something useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I guess it's due to the yield -> return above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I have no idea why it's done in this way. I think it's not necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, yeah. It seems unnecessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be two unintended side effects of this change. This code used to work in pyspark:

sc.parallelize([5,3,4,2,1]).map(lambda x: (x,x)).sortByKey().take(1)

Now it failswith the error:

File "<...>/spark/python/pyspark/rdd.py", line 1023, in takeUpToNumLeft
    yield next(iterator)
TypeError: list object is not an iterator

Changing mapFunc and sort back to generators rather than regular functions fixes that problem.

After making that change, there is a second side effect due to the removal of flatMap where the above code returns the following unexpected result due to the default partitioning scheme:

[[(1, 1), (2, 2)]]

Removing sortByKey, e.g.:

sc.parallelize([5,3,4,2,1]).map(lambda x: (x,x)).take(1)

returns the expected result [(5, 5)]. Restoring the call to flatMap resolves this as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out.. sounds like we should look into this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@str-janus thanks, this will be fixed in #2045


def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""
Expand Down