diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2d05611321ed..7561eeb2fc7f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2209,7 +2209,7 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) + pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dd8d3b1c5373..92a20e77d5e1 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,6 +34,8 @@ import threading import hashlib +from py4j.protocol import Py4JJavaError + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -1441,6 +1443,20 @@ def count(): self.assertTrue(not t.isAlive()) self.assertEqual(100000, rdd.count()) + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = sys.version_info + sys.version_info = (2, 0, 0) + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + try: + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + sys.version_info = version + log4j.LogManager.getRootLogger().setLevel(old_level) + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8a93c320ec5d..452d6fabdcc1 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -88,7 +88,11 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer) = command + (func, profiler, deserializer, serializer), version = command + if version != sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + (sys.version_info[:2], version)) init_time = time.time() def process():