From dfbccf5c92333da8ab835fc4730aadc844e9f895 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Sep 2014 11:23:10 -0700 Subject: [PATCH] fix bug while pickle globals of function --- python/pyspark/cloudpickle.py | 42 ++++++++++++++++++++++++++++++----- python/pyspark/tests.py | 18 +++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 32dda3888c62d..bb0783555aa77 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -52,6 +52,7 @@ import itertools from copy_reg import _extension_registry, _inverted_registry, _extension_cache import new +import dis import traceback import platform @@ -61,6 +62,14 @@ import logging cloudLog = logging.getLogger("Cloud.Transport") +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) if PyImp == "PyPy": # register builtin type in `new` @@ -304,16 +313,37 @@ def save_function_tuple(self, func, forced_imports): write(pickle.REDUCE) # applies _fill_function on the tuple @staticmethod - def extract_code_globals(code): + def extract_code_globals(co): """ Find all globals names read or written to by codeblock co """ - names = set(code.co_names) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + + if co.co_consts: # see if nested function have any global refs + for const in co.co_consts: if type(const) is types.CodeType: - names |= CloudPickler.extract_code_globals(const) - return names + out_names |= CloudPickler.extract_code_globals(const) + + return out_names def extract_func_data(self, func): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1b8afb763b26a..4dd9edad69081 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -213,6 +213,24 @@ def test_pickling_file_handles(self): out2 = ser.loads(ser.dumps(out1)) self.assertEquals(out1, out2) + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.func_code.co_names) + ser.dumps(foo) + class PySparkTestCase(unittest.TestCase):