diff --git a/src/galois/_codes/_bch.py b/src/galois/_codes/_bch.py index 49357b135..2675eda4a 100644 --- a/src/galois/_codes/_bch.py +++ b/src/galois/_codes/_bch.py @@ -661,7 +661,7 @@ def decode(self, codeword, output="message", errors=False): return super().decode(codeword, output=output, errors=errors) def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, np.ndarray]: - func = decode_jit(self.field, self.extension_field) + func = bch_decode_jit(self.field, self.extension_field) dec_codeword, N_errors = func(codeword, self.n, int(self.alpha), self.c, self.roots) dec_codeword = dec_codeword.view(self.field) return dec_codeword, N_errors @@ -1192,7 +1192,7 @@ def _generator_poly_from_k( return best_generator_poly, best_roots -class decode_jit(Function): +class bch_decode_jit(Function): """ Performs general BCH and Reed-Solomon decoding. @@ -1204,6 +1204,18 @@ def __init__(self, field: Type[FieldArray], extension_field: Type[FieldArray]): super().__init__(field) self.extension_field = extension_field + @property + def key_1(self): + # Make the key in the cache lookup table specific to both the base field and extension field + return ( + self.field.characteristic, + self.field.degree, + int(self.field.irreducible_poly), + self.extension_field.characteristic, + self.extension_field.degree, + int(self.extension_field.irreducible_poly), + ) + def __call__(self, codeword, design_n, alpha, c, roots): if self.extension_field.ufunc_mode != "python-calculate": output = self.jit(codeword.astype(np.int64), design_n, alpha, c, roots.astype(np.int64)) diff --git a/src/galois/_codes/_reed_solomon.py b/src/galois/_codes/_reed_solomon.py index b54d43d32..fbc2210d2 100644 --- a/src/galois/_codes/_reed_solomon.py +++ b/src/galois/_codes/_reed_solomon.py @@ -13,7 +13,7 @@ from .._math import ilog from .._polys import Poly, matlab_primitive_poly from ..typing import ArrayLike, ElementLike -from ._bch import decode_jit +from ._bch import bch_decode_jit from ._cyclic import _CyclicCode @@ -613,7 +613,7 @@ def decode(self, codeword, output="message", errors=False): return super().decode(codeword, output=output, errors=errors) def _decode_codeword(self, codeword: FieldArray) -> tuple[FieldArray, np.ndarray]: - func = decode_jit(self.field, self.field) + func = reed_solomon_decode_jit(self.field, self.field) dec_codeword, N_errors = func(codeword, self.n, int(self.alpha), self.c, self.roots) dec_codeword = dec_codeword.view(self.field) return dec_codeword, N_errors @@ -1033,3 +1033,14 @@ def is_narrow_sense(self) -> bool: @property def is_systematic(self) -> bool: return super().is_systematic + + +class reed_solomon_decode_jit(bch_decode_jit): + """ + Performs general BCH and Reed-Solomon decoding. + + References: + - Lin, S. and Costello, D. Error Control Coding. Section 7.4. + """ + + # NOTE: Making a subclass so that these compiled functions are stored in a new namespace diff --git a/src/galois/_domains/_function.py b/src/galois/_domains/_function.py index a44b0903f..29137591b 100644 --- a/src/galois/_domains/_function.py +++ b/src/galois/_domains/_function.py @@ -56,6 +56,18 @@ def set_globals(self): # Various ufuncs based on implementation and compilation ############################################################################### + @property + def key_1(self): + return (self.field.characteristic, self.field.degree, int(self.field.irreducible_poly)) + + @property + def key_2(self): + if self.field.ufunc_mode == "jit-lookup": + key = (str(self.__class__), self.field.ufunc_mode, int(self.field.primitive_element)) + else: + key = (str(self.__class__), self.field.ufunc_mode) + return key + @property def function(self): """ @@ -72,19 +84,13 @@ def jit(self) -> numba.types.FunctionType: """ assert self.field.ufunc_mode in ["jit-lookup", "jit-calculate"] - key_1 = (self.field.characteristic, self.field.degree, int(self.field.irreducible_poly)) - if self.field.ufunc_mode == "jit-lookup": - key_2 = (str(self.__class__), self.field.ufunc_mode, int(self.field.primitive_element)) - else: - key_2 = (str(self.__class__), self.field.ufunc_mode) - self._CACHE.setdefault(key_1, {}) - - if key_2 not in self._CACHE[key_1]: + self._CACHE.setdefault(self.key_1, {}) + if self.key_2 not in self._CACHE[self.key_1]: self.set_globals() # Set the globals once before JIT compiling the function func = numba.jit(self._SIGNATURE.signature, parallel=self._PARALLEL, nopython=True)(self.implementation) - self._CACHE[key_1][key_2] = func + self._CACHE[self.key_1][self.key_2] = func - return self._CACHE[key_1][key_2] + return self._CACHE[self.key_1][self.key_2] @property def python(self) -> Callable: diff --git a/tests/codes/test_bch.py b/tests/codes/test_bch.py index cbb726283..1d8f378bd 100644 --- a/tests/codes/test_bch.py +++ b/tests/codes/test_bch.py @@ -378,3 +378,17 @@ def test_bch_valid_codes_511(): code = random.choice(codes) bch = galois.BCH(code[0], code[1]) assert (bch.n, bch.k, bch.t) == code + + +def test_bug_483(): + """ + See https://github.com/mhostetter/galois/issues/483. + """ + bch_1 = galois.BCH(15, 11) + verify_decode(bch_1, 1) + + bch_2 = galois.BCH(7, 4) + verify_decode(bch_2, 1) + + bch_3 = galois.BCH(31, 26) + verify_decode(bch_3, 1)