Skip to content

Commit 296fb85

Browse files
authored
Fix code caching for OpenMP regions (CPU and GPU)
1) Fix dispatcher pickling to only return output_types if it is a object mode lifted with. (#29) 2) Make openmp_tag pickleable by not storing the arg if it contains a llvmlite AllocaInstr. 3) Don't pickle anything for LiftedCode's function IR except the arg count and the func_id. 4) Return all of openmp_region_start for pickling. 5) Comment on why we need new class.
1 parent 55bd60e commit 296fb85

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

Diff for: numba/core/consts.py

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def __init__(self, func_ir):
2121
self._func_ir = weakref.proxy(func_ir)
2222
self._cache = {}
2323

24+
def __repr__(self):
25+
return str(self._cache)
26+
2427
def infer_constant(self, name, loc=None):
2528
"""
2629
Infer a constant value for the given variable *name*.

Diff for: numba/core/dispatcher.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,16 @@ def __init__(self, func_ir, typingctx, targetctx, flags, locals):
11011101
can_fallback=True,
11021102
exact_match_required=False)
11031103

1104+
class LCFuncIR:
1105+
""" This is a minimalistic reimplementation of Numba's regular FuncIR
1106+
class. We only need these two parts for caching. We either don't
1107+
have the information to fill in the regular FuncIR or there are
1108+
problems if we do.
1109+
"""
1110+
def __init__(self, arg_count, func_id):
1111+
self.arg_count = arg_count
1112+
self.func_id = func_id
1113+
11041114
def _reduce_states(self):
11051115
"""
11061116
Reduce the instance for pickling. This will serialize
@@ -1110,8 +1120,8 @@ def _reduce_states(self):
11101120
NOTE: part of ReduceMixin protocol
11111121
"""
11121122
return dict(
1113-
uuid=self._uuid, func_ir=self.func_ir, flags=self.flags,
1114-
locals=self.locals, extras=self._reduce_extras(),
1123+
uuid=self._uuid, func_ir=self.LCFuncIR(self.func_ir.arg_count, self.func_ir.func_id),
1124+
flags=self.flags, locals=self.locals, extras=self._reduce_extras(),
11151125
)
11161126

11171127
def _reduce_extras(self):
@@ -1227,7 +1237,10 @@ class LiftedWith(LiftedCode):
12271237
can_cache = True
12281238

12291239
def _reduce_extras(self):
1230-
return dict(output_types=self.output_types)
1240+
if hasattr(self, "output_types"):
1241+
return dict(output_types=self.output_types)
1242+
else:
1243+
return dict()
12311244

12321245
@property
12331246
def _numba_type_(self):

Diff for: numba/openmp.py

+20
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
llvm_binpath=None
5454
llvm_libpath=None
55+
5556
def _init():
5657
global llvm_binpath
5758
global llvm_libpath
@@ -80,6 +81,7 @@ def _init():
8081

8182
_init()
8283

84+
8385
#----------------------------------------------------------------------------------------------
8486

8587
class NameSlice:
@@ -166,6 +168,17 @@ def __init__(self, name, arg=None, load=False, non_arg=False, omp_slice=None):
166168
self.non_arg = non_arg
167169
self.omp_slice = omp_slice
168170

171+
def __getstate__(self):
172+
state = self.__dict__.copy()
173+
if isinstance(self.arg, lir.instructions.AllocaInstr):
174+
del state['arg']
175+
return state
176+
177+
def __setstate__(self, state):
178+
self.__dict__.update(state)
179+
if not hasattr(self, "arg"):
180+
self.arg = None
181+
169182
def var_in(self, var):
170183
assert isinstance(var, str)
171184

@@ -1120,6 +1133,13 @@ def __init__(self, tags, region_number, loc, firstprivate_dead_after=None):
11201133
self.alloca_queue = []
11211134
self.end_region = None
11221135

1136+
def __getstate__(self):
1137+
state = self.__dict__.copy()
1138+
return state
1139+
1140+
def __setstate__(self, state):
1141+
self.__dict__.update(state)
1142+
11231143
def replace_var_names(self, namedict):
11241144
for i in range(len(self.tags)):
11251145
if isinstance(self.tags[i].arg, ir.Var):

0 commit comments

Comments
 (0)