From 2ca4d16879a6339e0c19f05a0d3c405f1889470c Mon Sep 17 00:00:00 2001 From: mhostetter Date: Tue, 3 May 2022 15:13:11 -0400 Subject: [PATCH] Rework FEC code JIT function structure --- galois/_codes/_bch.py | 201 ++++----- galois/_codes/_reed_solomon.py | 268 ++++++------ galois/_domains/_function.py | 58 +++ galois/_lfsr.py | 734 +++++++++++++++++---------------- 4 files changed, 646 insertions(+), 615 deletions(-) diff --git a/galois/_codes/_bch.py b/galois/_codes/_bch.py index 2d7597297..6dd0b7691 100644 --- a/galois/_codes/_bch.py +++ b/galois/_codes/_bch.py @@ -4,15 +4,16 @@ from __future__ import annotations import math -from typing import Tuple, List, Optional, Union, Type, Any, overload +from typing import Tuple, List, Optional, Union, Type, overload from typing_extensions import Literal import numba from numba import int64 import numpy as np -from .. import _lfsr +from .._domains._function import JITFunction from .._fields import Field, FieldArray, GF2 # pylint: disable=unused-import +from .._lfsr import berlekamp_massey_jit from .._overrides import set_module from .._polys import Poly, matlab_primitive_poly from .._prime import factors @@ -764,7 +765,6 @@ def decode(self, codeword, errors=False): raise ValueError(f"For a non-systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension equal to {self.n}, not shape {codeword.shape}.") codeword_1d = codeword.ndim == 1 - dtype = codeword.dtype ns = codeword.shape[-1] # The number of input codeword bits (could be less than self.n for shortened codes) ks = self.k - (self.n - ns) # The equivalent number of input message bits (could be less than self.k for shortened codes) @@ -774,21 +774,14 @@ def decode(self, codeword, errors=False): # Compute the syndrome by matrix multiplying with the parity-check matrix syndrome = codeword.view(self.field) @ self.H[:,-ns:].T - if self.field.ufunc_mode != "python-calculate": - codeword_ = codeword.astype(np.int64) - syndrome_ = syndrome.astype(np.int64) - y = function("decode", self.field)(codeword_, syndrome_, self.t, int(self.field.primitive_element)) - else: - codeword_ = codeword.view(np.ndarray) - syndrome_ = syndrome.view(np.ndarray) - y = function("decode", self.field)(codeword_, syndrome_, self.t, int(self.field.primitive_element)) + # Invoke the JIT compiled function + dec_codeword, N_errors = decode_jit.call(self.field, codeword, syndrome, self.t, int(self.field.primitive_element)) if self.systematic: - message = y[:, 0:ks] + message = dec_codeword[:, 0:ks] else: - message, _ = GF2._poly_divmod(y[:, 0:ns].view(GF2), self.generator_poly.coeffs) - message = message.astype(dtype).view(type(codeword)) - N_errors = y[:, -1] + message, _ = GF2._poly_divmod(dec_codeword[:, 0:ns].view(GF2), self.generator_poly.coeffs) + message = message.view(type(codeword)) # TODO: Remove this if codeword_1d: message, N_errors = message[0,:], N_errors[0] @@ -977,111 +970,91 @@ def is_narrow_sense(self) -> bool: return self._is_narrow_sense -############################################################################### -# JIT functions -############################################################################### - -POLY_ROOTS: Any -BERLEKAMP_MASSEY: Any - - -def function(name: str, field: Type[FieldArray]): +class decode_jit(JITFunction): """ - Returns a function implemented over the given field and ufunc mode. - """ - if field.ufunc_mode != "python-calculate": - return function_jit(name, field) - else: - return function_python(name, field) - - -def function_jit(name: str, field: Type[FieldArray]): - """ - Returns a JIT-compiled function implemented over the given field. - """ - key = (name, field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element)) - if key not in function_jit.cache: - # Set the globals once before JIT compiling the function - eval(f"set_{name}_globals")(field) - sig = eval(f"{name.upper()}_SIG") - function_jit.cache[key] = numba.jit(sig.signature, nopython=True)(eval(f"{name}_jit")) - - return function_jit.cache[key] - -function_jit.cache = {} + Performs BCH decoding. - -def function_python(name: str, field: Type[FieldArray]): - """ - Returns a pure-Python function. + References + ---------- + * Lin, S. and Costello, D. Error Control Coding. Section 7.4. """ - # Set the globals each time before invoking the pure-Python ufunc - eval(f"set_{name}_globals")(field) - return eval(f"{name}_jit") + _CACHE = {} + @classmethod + def call(cls, field, codeword, syndrome, t, primitive_element): + if field.ufunc_mode != "python-calculate": + codeword_ = codeword.astype(np.int64) + syndrome_ = syndrome.astype(np.int64) + y = cls.jit(field)(codeword_, syndrome_, t, primitive_element) + else: + codeword_ = codeword.view(np.ndarray) + syndrome_ = syndrome.view(np.ndarray) + y = cls.python(field)(codeword_, syndrome_, t, primitive_element) -def set_decode_globals(field: Type[FieldArray]): - global POLY_ROOTS, BERLEKAMP_MASSEY - POLY_ROOTS = field._function("poly_roots") - BERLEKAMP_MASSEY = _lfsr.function("berlekamp_massey", field) + dec_codeword, N_errors = y[:,0:-1], y[:,-1] + dec_codeword = dec_codeword.astype(codeword.dtype) + dec_codeword = dec_codeword.view(field) + return dec_codeword, N_errors -DECODE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64)) + @classmethod + def set_globals(cls, field: Type[FieldArray]): + # pylint: disable=global-variable-undefined + global POLY_ROOTS, BERLEKAMP_MASSEY + POLY_ROOTS = field._function("poly_roots") + BERLEKAMP_MASSEY = berlekamp_massey_jit.function(field) + _SIGNATURE = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64)) -def decode_jit(codeword, syndrome, t, primitive_element): # pragma: no cover - """ - References - ---------- - * Lin, S. and Costello, D. Error Control Coding. Section 7.4. - """ - dtype = codeword.dtype - N = codeword.shape[0] # The number of codewords - n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes) - - # The last column of the returned decoded codeword is the number of corrected errors - dec_codeword = np.zeros((N, n + 1), dtype=dtype) - dec_codeword[:, 0:n] = codeword[:,:] - - for i in range(N): - if not np.all(syndrome[i,:] == 0): - # The syndrome vector is S = [S0, S1, ..., S2t-1] - - # The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv, - # implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator - # value and ji is the error location. - - # The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots - # of σ(x). - - # Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards - # TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial - sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1])[::-1] - v = sigma_rev.size - 1 # The number of errors - - if v > t: - dec_codeword[i, -1] = -1 - continue - - # Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x) - degrees = np.arange(sigma_rev.size - 1, -1, -1) - results = POLY_ROOTS(degrees, sigma_rev, primitive_element) - beta = results[0,:] # The roots of σ(x^-v) - error_locations = results[1,:] # The roots as powers of the primitive element α - - if np.any(error_locations > n - 1): - # Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually - # more errors than alleged. Return failure to decode. - dec_codeword[i, -1] = -1 - continue - - if beta.size != v: - dec_codeword[i, -1] = -1 - continue - - for j in range(v): - # δi can only be 1 - dec_codeword[i, n - 1 - error_locations[j]] ^= 1 - dec_codeword[i, -1] = v # The number of corrected errors - - return dec_codeword + @staticmethod + def implementation(codeword, syndrome, t, primitive_element): # pragma: no cover + dtype = codeword.dtype + N = codeword.shape[0] # The number of codewords + n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes) + + # The last column of the returned decoded codeword is the number of corrected errors + dec_codeword = np.zeros((N, n + 1), dtype=dtype) + dec_codeword[:, 0:n] = codeword[:,:] + + for i in range(N): + if not np.all(syndrome[i,:] == 0): + # The syndrome vector is S = [S0, S1, ..., S2t-1] + + # The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv, + # implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator + # value and ji is the error location. + + # The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots + # of σ(x). + + # Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards + # TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial + sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1])[::-1] + v = sigma_rev.size - 1 # The number of errors + + if v > t: + dec_codeword[i, -1] = -1 + continue + + # Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x) + degrees = np.arange(sigma_rev.size - 1, -1, -1) + results = POLY_ROOTS(degrees, sigma_rev, primitive_element) + beta = results[0,:] # The roots of σ(x^-v) + error_locations = results[1,:] # The roots as powers of the primitive element α + + if np.any(error_locations > n - 1): + # Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually + # more errors than alleged. Return failure to decode. + dec_codeword[i, -1] = -1 + continue + + if beta.size != v: + dec_codeword[i, -1] = -1 + continue + + for j in range(v): + # δi can only be 1 + dec_codeword[i, n - 1 - error_locations[j]] ^= 1 + dec_codeword[i, -1] = v # The number of corrected errors + + return dec_codeword diff --git a/galois/_codes/_reed_solomon.py b/galois/_codes/_reed_solomon.py index 303ff377d..1fc64a862 100644 --- a/galois/_codes/_reed_solomon.py +++ b/galois/_codes/_reed_solomon.py @@ -3,15 +3,16 @@ """ from __future__ import annotations -from typing import Tuple, Optional, Union, Type, Any, overload +from typing import Tuple, Optional, Union, Type, overload from typing_extensions import Literal import numba from numba import int64 import numpy as np -from .. import _lfsr +from .._domains._function import JITFunction from .._fields import Field, FieldArray +from .._lfsr import berlekamp_massey_jit from .._overrides import set_module from .._polys import Poly, matlab_primitive_poly from .._prime import factors @@ -656,7 +657,6 @@ def decode(self, codeword, errors=False): raise ValueError(f"For a non-systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension equal to {self.n}, not shape {codeword.shape}.") codeword_1d = codeword.ndim == 1 - dtype = codeword.dtype ns = codeword.shape[-1] # The number of input codeword symbols (could be less than self.n for shortened codes) ks = self.k - (self.n - ns) # The equivalent number of input message symbols (could be less than self.k for shortened codes) @@ -666,21 +666,14 @@ def decode(self, codeword, errors=False): # Compute the syndrome by matrix multiplying with the parity-check matrix syndrome = codeword.view(self.field) @ self.H[:,-ns:].T - if self.field.ufunc_mode != "python-calculate": - codeword_ = codeword.astype(np.int64) - syndrome_ = syndrome.astype(np.int64) - y = function("decode", self.field)(codeword_, syndrome_, self.c, self.t, int(self.field.primitive_element)) - else: - codeword_ = codeword.view(np.ndarray) - syndrome_ = syndrome.view(np.ndarray) - y = function("decode", self.field)(codeword_, syndrome_, self.c, self.t, int(self.field.primitive_element)) + # Invoke the JIT compiled function + dec_codeword, N_errors = decode_jit.call(self.field, codeword, syndrome, self.c, self.t, int(self.field.primitive_element)) if self.systematic: - message = y[:, 0:ks] + message = dec_codeword[:, 0:ks] else: - message, _ = self.field._poly_divmod(y[:, 0:ns].view(self.field), self.generator_poly.coeffs) - message = message.astype(dtype).view(type(codeword)) - N_errors = y[:, -1] + message, _ = self.field._poly_divmod(dec_codeword[:, 0:ns].view(self.field), self.generator_poly.coeffs) + message = message.view(type(codeword)) # TODO: Remove this if codeword_1d: message, N_errors = message[0,:], N_errors[0] @@ -877,145 +870,118 @@ def is_narrow_sense(self) -> bool: return self._is_narrow_sense -############################################################################### -# JIT functions -############################################################################### - -CHARACTERISTIC: int -ORDER: int -SUBTRACT = np.subtract -MULTIPLY = np.multiply -RECIPROCAL = np.reciprocal -POWER = np.power -CONVOLVE = np.convolve -POLY_ROOTS: Any -POLY_EVALUATE: Any -BERLEKAMP_MASSEY: Any - - -def function(name: str, field: Type[FieldArray]): - """ - Returns a function implemented over the given field and ufunc mode. - """ - if field.ufunc_mode != "python-calculate": - return function_jit(name, field) - else: - return function_python(name, field) - - -def function_jit(name: str, field: Type[FieldArray]): - """ - Returns a JIT-compiled function implemented over the given field. - """ - key = (name, field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element)) - if key not in function_jit.cache: - # Set the globals once before JIT compiling the function - eval(f"set_{name}_globals")(field) - sig = eval(f"{name.upper()}_SIG") - function_jit.cache[key] = numba.jit(sig.signature, nopython=True)(eval(f"{name}_jit")) - - return function_jit.cache[key] - -function_jit.cache = {} - - -def function_python(name: str, field: Type[FieldArray]): - """ - Returns a pure-Python function. +class decode_jit(JITFunction): """ - # Set the globals each time before invoking the pure-Python ufunc - eval(f"set_{name}_globals")(field) - return eval(f"{name}_jit") - - -def set_decode_globals(field: Type[FieldArray]): - global CHARACTERISTIC, ORDER, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, CONVOLVE, POLY_ROOTS, POLY_EVALUATE, BERLEKAMP_MASSEY - CHARACTERISTIC = field.characteristic - ORDER = field.order - SUBTRACT = field._ufunc("subtract") - MULTIPLY = field._ufunc("multiply") - RECIPROCAL = field._ufunc("reciprocal") - POWER = field._ufunc("power") - CONVOLVE = field._function("convolve") - POLY_ROOTS = field._function("poly_roots") - POLY_EVALUATE = field._function("poly_evaluate") - BERLEKAMP_MASSEY = _lfsr.function("berlekamp_massey", field) - - -DECODE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, int64)) + Performs Reed-Solomon decoding. -def decode_jit(codeword, syndrome, c, t, primitive_element): # pragma: no cover - """ References ---------- * Lin, S. and Costello, D. Error Control Coding. Section 7.4. """ - dtype = codeword.dtype - N = codeword.shape[0] # The number of codewords - n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes) - design_n = ORDER - 1 # The designed codeword size - - # The last column of the returned decoded codeword is the number of corrected errors - dec_codeword = np.zeros((N, n + 1), dtype=dtype) - dec_codeword[:, 0:n] = codeword[:,:] - - for i in range(N): - if not np.all(syndrome[i,:] == 0): - # The syndrome vector is S = [S0, S1, ..., S2t-1] - - # The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv, - # implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator - # value and ji is the error location. - - # The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots - # of σ(x). - - # Compute the error-locator polynomial σ(x) - # TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial - sigma = BERLEKAMP_MASSEY(syndrome[i,:])[::-1] - v = sigma.size - 1 # The number of errors, which is the degree of the error-locator polynomial - - if v > t: - dec_codeword[i,-1] = -1 - continue - - # Compute βi^-1, the roots of σ(x) - degrees = np.arange(sigma.size - 1, -1, -1) - results = POLY_ROOTS(degrees, sigma, primitive_element) - beta_inv = results[0,:] # The roots βi^-1 of σ(x) - error_locations_inv = results[1,:] # The roots βi^-1 as powers of the primitive element α - error_locations = -error_locations_inv % design_n # The error locations as degrees of c(x) - - if np.any(error_locations > n - 1): - # Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually - # more errors than alleged. Return failure to decode. - dec_codeword[i,-1] = -1 - continue - - if beta_inv.size != v: - dec_codeword[i,-1] = -1 - continue - - # Compute σ'(x) - sigma_prime = np.zeros(v, dtype=dtype) - for j in range(v): - degree = v - j - sigma_prime[j] = MULTIPLY(degree % CHARACTERISTIC, sigma[j]) # Scalar multiplication - - # The error-value evaluator polynomial Z0(x) = S0*σ0 + (S1*σ0 + S0*σ1)*x + (S2*σ0 + S1*σ1 + S0*σ2)*x^2 + ... - # with degree v-1 - Z0 = CONVOLVE(sigma[-v:], syndrome[i,0:v][::-1])[-v:] - - # The error value δi = -1 * βi^(1-c) * Z0(βi^-1) / σ'(βi^-1) - for j in range(v): - beta_i = POWER(beta_inv[j], c - 1) - Z0_i = POLY_EVALUATE(Z0, np.array([beta_inv[j]], dtype=dtype))[0] # NOTE: poly_eval() expects a 1-D array of values - sigma_prime_i = POLY_EVALUATE(sigma_prime, np.array([beta_inv[j]], dtype=dtype))[0] # NOTE: poly_eval() expects a 1-D array of values - delta_i = MULTIPLY(beta_i, Z0_i) - delta_i = MULTIPLY(delta_i, RECIPROCAL(sigma_prime_i)) - delta_i = SUBTRACT(0, delta_i) - dec_codeword[i, n - 1 - error_locations[j]] = SUBTRACT(dec_codeword[i, n - 1 - error_locations[j]], delta_i) - - dec_codeword[i,-1] = v # The number of corrected errors - - return dec_codeword + _CACHE = {} + + @classmethod + def call(cls, field, codeword, syndrome, c, t, primitive_element): + if field.ufunc_mode != "python-calculate": + codeword_ = codeword.astype(np.int64) + syndrome_ = syndrome.astype(np.int64) + y = cls.jit(field)(codeword_, syndrome_, c, t, primitive_element) + else: + codeword_ = codeword.view(np.ndarray) + syndrome_ = syndrome.view(np.ndarray) + y = cls.python(field)(codeword_, syndrome_, c, t, primitive_element) + + dec_codeword, N_errors = y[:,0:-1], y[:,-1] + dec_codeword = dec_codeword.astype(codeword.dtype) + dec_codeword = dec_codeword.view(field) + + return dec_codeword, N_errors + + @classmethod + def set_globals(cls, field: Type[FieldArray]): + # pylint: disable=global-variable-undefined + global CHARACTERISTIC, ORDER, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, CONVOLVE, POLY_ROOTS, POLY_EVALUATE, BERLEKAMP_MASSEY + CHARACTERISTIC = field.characteristic + ORDER = field.order + SUBTRACT = field._ufunc("subtract") + MULTIPLY = field._ufunc("multiply") + RECIPROCAL = field._ufunc("reciprocal") + POWER = field._ufunc("power") + CONVOLVE = field._function("convolve") + POLY_ROOTS = field._function("poly_roots") + POLY_EVALUATE = field._function("poly_evaluate") + BERLEKAMP_MASSEY = berlekamp_massey_jit.function(field) + + _SIGNATURE = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, int64)) + + @staticmethod + def implementation(codeword, syndrome, c, t, primitive_element): # pragma: no cover + dtype = codeword.dtype + N = codeword.shape[0] # The number of codewords + n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes) + design_n = ORDER - 1 # The designed codeword size + + # The last column of the returned decoded codeword is the number of corrected errors + dec_codeword = np.zeros((N, n + 1), dtype=dtype) + dec_codeword[:, 0:n] = codeword[:,:] + + for i in range(N): + if not np.all(syndrome[i,:] == 0): + # The syndrome vector is S = [S0, S1, ..., S2t-1] + + # The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv, + # implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator + # value and ji is the error location. + + # The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots + # of σ(x). + + # Compute the error-locator polynomial σ(x) + # TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial + sigma = BERLEKAMP_MASSEY(syndrome[i,:])[::-1] + v = sigma.size - 1 # The number of errors, which is the degree of the error-locator polynomial + + if v > t: + dec_codeword[i,-1] = -1 + continue + + # Compute βi^-1, the roots of σ(x) + degrees = np.arange(sigma.size - 1, -1, -1) + results = POLY_ROOTS(degrees, sigma, primitive_element) + beta_inv = results[0,:] # The roots βi^-1 of σ(x) + error_locations_inv = results[1,:] # The roots βi^-1 as powers of the primitive element α + error_locations = -error_locations_inv % design_n # The error locations as degrees of c(x) + + if np.any(error_locations > n - 1): + # Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually + # more errors than alleged. Return failure to decode. + dec_codeword[i,-1] = -1 + continue + + if beta_inv.size != v: + dec_codeword[i,-1] = -1 + continue + + # Compute σ'(x) + sigma_prime = np.zeros(v, dtype=dtype) + for j in range(v): + degree = v - j + sigma_prime[j] = MULTIPLY(degree % CHARACTERISTIC, sigma[j]) # Scalar multiplication + + # The error-value evaluator polynomial Z0(x) = S0*σ0 + (S1*σ0 + S0*σ1)*x + (S2*σ0 + S1*σ1 + S0*σ2)*x^2 + ... + # with degree v-1 + Z0 = CONVOLVE(sigma[-v:], syndrome[i,0:v][::-1])[-v:] + + # The error value δi = -1 * βi^(1-c) * Z0(βi^-1) / σ'(βi^-1) + for j in range(v): + beta_i = POWER(beta_inv[j], c - 1) + Z0_i = POLY_EVALUATE(Z0, np.array([beta_inv[j]], dtype=dtype))[0] # NOTE: poly_eval() expects a 1-D array of values + sigma_prime_i = POLY_EVALUATE(sigma_prime, np.array([beta_inv[j]], dtype=dtype))[0] # NOTE: poly_eval() expects a 1-D array of values + delta_i = MULTIPLY(beta_i, Z0_i) + delta_i = MULTIPLY(delta_i, RECIPROCAL(sigma_prime_i)) + delta_i = SUBTRACT(0, delta_i) + dec_codeword[i, n - 1 - error_locations[j]] = SUBTRACT(dec_codeword[i, n - 1 - error_locations[j]], delta_i) + + dec_codeword[i,-1] = v # The number of corrected errors + + return dec_codeword diff --git a/galois/_domains/_function.py b/galois/_domains/_function.py index 4523cd4df..34645dce8 100644 --- a/galois/_domains/_function.py +++ b/galois/_domains/_function.py @@ -2,11 +2,13 @@ A module that contains Array mixin classes that override NumPy functions. """ import abc +from typing import Type, Callable import numba from numba import int64 import numpy as np +from ._array import Array from ._ufunc import RingUFuncs, FieldUFuncs ADD = np.add @@ -14,6 +16,62 @@ MULTIPLY = np.multiply +class JITFunction: + """ + Wrapper class for optionally JIT-compiled functions. + """ + _CACHE = {} # A cache of compiled functions. Should be cleared for each derived class. + + call: Callable + """Call the function, invoking either the JIT-compiled or pure-Python version.""" + + @classmethod + def set_globals(cls, field: Type[Array]): + """ + Set the global variables used in `implementation()` before JIT compiling it or before invoking it in pure Python. + """ + # pylint: disable=unused-argument + return + + _SIGNATURE: numba.types.FunctionType + """The function's Numba signature.""" + + implementation: Callable + """The function implementation in Python.""" + + @classmethod + def function(cls, field: Type[Array]): + """ + Returns a JIT-compiled or pure-Python function based on field size. + """ + if field.ufunc_mode != "python-calculate": + return cls.jit(field) + else: + return cls.python(field) + + @classmethod + def jit(cls, field: Type[Array]) -> numba.types.FunctionType: + """ + Returns a JIT-compiled function implemented over the given field. + """ + key = (field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element)) + if key not in cls._CACHE: + # Set the globals once before JIT compiling the function + cls.set_globals(field) + cls._CACHE[key] = numba.jit(cls._SIGNATURE.signature, nopython=True)(cls.implementation) + + return cls._CACHE[key] + + @classmethod + def python(cls, field: Type[Array]) -> Callable: + """ + Returns the pure-Python function implemented over the given field. + """ + # Set the globals each time before invoking the pure-Python function + cls.set_globals(field) + return cls.implementation + + class RingFunctions(RingUFuncs, abc.ABC): """ A mixin base class that overrides NumPy functions to perform ring arithmetic (+, -, *), using *only* explicit diff --git a/galois/_lfsr.py b/galois/_lfsr.py index 8195e48d8..ea3789820 100644 --- a/galois/_lfsr.py +++ b/galois/_lfsr.py @@ -10,6 +10,7 @@ import numpy as np from numba import int64 +from ._domains._function import JITFunction from ._fields import FieldArray from ._overrides import set_module from ._polys import Poly @@ -18,6 +19,10 @@ __all__ = ["FLFSR", "GLFSR", "berlekamp_massey"] +############################################################################### +# LFSR base class +############################################################################### + class _LFSR: r""" A linear-feedback shift register base object. @@ -105,18 +110,12 @@ def step(self, steps: int = 1) -> FieldArray: def _step_forward(self, steps): assert steps > 0 - if self.field.ufunc_mode != "python-calculate": - taps = self.taps.astype(np.int64) - state = self.state.astype(np.int64) - y = function(f"{self._type}_lfsr_step_forward", self.field)(taps, state, steps) - y = y.astype(self.state.dtype) + if self._type == "fibonacci": + y, state = fibonacci_lfsr_step_forward_jit.call(self.field, self.taps, self.state, steps) else: - taps = self.taps.view(np.ndarray) - state = self.state.view(np.ndarray) - y = function(f"{self._type}_lfsr_step_forward", self.field)(taps, state, steps) + y, state = galois_lfsr_step_forward_jit.call(self.field, self.taps, self.state, steps) self._state[:] = state[:] - y = self.field._view(y) if y.size == 1: y = y[0] @@ -128,18 +127,12 @@ def _step_backward(self, steps): if not self.characteristic_poly.coeffs[-1] > 0: raise ValueError(f"Can only step the shift register backwards if the c_0 tap is non-zero, not c(x) = {self.characteristic_poly}.") - if self.field.ufunc_mode != "python-calculate": - taps = self.taps.astype(np.int64) - state = self.state.astype(np.int64) - y = function(f"{self._type}_lfsr_step_backward", self.field)(taps, state, steps) - y = y.astype(self.state.dtype) + if self._type == "fibonacci": + y, state = fibonacci_lfsr_step_backward_jit.call(self.field, self.taps, self.state, steps) else: - taps = self.taps.view(np.ndarray) - state = self.state.view(np.ndarray) - y = function(f"{self._type}_lfsr_step_backward", self.field)(taps, state, steps) + y, state = galois_lfsr_step_backward_jit.call(self.field, self.taps, self.state, steps) self._state[:] = state[:] - y = self.field._view(y) if y.size == 1: y = y[0] @@ -174,6 +167,10 @@ def state(self) -> FieldArray: return self._state.copy() +############################################################################### +# Fibonacci LFSR +############################################################################### + @set_module("galois") class FLFSR(_LFSR): r""" @@ -694,6 +691,162 @@ def state(self) -> "FieldArray": return super().state +class fibonacci_lfsr_step_forward_jit(JITFunction): + """ + Steps the Fibonacci LFSR `steps` forward. + + .. code-block:: text + :caption: Fibonacci LFSR Configuration + + +--------------+<-------------+<-------------+<-------------+ + | ^ ^ ^ | + | | c_n-1 | c_n-2 | c_1 | c_0 + | | T[0] | T[1] | T[n-2] | T[n-1] + | +--------+ | +--------+ | | +--------+ | + +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] + +--------+ +--------+ +--------+ + y[t+n-1] y[t+n-2] y[t+1] + + Parameters + ---------- + taps + The set of taps T = [c_n-1, c_n-2, ..., c_1, c_0]. + state + The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! + steps + The number of output symbols to produce. + feedback + `True` indicates to output the feedback value `y_1[t]` (LRS) and `False` indicates to output the value out of the + shift register `y_2[t]`. + + Returns + ------- + y + The output sequence of size `steps`. + """ + _CACHE = {} + + @classmethod + def call(cls, field, taps, state, steps): + if field.ufunc_mode != "python-calculate": + taps_ = taps.astype(np.int64) + state_ = state.astype(np.int64) + y = cls.jit(field)(taps_, state_, steps) + y = y.astype(state.dtype) + else: + taps_ = taps.view(np.ndarray) + state_ = state.view(np.ndarray) + y = cls.python(field)(taps_, state_, steps) + y = field._view(y) + + return y, state_ + + @staticmethod + def set_globals(field: Type[FieldArray]): + # pylint: disable=global-variable-undefined + global ADD, MULTIPLY + ADD = field._ufunc("add") + MULTIPLY = field._ufunc("multiply") + + _SIGNATURE = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) + + @staticmethod + def implementation(taps, state, steps): # pragma: no cover + n = taps.size + y = np.zeros(steps, dtype=state.dtype) # The output array + + for i in range(steps): + f = 0 # The feedback value + for j in range(n): + f = ADD(f, MULTIPLY(state[j], taps[j])) + + y[i] = state[-1] # Output is popped off the shift register + state[1:] = state[0:-1] # Shift state rightward + state[0] = f # Insert feedback value at leftmost position + + return y + + +class fibonacci_lfsr_step_backward_jit(JITFunction): + """ + Steps the Fibonacci LFSR `steps` backward. + + .. code-block:: text + :caption: Fibonacci LFSR Configuration + + +--------------+<-------------+<-------------+<-------------+ + | ^ ^ ^ | + | | c_n-1 | c_n-2 | c_1 | c_0 + | | T[0] | T[1] | T[n-2] | T[n-1] + | +--------+ | +--------+ | | +--------+ | + +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] + +--------+ +--------+ +--------+ + y[t+n-1] y[t+n-2] y[t+1] + + Parameters + ---------- + taps + The set of taps T = [c_n-1, c_n-2, ..., c_1, c_0]. + state + The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! + steps + The number of output symbols to produce. + + Returns + ------- + y + The output sequence of size `steps`. + """ + _CACHE = {} + + @classmethod + def call(cls, field, taps, state, steps): + if field.ufunc_mode != "python-calculate": + taps_ = taps.astype(np.int64) + state_ = state.astype(np.int64) + y = cls.jit(field)(taps_, state_, steps) + y = y.astype(state.dtype) + else: + taps_ = taps.view(np.ndarray) + state_ = state.view(np.ndarray) + y = cls.python(field)(taps_, state_, steps) + y = field._view(y) + + return y, state_ + + @staticmethod + def set_globals(field: Type[FieldArray]): + global SUBTRACT, MULTIPLY, DIVIDE + SUBTRACT = field._ufunc("subtract") + MULTIPLY = field._ufunc("multiply") + DIVIDE = field._ufunc("divide") + + _SIGNATURE = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) + + @staticmethod + def implementation(taps, state, steps): # pragma: no cover + n = taps.size + y = np.zeros(steps, dtype=state.dtype) # The output array + + for i in range(steps): + f = state[0] # The feedback value + state[0:-1] = state[1:] # Shift state leftward + + s = f # The unknown previous state value + for j in range(n - 1): + s = SUBTRACT(s, MULTIPLY(state[j], taps[j])) + s = DIVIDE(s, taps[n - 1]) + + y[i] = s # The previous output was the last value in the shift register + state[-1] = s # Assign recovered state to the leftmost position + + return y + + +############################################################################### +# Galois LFSR +############################################################################### + @set_module("galois") class GLFSR(_LFSR): r""" @@ -1200,6 +1353,158 @@ def state(self) -> "FieldArray": return super().state +class galois_lfsr_step_forward_jit(JITFunction): + """ + Steps the Galois LFSR `steps` forward. + + .. code-block:: text + :caption: Galois LFSR Configuration + + +--------------+<-------------+<-------------+<-------------+ + | | | | | + | c_0 | c_1 | c_2 | c_n-1 | + | T[0] | T[1] | T[2] | T[n-1] | + | +--------+ v +--------+ v v +--------+ | + +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] + +--------+ +--------+ +--------+ + y[t+1] + + Parameters + ---------- + taps + The set of taps T = [c_0, c_1, ..., c_n-2, c_n-2]. + state + The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! + steps + The number of output symbols to produce. + + Returns + ------- + y + The output sequence of size `steps`. + """ + _CACHE = {} + + @classmethod + def call(cls, field, taps, state, steps): + if field.ufunc_mode != "python-calculate": + taps_ = taps.astype(np.int64) + state_ = state.astype(np.int64) + y = cls.jit(field)(taps_, state_, steps) + y = y.astype(state.dtype) + else: + taps_ = taps.view(np.ndarray) + state_ = state.view(np.ndarray) + y = cls.python(field)(taps_, state_, steps) + y = field._view(y) + + return y, state_ + + @staticmethod + def set_globals(field: Type[FieldArray]): + global ADD, MULTIPLY + ADD = field._ufunc("add") + MULTIPLY = field._ufunc("multiply") + + _SIGNATURE = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) + + @staticmethod + def implementation(taps, state, steps): # pragma: no cover + n = taps.size + y = np.zeros(steps, dtype=state.dtype) # The output array + + for i in range(steps): + f = state[n - 1] # The feedback value + y[i] = f # The output + + if f == 0: + state[1:] = state[0:-1] + state[0] = 0 + else: + for j in range(n - 1, 0, -1): + state[j] = ADD(state[j - 1], MULTIPLY(f, taps[j])) + state[0] = MULTIPLY(f, taps[0]) + + return y + + +class galois_lfsr_step_backward_jit(JITFunction): + """ + Steps the Galois LFSR `steps` backward. + + .. code-block:: text + :caption: Galois LFSR Configuration + + +--------------+<-------------+<-------------+<-------------+ + | | | | | + | c_0 | c_1 | c_2 | c_n-1 | + | T[0] | T[1] | T[2] | T[n-1] | + | +--------+ v +--------+ v v +--------+ | + +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] + +--------+ +--------+ +--------+ + y[t+1] + + Parameters + ---------- + taps + The set of taps T = [c_0, c_1, ..., c_n-2, c_n-2]. + state + The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! + steps + The number of output symbols to produce. + + Returns + ------- + y + The output sequence of size `steps`. + """ + _CACHE = {} + + @classmethod + def call(cls, field, taps, state, steps): + if field.ufunc_mode != "python-calculate": + taps_ = taps.astype(np.int64) + state_ = state.astype(np.int64) + y = cls.jit(field)(taps_, state_, steps) + y = y.astype(state.dtype) + else: + taps_ = taps.view(np.ndarray) + state_ = state.view(np.ndarray) + y = cls.python(field)(taps_, state_, steps) + y = field._view(y) + + return y, state_ + + @staticmethod + def set_globals(field: Type[FieldArray]): + global SUBTRACT, MULTIPLY, DIVIDE + SUBTRACT = field._ufunc("subtract") + MULTIPLY = field._ufunc("multiply") + DIVIDE = field._ufunc("divide") + + _SIGNATURE = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) + + @staticmethod + def implementation(taps, state, steps): # pragma: no cover + n = taps.size + y = np.zeros(steps, dtype=state.dtype) # The output array + + for i in range(steps): + f = DIVIDE(state[0], taps[0]) # The feedback value + + for j in range(0, n - 1): + state[j] = SUBTRACT(state[j + 1], MULTIPLY(f, taps[j + 1])) + + state[n - 1] = f + y[i] = f # The output + + return y + + +############################################################################### +# Berlekamp-Massey algorithm +############################################################################### + @overload def berlekamp_massey(sequence: "FieldArray", output: Literal["minimal"] = "minimal") -> Poly: ... @@ -1299,17 +1604,7 @@ def berlekamp_massey(sequence, output="minimal"): raise ValueError(f"Argument `output` must be in ['minimal', 'fibonacci', 'galois'], not {output!r}.") field = type(sequence) - dtype = sequence.dtype - - if field.ufunc_mode != "python-calculate": - sequence = sequence.astype(np.int64) - coeffs = function("berlekamp_massey", field)(sequence) - coeffs = coeffs.astype(dtype) - else: - sequence = sequence.view(np.ndarray) - coeffs = function("berlekamp_massey", field)(sequence) - coeffs = field._view(coeffs) - + coeffs = berlekamp_massey_jit.call(field, sequence) characteristic_poly = Poly(coeffs, field=field) if output == "minimal": @@ -1326,328 +1621,67 @@ def berlekamp_massey(sequence, output="minimal"): return fibonacci_lfsr.to_galois_lfsr() -############################################################################### -# JIT functions -############################################################################### - -ADD = np.add -SUBTRACT = np.subtract -MULTIPLY = np.multiply -DIVIDE = np.divide -RECIPROCAL = np.reciprocal - - -def function(name: str, field: Type[FieldArray]): - """ - Returns a function implemented over the given field and ufunc mode. - """ - if field.ufunc_mode != "python-calculate": - return function_jit(name, field) - else: - return function_python(name, field) - - -def function_jit(name: str, field: Type[FieldArray]): - """ - Returns a JIT-compiled function implemented over the given field. - """ - key = (name, field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element)) - if key not in function_jit.cache: - # Set the globals once before JIT compiling the function - eval(f"set_{name}_globals")(field) - sig = eval(f"{name.upper()}_SIG") - function_jit.cache[key] = numba.jit(sig.signature, nopython=True)(eval(f"{name}_jit")) - - return function_jit.cache[key] - -function_jit.cache = {} - - -def function_python(name: str, field: Type[FieldArray]): - """ - Returns a pure-Python function. - """ - # Set the globals each time before invoking the pure-Python ufunc - eval(f"set_{name}_globals")(field) - return eval(f"{name}_jit") - - -############################################################################### -# Fibonacci LFSR JIT functions -############################################################################### - -def set_fibonacci_lfsr_step_forward_globals(field: Type[FieldArray]): - global ADD, MULTIPLY - ADD = field._ufunc("add") - MULTIPLY = field._ufunc("multiply") - - -FIBONACCI_LFSR_STEP_FORWARD_SIG = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) - -def fibonacci_lfsr_step_forward_jit(taps, state, steps): # pragma: no cover - """ - Steps the Fibonacci LFSR `steps` forward. - - .. code-block:: text - :caption: Fibonacci LFSR Configuration - - +--------------+<-------------+<-------------+<-------------+ - | ^ ^ ^ | - | | c_n-1 | c_n-2 | c_1 | c_0 - | | T[0] | T[1] | T[n-2] | T[n-1] - | +--------+ | +--------+ | | +--------+ | - +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] - +--------+ +--------+ +--------+ - y[t+n-1] y[t+n-2] y[t+1] - - Parameters - ---------- - taps - The set of taps T = [c_n-1, c_n-2, ..., c_1, c_0]. - state - The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! - steps - The number of output symbols to produce. - feedback - `True` indicates to output the feedback value `y_1[t]` (LRS) and `False` indicates to output the value out of the - shift register `y_2[t]`. - - Returns - ------- - y - The output sequence of size `steps`. +class berlekamp_massey_jit(JITFunction): """ - n = taps.size - y = np.zeros(steps, dtype=state.dtype) # The output array - - for i in range(steps): - f = 0 # The feedback value - for j in range(n): - f = ADD(f, MULTIPLY(state[j], taps[j])) - - y[i] = state[-1] # Output is popped off the shift register - state[1:] = state[0:-1] # Shift state rightward - state[0] = f # Insert feedback value at leftmost position - - return y - - -def set_fibonacci_lfsr_step_backward_globals(field: Type[FieldArray]): - global SUBTRACT, MULTIPLY, DIVIDE - SUBTRACT = field._ufunc("subtract") - MULTIPLY = field._ufunc("multiply") - DIVIDE = field._ufunc("divide") - - -FIBONACCI_LFSR_STEP_BACKWARD_SIG = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) - -def fibonacci_lfsr_step_backward_jit(taps, state, steps): # pragma: no cover + Finds the minimal polynomial c(x) of the input sequence. """ - Steps the Fibonacci LFSR `steps` backward. - - .. code-block:: text - :caption: Fibonacci LFSR Configuration - - +--------------+<-------------+<-------------+<-------------+ - | ^ ^ ^ | - | | c_n-1 | c_n-2 | c_1 | c_0 - | | T[0] | T[1] | T[n-2] | T[n-1] - | +--------+ | +--------+ | | +--------+ | - +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] - +--------+ +--------+ +--------+ - y[t+n-1] y[t+n-2] y[t+1] + _CACHE = {} - Parameters - ---------- - taps - The set of taps T = [c_n-1, c_n-2, ..., c_1, c_0]. - state - The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! - steps - The number of output symbols to produce. - - Returns - ------- - y - The output sequence of size `steps`. - """ - n = taps.size - y = np.zeros(steps, dtype=state.dtype) # The output array - - for i in range(steps): - f = state[0] # The feedback value - state[0:-1] = state[1:] # Shift state leftward - - s = f # The unknown previous state value - for j in range(n - 1): - s = SUBTRACT(s, MULTIPLY(state[j], taps[j])) - s = DIVIDE(s, taps[n - 1]) - - y[i] = s # The previous output was the last value in the shift register - state[-1] = s # Assign recovered state to the leftmost position - - return y - - -############################################################################### -# Galois LFSR JIT functions -############################################################################### - -def set_galois_lfsr_step_forward_globals(field: Type[FieldArray]): - global ADD, MULTIPLY - ADD = field._ufunc("add") - MULTIPLY = field._ufunc("multiply") - - -GALOIS_LFSR_STEP_FORWARD_SIG = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) - -def galois_lfsr_step_forward_jit(taps, state, steps): # pragma: no cover - """ - Steps the Galois LFSR `steps` forward. - - .. code-block:: text - :caption: Galois LFSR Configuration - - +--------------+<-------------+<-------------+<-------------+ - | | | | | - | c_0 | c_1 | c_2 | c_n-1 | - | T[0] | T[1] | T[2] | T[n-1] | - | +--------+ v +--------+ v v +--------+ | - +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] - +--------+ +--------+ +--------+ - y[t+1] - - Parameters - ---------- - taps - The set of taps T = [c_0, c_1, ..., c_n-2, c_n-2]. - state - The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! - steps - The number of output symbols to produce. - - Returns - ------- - y - The output sequence of size `steps`. - """ - n = taps.size - y = np.zeros(steps, dtype=state.dtype) # The output array - - for i in range(steps): - f = state[n - 1] # The feedback value - y[i] = f # The output - - if f == 0: - state[1:] = state[0:-1] - state[0] = 0 - else: - for j in range(n - 1, 0, -1): - state[j] = ADD(state[j - 1], MULTIPLY(f, taps[j])) - state[0] = MULTIPLY(f, taps[0]) - - return y - - -def set_galois_lfsr_step_backward_globals(field: Type[FieldArray]): - global SUBTRACT, MULTIPLY, DIVIDE - SUBTRACT = field._ufunc("subtract") - MULTIPLY = field._ufunc("multiply") - DIVIDE = field._ufunc("divide") - - -GALOIS_LFSR_STEP_BACKWARD_SIG = numba.types.FunctionType(int64[:](int64[:], int64[:], int64)) - -def galois_lfsr_step_backward_jit(taps, state, steps): # pragma: no cover - """ - Steps the Galois LFSR `steps` backward. - - .. code-block:: text - :caption: Galois LFSR Configuration - - +--------------+<-------------+<-------------+<-------------+ - | | | | | - | c_0 | c_1 | c_2 | c_n-1 | - | T[0] | T[1] | T[2] | T[n-1] | - | +--------+ v +--------+ v v +--------+ | - +->| S[0] |--+->| S[1] |--+--- ... ---+->| S[n-1] |--+--> y[t] - +--------+ +--------+ +--------+ - y[t+1] - - Parameters - ---------- - taps - The set of taps T = [c_0, c_1, ..., c_n-2, c_n-2]. - state - The state vector [S_0, S_1, ..., S_n-2, S_n-1]. State will be modified in-place! - steps - The number of output symbols to produce. - - Returns - ------- - y - The output sequence of size `steps`. - """ - n = taps.size - y = np.zeros(steps, dtype=state.dtype) # The output array - - for i in range(steps): - f = DIVIDE(state[0], taps[0]) # The feedback value - - for j in range(0, n - 1): - state[j] = SUBTRACT(state[j + 1], MULTIPLY(f, taps[j + 1])) - - state[n - 1] = f - y[i] = f # The output - - return y - - -############################################################################### -# Berlekamp-Massey JIT functions -############################################################################### - -def set_berlekamp_massey_globals(field: Type[FieldArray]): - global ADD, SUBTRACT, MULTIPLY, RECIPROCAL - ADD = field._ufunc("add") - SUBTRACT = field._ufunc("subtract") - MULTIPLY = field._ufunc("multiply") - RECIPROCAL = field._ufunc("reciprocal") - - -BERLEKAMP_MASSEY_SIG = numba.types.FunctionType(int64[:](int64[:])) - -def berlekamp_massey_jit(sequence): # pragma: no cover - N = sequence.size - s = sequence - c = np.zeros(N, dtype=sequence.dtype) - b = np.zeros(N, dtype=sequence.dtype) - c[0] = 1 # The polynomial c(x) = 1 - b[0] = 1 # The polynomial b(x) = 1 - L = 0 - m = 1 - bb = 1 - - for n in range(0, N): - d = 0 - for i in range(0, L + 1): - d = ADD(d, MULTIPLY(s[n - i], c[i])) - - if d == 0: - m += 1 - elif 2*L <= n: - t = c.copy() - d_bb = MULTIPLY(d, RECIPROCAL(bb)) - for i in range(m, N): - c[i] = SUBTRACT(c[i], MULTIPLY(d_bb, b[i - m])) - L = n + 1 - L - b = t.copy() - bb = d - m = 1 + @classmethod + def call(cls, field, sequence): + if field.ufunc_mode != "python-calculate": + sequence = sequence.astype(np.int64) + coeffs = cls.jit(field)(sequence) + coeffs = coeffs.astype(sequence.dtype) else: - d_bb = MULTIPLY(d, RECIPROCAL(bb)) - for i in range(m, N): - c[i] = SUBTRACT(c[i], MULTIPLY(d_bb, b[i - m])) - m += 1 - - return c[0:L + 1] + sequence = sequence.view(np.ndarray) + coeffs = cls.python(field)(sequence) + coeffs = field._view(coeffs) + + return coeffs + + @staticmethod + def set_globals(field: Type[FieldArray]): + global ADD, SUBTRACT, MULTIPLY, RECIPROCAL + ADD = field._ufunc("add") + SUBTRACT = field._ufunc("subtract") + MULTIPLY = field._ufunc("multiply") + RECIPROCAL = field._ufunc("reciprocal") + + _SIGNATURE = numba.types.FunctionType(int64[:](int64[:])) + + @staticmethod + def implementation(sequence): # pragma: no cover + N = sequence.size + s = sequence + c = np.zeros(N, dtype=sequence.dtype) + b = np.zeros(N, dtype=sequence.dtype) + c[0] = 1 # The polynomial c(x) = 1 + b[0] = 1 # The polynomial b(x) = 1 + L = 0 + m = 1 + bb = 1 + + for n in range(0, N): + d = 0 + for i in range(0, L + 1): + d = ADD(d, MULTIPLY(s[n - i], c[i])) + + if d == 0: + m += 1 + elif 2*L <= n: + t = c.copy() + d_bb = MULTIPLY(d, RECIPROCAL(bb)) + for i in range(m, N): + c[i] = SUBTRACT(c[i], MULTIPLY(d_bb, b[i - m])) + L = n + 1 - L + b = t.copy() + bb = d + m = 1 + else: + d_bb = MULTIPLY(d, RECIPROCAL(bb)) + for i in range(m, N): + c[i] = SUBTRACT(c[i], MULTIPLY(d_bb, b[i - m])) + m += 1 + + return c[0:L + 1]