Skip to content

Commit

Permalink
Rework FEC code JIT function structure
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed May 9, 2022
1 parent 220a669 commit 2ca4d16
Show file tree
Hide file tree
Showing 4 changed files with 646 additions and 615 deletions.
201 changes: 87 additions & 114 deletions galois/_codes/_bch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from __future__ import annotations

import math
from typing import Tuple, List, Optional, Union, Type, Any, overload
from typing import Tuple, List, Optional, Union, Type, overload
from typing_extensions import Literal

import numba
from numba import int64
import numpy as np

from .. import _lfsr
from .._domains._function import JITFunction
from .._fields import Field, FieldArray, GF2 # pylint: disable=unused-import
from .._lfsr import berlekamp_massey_jit
from .._overrides import set_module
from .._polys import Poly, matlab_primitive_poly
from .._prime import factors
Expand Down Expand Up @@ -764,7 +765,6 @@ def decode(self, codeword, errors=False):
raise ValueError(f"For a non-systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension equal to {self.n}, not shape {codeword.shape}.")

codeword_1d = codeword.ndim == 1
dtype = codeword.dtype
ns = codeword.shape[-1] # The number of input codeword bits (could be less than self.n for shortened codes)
ks = self.k - (self.n - ns) # The equivalent number of input message bits (could be less than self.k for shortened codes)

Expand All @@ -774,21 +774,14 @@ def decode(self, codeword, errors=False):
# Compute the syndrome by matrix multiplying with the parity-check matrix
syndrome = codeword.view(self.field) @ self.H[:,-ns:].T

if self.field.ufunc_mode != "python-calculate":
codeword_ = codeword.astype(np.int64)
syndrome_ = syndrome.astype(np.int64)
y = function("decode", self.field)(codeword_, syndrome_, self.t, int(self.field.primitive_element))
else:
codeword_ = codeword.view(np.ndarray)
syndrome_ = syndrome.view(np.ndarray)
y = function("decode", self.field)(codeword_, syndrome_, self.t, int(self.field.primitive_element))
# Invoke the JIT compiled function
dec_codeword, N_errors = decode_jit.call(self.field, codeword, syndrome, self.t, int(self.field.primitive_element))

if self.systematic:
message = y[:, 0:ks]
message = dec_codeword[:, 0:ks]
else:
message, _ = GF2._poly_divmod(y[:, 0:ns].view(GF2), self.generator_poly.coeffs)
message = message.astype(dtype).view(type(codeword))
N_errors = y[:, -1]
message, _ = GF2._poly_divmod(dec_codeword[:, 0:ns].view(GF2), self.generator_poly.coeffs)
message = message.view(type(codeword)) # TODO: Remove this

if codeword_1d:
message, N_errors = message[0,:], N_errors[0]
Expand Down Expand Up @@ -977,111 +970,91 @@ def is_narrow_sense(self) -> bool:
return self._is_narrow_sense


###############################################################################
# JIT functions
###############################################################################

POLY_ROOTS: Any
BERLEKAMP_MASSEY: Any


def function(name: str, field: Type[FieldArray]):
class decode_jit(JITFunction):
"""
Returns a function implemented over the given field and ufunc mode.
"""
if field.ufunc_mode != "python-calculate":
return function_jit(name, field)
else:
return function_python(name, field)


def function_jit(name: str, field: Type[FieldArray]):
"""
Returns a JIT-compiled function implemented over the given field.
"""
key = (name, field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element))
if key not in function_jit.cache:
# Set the globals once before JIT compiling the function
eval(f"set_{name}_globals")(field)
sig = eval(f"{name.upper()}_SIG")
function_jit.cache[key] = numba.jit(sig.signature, nopython=True)(eval(f"{name}_jit"))

return function_jit.cache[key]

function_jit.cache = {}
Performs BCH decoding.

def function_python(name: str, field: Type[FieldArray]):
"""
Returns a pure-Python function.
References
----------
* Lin, S. and Costello, D. Error Control Coding. Section 7.4.
"""
# Set the globals each time before invoking the pure-Python ufunc
eval(f"set_{name}_globals")(field)
return eval(f"{name}_jit")
_CACHE = {}

@classmethod
def call(cls, field, codeword, syndrome, t, primitive_element):
if field.ufunc_mode != "python-calculate":
codeword_ = codeword.astype(np.int64)
syndrome_ = syndrome.astype(np.int64)
y = cls.jit(field)(codeword_, syndrome_, t, primitive_element)
else:
codeword_ = codeword.view(np.ndarray)
syndrome_ = syndrome.view(np.ndarray)
y = cls.python(field)(codeword_, syndrome_, t, primitive_element)

def set_decode_globals(field: Type[FieldArray]):
global POLY_ROOTS, BERLEKAMP_MASSEY
POLY_ROOTS = field._function("poly_roots")
BERLEKAMP_MASSEY = _lfsr.function("berlekamp_massey", field)
dec_codeword, N_errors = y[:,0:-1], y[:,-1]
dec_codeword = dec_codeword.astype(codeword.dtype)
dec_codeword = dec_codeword.view(field)

return dec_codeword, N_errors

DECODE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64))
@classmethod
def set_globals(cls, field: Type[FieldArray]):
# pylint: disable=global-variable-undefined
global POLY_ROOTS, BERLEKAMP_MASSEY
POLY_ROOTS = field._function("poly_roots")
BERLEKAMP_MASSEY = berlekamp_massey_jit.function(field)

_SIGNATURE = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64))

def decode_jit(codeword, syndrome, t, primitive_element): # pragma: no cover
"""
References
----------
* Lin, S. and Costello, D. Error Control Coding. Section 7.4.
"""
dtype = codeword.dtype
N = codeword.shape[0] # The number of codewords
n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes)

# The last column of the returned decoded codeword is the number of corrected errors
dec_codeword = np.zeros((N, n + 1), dtype=dtype)
dec_codeword[:, 0:n] = codeword[:,:]

for i in range(N):
if not np.all(syndrome[i,:] == 0):
# The syndrome vector is S = [S0, S1, ..., S2t-1]

# The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv,
# implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator
# value and ji is the error location.

# The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots
# of σ(x).

# Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards
# TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial
sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1])[::-1]
v = sigma_rev.size - 1 # The number of errors

if v > t:
dec_codeword[i, -1] = -1
continue

# Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x)
degrees = np.arange(sigma_rev.size - 1, -1, -1)
results = POLY_ROOTS(degrees, sigma_rev, primitive_element)
beta = results[0,:] # The roots of σ(x^-v)
error_locations = results[1,:] # The roots as powers of the primitive element α

if np.any(error_locations > n - 1):
# Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually
# more errors than alleged. Return failure to decode.
dec_codeword[i, -1] = -1
continue

if beta.size != v:
dec_codeword[i, -1] = -1
continue

for j in range(v):
# δi can only be 1
dec_codeword[i, n - 1 - error_locations[j]] ^= 1
dec_codeword[i, -1] = v # The number of corrected errors

return dec_codeword
@staticmethod
def implementation(codeword, syndrome, t, primitive_element): # pragma: no cover
dtype = codeword.dtype
N = codeword.shape[0] # The number of codewords
n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes)

# The last column of the returned decoded codeword is the number of corrected errors
dec_codeword = np.zeros((N, n + 1), dtype=dtype)
dec_codeword[:, 0:n] = codeword[:,:]

for i in range(N):
if not np.all(syndrome[i,:] == 0):
# The syndrome vector is S = [S0, S1, ..., S2t-1]

# The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv,
# implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator
# value and ji is the error location.

# The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots
# of σ(x).

# Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards
# TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial
sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1])[::-1]
v = sigma_rev.size - 1 # The number of errors

if v > t:
dec_codeword[i, -1] = -1
continue

# Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x)
degrees = np.arange(sigma_rev.size - 1, -1, -1)
results = POLY_ROOTS(degrees, sigma_rev, primitive_element)
beta = results[0,:] # The roots of σ(x^-v)
error_locations = results[1,:] # The roots as powers of the primitive element α

if np.any(error_locations > n - 1):
# Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually
# more errors than alleged. Return failure to decode.
dec_codeword[i, -1] = -1
continue

if beta.size != v:
dec_codeword[i, -1] = -1
continue

for j in range(v):
# δi can only be 1
dec_codeword[i, n - 1 - error_locations[j]] ^= 1
dec_codeword[i, -1] = v # The number of corrected errors

return dec_codeword
Loading

0 comments on commit 2ca4d16

Please sign in to comment.