Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Fix code caching for OpenMP regions (CPU and GPU) #29

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions numba/core/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, func_ir):
self._func_ir = weakref.proxy(func_ir)
self._cache = {}

def __repr__(self):
return str(self._cache)

def infer_constant(self, name, loc=None):
"""
Infer a constant value for the given variable *name*.
Expand Down
19 changes: 16 additions & 3 deletions numba/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,16 @@ def __init__(self, func_ir, typingctx, targetctx, flags, locals):
can_fallback=True,
exact_match_required=False)

class LCFuncIR:
""" This is a minimalistic reimplementation of Numba's regular FuncIR
class. We only need these two parts for caching. We either don't
have the information to fill in the regular FuncIR or there are
problems if we do.
"""
def __init__(self, arg_count, func_id):
self.arg_count = arg_count
self.func_id = func_id

def _reduce_states(self):
"""
Reduce the instance for pickling. This will serialize
Expand All @@ -1110,8 +1120,8 @@ def _reduce_states(self):
NOTE: part of ReduceMixin protocol
"""
return dict(
uuid=self._uuid, func_ir=self.func_ir, flags=self.flags,
locals=self.locals, extras=self._reduce_extras(),
uuid=self._uuid, func_ir=self.LCFuncIR(self.func_ir.arg_count, self.func_ir.func_id),
flags=self.flags, locals=self.locals, extras=self._reduce_extras(),
)

def _reduce_extras(self):
Expand Down Expand Up @@ -1227,7 +1237,10 @@ class LiftedWith(LiftedCode):
can_cache = True

def _reduce_extras(self):
return dict(output_types=self.output_types)
if hasattr(self, "output_types"):
return dict(output_types=self.output_types)
else:
return dict()

@property
def _numba_type_(self):
Expand Down
20 changes: 20 additions & 0 deletions numba/openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

llvm_binpath=None
llvm_libpath=None

def _init():
global llvm_binpath
global llvm_libpath
Expand Down Expand Up @@ -80,6 +81,7 @@ def _init():

_init()


#----------------------------------------------------------------------------------------------

class NameSlice:
Expand Down Expand Up @@ -166,6 +168,17 @@ def __init__(self, name, arg=None, load=False, non_arg=False, omp_slice=None):
self.non_arg = non_arg
self.omp_slice = omp_slice

def __getstate__(self):
state = self.__dict__.copy()
if isinstance(self.arg, lir.instructions.AllocaInstr):
del state['arg']
return state

def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, "arg"):
self.arg = None

def var_in(self, var):
assert isinstance(var, str)

Expand Down Expand Up @@ -1120,6 +1133,13 @@ def __init__(self, tags, region_number, loc, firstprivate_dead_after=None):
self.alloca_queue = []
self.end_region = None

def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)

def replace_var_names(self, namedict):
for i in range(len(self.tags)):
if isinstance(self.tags[i].arg, ir.Var):
Expand Down
Loading