Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
</td>
</tr>
<tr>
<td><code>spark.python.profile</code></td>
<td>false</td>
<td>
Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
or it will be displayed before the driver exiting. It also can be dumped into disk by
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
they will not be displayed automatically before driver exiting.
</td>
</tr>
<tr>
<td><code>spark.python.profile.dump</code></td>
<td>(none)</td>
<td>
The directory which is used to dump the profile result before driver exiting.
The results will be dumped as separated file for each RDD. They can be loaded
by ptats.Stats(). If this is specified, the profile result will not be displayed
automatically.
</tr>
<tr>
<td><code>spark.python.worker.reuse</code></td>
<td>true</td>
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ def addInPlace(self, value1, value2):
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)


class PStatsParam(AccumulatorParam):
"""PStatsParam is used to merge pstats.Stats"""

@staticmethod
def zero(value):
return None

@staticmethod
def addInPlace(value1, value2):
if value1 is None:
return value2
value1.add(value2)
return value1


class _UpdateRequestHandler(SocketServer.StreamRequestHandler):

"""
Expand Down
39 changes: 38 additions & 1 deletion python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
import atexit

from pyspark import accumulators
from pyspark.accumulators import Accumulator
Expand All @@ -30,7 +31,6 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call

Expand Down Expand Up @@ -192,6 +192,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()

# profiling stats collected for each PythonRDD
self._profile_stats = []

def _initialize_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
Expand Down Expand Up @@ -792,6 +795,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))

def _add_profile(self, id, profileAcc):
if not self._profile_stats:
dump_path = self._conf.get("spark.python.profile.dump")
if dump_path:
atexit.register(self.dump_profiles, dump_path)
else:
atexit.register(self.show_profiles)

self._profile_stats.append([id, profileAcc, False])

def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, acc, showed) in enumerate(self._profile_stats):
stats = acc.value
if not showed and stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()
# mark it as showed
self._profile_stats[i][2] = True

def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
if not os.path.exists(path):
os.makedirs(path)
for id, acc, _ in self._profile_stats:
stats = acc.value
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)
self._profile_stats = []


def _test():
import atexit
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
Expand All @@ -32,6 +31,7 @@
from random import Random
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
Expand Down Expand Up @@ -2080,7 +2080,9 @@ def _jrdd(self):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
command = (self.func, profileStats, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
Expand All @@ -2102,6 +2104,10 @@ def _jrdd(self):
self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()

if enable_profile:
self._id = self._jrdd_val.id()
self.ctx._add_profile(self._id, profileStats)
return self._jrdd_val

def id(self):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def registerFunction(self, name, f, returnType=StringType()):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
command = (func, None,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,36 @@ def test_distinct(self):
self.assertEquals(result.count(), 3)


class TestProfiler(PySparkTestCase):

def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)

def test_profiler(self):

def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)
profiles = self.sc._profile_stats
self.assertEqual(1, len(profiles))
id, acc, _ = profiles[0]
stats = acc.value
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
self.assertTrue("heavy_foo" in func_names)

self.sc.show_profiles()
d = tempfile.gettempdir()
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))


class TestSQL(PySparkTestCase):

def setUp(self):
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import time
import socket
import traceback
import cProfile
import pstats

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
Expand Down Expand Up @@ -90,10 +92,21 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
(func, stats, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)

def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)

if stats:
p = cProfile.Profile()
p.runcall(process)
st = pstats.Stats(p)
st.stream = None # make it picklable
stats.add(st.strip_dirs())
else:
process()
except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
Expand Down