Skip to content

Commit

Permalink
Allow caching of constants in lambda functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Nov 9, 2021
1 parent f1d2f6e commit 976d2a4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
20 changes: 13 additions & 7 deletions galois/_fields/_calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def _ufunc_calculate(cls, name):
function = getattr(cls, f"_{name}_calculate")

# These variables must be locals and not class properties for Numba to compile them as literals
CHARACTERISTIC = cls.characteristic
DEGREE = cls.degree
IRREDUCIBLE_POLY = cls._irreducible_poly_int
characteristic = cls.characteristic
degree = cls.degree
irreducible_poly = cls._irreducible_poly_int

if cls._UFUNC_TYPE[name] == "unary":
cls._UFUNC_CACHE_CALCULATE[key] = numba.vectorize(["int64(int64)"], nopython=True)(lambda a: function(a, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY))
cls._UFUNC_CACHE_CALCULATE[key] = numba.vectorize(["int64(int64)"], nopython=True)(lambda a: function(a, characteristic, degree, irreducible_poly))
else:
cls._UFUNC_CACHE_CALCULATE[key] = numba.vectorize(["int64(int64, int64)"], nopython=True)(lambda a, b: function(a, b, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY))
cls._UFUNC_CACHE_CALCULATE[key] = numba.vectorize(["int64(int64, int64)"], nopython=True)(lambda a, b: function(a, b, characteristic, degree, irreducible_poly))

cls._reset_globals()

Expand All @@ -112,10 +112,16 @@ def _ufunc_python(cls, name):
Returns a pure-python arithmetic ufunc using explicit calculation.
"""
function = getattr(cls, f"_{name}_calculate")

# Pre-fetching these values into local variables allows Python to cache them as constants in the lambda function
characteristic = cls.characteristic
degree = cls.degree
irreducible_poly = cls._irreducible_poly_int

if cls._UFUNC_TYPE[name] == "unary":
return np.frompyfunc(lambda a: function(a, cls.characteristic, cls.degree, cls._irreducible_poly_int), 1, 1)
return np.frompyfunc(lambda a: function(a, characteristic, degree, irreducible_poly), 1, 1)
else:
return np.frompyfunc(lambda a, b: function(a, b, cls.characteristic, cls.degree, cls._irreducible_poly_int), 2, 1)
return np.frompyfunc(lambda a, b: function(a, b, characteristic, degree, irreducible_poly), 2, 1)

###############################################################################
# Arithmetic functions using explicit calculation
Expand Down
5 changes: 3 additions & 2 deletions galois/_fields/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def _build_lookup_tables(cls):

order = cls.order
primitive_element = int(cls.primitive_element)
add = lambda a, b: cls._func_python("add")(a, b, cls.characteristic, cls.degree, cls._irreducible_poly_int)
multiply = lambda a, b: cls._func_python("multiply")(a, b, cls.characteristic, cls.degree, cls._irreducible_poly_int)
add = cls._ufunc_python("add")
multiply = cls._ufunc_python("multiply")

cls._EXP = np.zeros(2*order, dtype=np.int64)
cls._LOG = np.zeros(order, dtype=np.int64)
Expand Down Expand Up @@ -101,6 +101,7 @@ def _ufunc_lookup(cls, name):
key = (name, cls.characteristic, cls.degree, cls._irreducible_poly_int)

if key not in cls._UFUNC_CACHE_LOOKUP:
# These variables must be locals for Numba to compile them as literals
EXP = cls._EXP
LOG = cls._LOG
ZECH_LOG = cls._ZECH_LOG
Expand Down

0 comments on commit 976d2a4

Please sign in to comment.