Skip to content

Commit b1a31a0

Browse files
author
Prabin Banka
committed
Added Python RDD.zip function
1 parent 181ec50 commit b1a31a0

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

python/pyspark/rdd.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import warnings
3131

3232
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
33-
BatchedSerializer, CloudPickleSerializer, pack_long
33+
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
3434
from pyspark.join import python_join, python_left_outer_join, \
3535
python_right_outer_join, python_cogroup
3636
from pyspark.statcounter import StatCounter
@@ -1057,6 +1057,24 @@ def coalesce(self, numPartitions, shuffle=False):
10571057
jrdd = self._jrdd.coalesce(numPartitions)
10581058
return RDD(jrdd, self.ctx, self._jrdd_deserializer)
10591059

1060+
def zip(self, other):
1061+
"""
1062+
Zips this RDD with another one, returning key-value pairs with the first element in each RDD
1063+
second element in each RDD, etc. Assumes that the two RDDs have the same number of
1064+
partitions and the same number of elements in each partition (e.g. one was made through
1065+
a map on the other).
1066+
1067+
>>> x = sc.parallelize(range(0,5))
1068+
>>> y = sc.parallelize(range(1000, 1005))
1069+
>>> x.zip(y).collect()
1070+
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
1071+
"""
1072+
pairRDD = self._jrdd.zip(other._jrdd)
1073+
deserializer = PairDeserializer(self._jrdd_deserializer,
1074+
other._jrdd_deserializer)
1075+
return RDD(pairRDD, self.ctx, deserializer)
1076+
1077+
10601078
# TODO: `lookup` is disabled because we can't make direct comparisons based
10611079
# on the key; we need to compare the hash of the key to the hash of the
10621080
# keys in the pairs. This could be an expensive operation, since those

python/pyspark/serializers.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,18 @@ def __init__(self, key_ser, val_ser):
204204
self.key_ser = key_ser
205205
self.val_ser = val_ser
206206

207-
def load_stream(self, stream):
207+
def prepare_keys_values(self, stream):
208208
key_stream = self.key_ser._load_stream_without_unbatching(stream)
209209
val_stream = self.val_ser._load_stream_without_unbatching(stream)
210210
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
211211
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
212212
for (keys, vals) in izip(key_stream, val_stream):
213213
keys = keys if key_is_batched else [keys]
214214
vals = vals if val_is_batched else [vals]
215+
yield (keys, vals)
216+
217+
def load_stream(self, stream):
218+
for (keys, vals) in self.prepare_keys_values(stream):
215219
for pair in product(keys, vals):
216220
yield pair
217221

@@ -224,6 +228,29 @@ def __str__(self):
224228
(str(self.key_ser), str(self.val_ser))
225229

226230

231+
class PairDeserializer(CartesianDeserializer):
232+
"""
233+
Deserializes the JavaRDD zip() of two PythonRDDs.
234+
"""
235+
236+
def __init__(self, key_ser, val_ser):
237+
self.key_ser = key_ser
238+
self.val_ser = val_ser
239+
240+
def load_stream(self, stream):
241+
for (keys, vals) in self.prepare_keys_values(stream):
242+
for pair in izip(keys, vals):
243+
yield pair
244+
245+
def __eq__(self, other):
246+
return isinstance(other, PairDeserializer) and \
247+
self.key_ser == other.key_ser and self.val_ser == other.val_ser
248+
249+
def __str__(self):
250+
return "PairDeserializer<%s, %s>" % \
251+
(str(self.key_ser), str(self.val_ser))
252+
253+
227254
class NoOpSerializer(FramedSerializer):
228255

229256
def loads(self, obj): return obj

0 commit comments

Comments
 (0)