diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 98b505c9046b..de4b6af23666 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -48,10 +48,6 @@ def __new__(cls): cls._taskContext = taskContext = object.__new__(cls) return taskContext - def __init__(self): - """Construct a TaskContext, use get instead""" - pass - @classmethod def _getOrCreate(cls): """Internal function to get or create global TaskContext.""" @@ -140,13 +136,13 @@ class BarrierTaskContext(TaskContext): _port = None _secret = None - def __init__(self): - """Construct a BarrierTaskContext, use get instead""" - pass - @classmethod def _getOrCreate(cls): - """Internal function to get or create global BarrierTaskContext.""" + """ + Internal function to get or create global BarrierTaskContext. We need to make sure + BarrierTaskContext is returned from here because it is needed in python worker reuse + scenario, see SPARK-25921 for more details. + """ if not isinstance(cls._taskContext, BarrierTaskContext): cls._taskContext = object.__new__(cls) return cls._taskContext diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index b3a967440a9b..fdb5c40b78a4 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import random import sys import time +import unittest -from pyspark import SparkContext, TaskContext, BarrierTaskContext +from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext from pyspark.testing.utils import PySparkTestCase @@ -118,21 +120,6 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) - def test_barrier_with_python_worker_reuse(self): - """ - Verify that BarrierTaskContext.barrier() with reused python worker. - """ - self.sc._conf.set("spark.python.work.reuse", "true") - rdd = self.sc.parallelize(range(4), 4) - # start a normal job first to start all worker - result = rdd.map(lambda x: x ** 2).collect() - self.assertEqual([0, 1, 4, 9], result) - # make sure `spark.python.work.reuse=true` - self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") - - # worker will be reused in this barrier job - self.test_barrier() - def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the @@ -149,6 +136,44 @@ def f(iterator): self.assertTrue(len(taskInfos[0]) == 4) +class TaskContextTestsWithWorkerReuse(unittest.TestCase): + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.worker.reuse", "true") + self.sc = SparkContext('local[2]', class_name, conf=conf) + + def test_barrier_with_python_worker_reuse(self): + """ + Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with + reused python worker. + """ + # start a normal job first to start all workers and get all worker pids + worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect() + # the worker will reuse in this barrier job + rdd = self.sc.parallelize(range(10), 2) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return (time.time(), os.getpid()) + + result = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + times = list(map(lambda x: x[0], result)) + pids = list(map(lambda x: x[1], result)) + # check both barrier and worker reuse effect + self.assertTrue(max(times) - min(times) < 1) + for pid in pids: + self.assertTrue(pid in worker_pids) + + def tearDown(self): + self.sc.stop() + + if __name__ == "__main__": import unittest from pyspark.tests.test_taskcontext import *