Skip to content

Commit

Permalink
Accept ArrayLike to FEC encode/decode methods
Browse files Browse the repository at this point in the history
Fixes #395
  • Loading branch information
mhostetter committed Jul 29, 2022
1 parent d0495cf commit 0ddcbd7
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 442 deletions.
25 changes: 11 additions & 14 deletions galois/_codes/_bch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .._polys import Poly, matlab_primitive_poly
from .._polys._dense import roots_jit, divmod_jit
from .._prime import factors
from ..typing import PolyLike
from ..typing import ArrayLike, PolyLike

from ._cyclic import poly_to_generator_matrix, roots_to_parity_check_matrix

Expand Down Expand Up @@ -297,7 +297,7 @@ def __str__(self) -> str:

return string

def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) -> GF2:
def encode(self, message: ArrayLike, parity_only: bool = False) -> GF2:
r"""
Encodes the message :math:`\mathbf{m}` into the BCH codeword :math:`\mathbf{c}`.
Expand Down Expand Up @@ -405,8 +405,7 @@ def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) ->
p = bch.encode(m, parity_only=True); p
"""
if not isinstance(message, (np.ndarray, GF2)):
raise TypeError(f"Argument `message` must be a subclass of np.ndarray (or a galois.GF2 array), not {type(message)}.")
message = GF2(message) # This performs type/value checking
if parity_only and not self.is_systematic:
raise ValueError("Argument `parity_only=True` only applies to systematic codes.")
if self.is_systematic:
Expand All @@ -419,17 +418,17 @@ def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) ->
ks = message.shape[-1] # The number of input message bits (could be less than self.k for shortened codes)

if parity_only:
parity = message.view(GF2) @ self.G[-ks:, self.k:]
parity = message @ self.G[-ks:, self.k:]
return parity
elif self.is_systematic:
parity = message.view(GF2) @ self.G[-ks:, self.k:]
parity = message @ self.G[-ks:, self.k:]
codeword = np.hstack((message, parity))
return codeword
else:
codeword = message.view(GF2) @ self.G
codeword = message @ self.G
return codeword

def detect(self, codeword: Union[np.ndarray, GF2]) -> Union[np.bool_, np.ndarray]:
def detect(self, codeword: ArrayLike) -> Union[np.bool_, np.ndarray]:
r"""
Detects if errors are present in the BCH codeword :math:`\mathbf{c}`.
Expand Down Expand Up @@ -558,8 +557,7 @@ def detect(self, codeword: Union[np.ndarray, GF2]) -> Union[np.bool_, np.ndarray
c
bch.detect(c)
"""
if not isinstance(codeword, np.ndarray):
raise TypeError(f"Argument `codeword` must be a subclass of np.ndarray (or a galois.GF2 array), not {type(codeword)}.")
codeword = GF2(codeword) # This performs type/value checking
if self.is_systematic:
if not codeword.shape[-1] <= self.n:
raise ValueError(f"For a systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension less than or equal to {self.n}, not shape {codeword.shape}.")
Expand All @@ -584,10 +582,10 @@ def detect(self, codeword: Union[np.ndarray, GF2]) -> Union[np.bool_, np.ndarray
return detected

@overload
def decode(self, codeword: Union[np.ndarray, GF2], errors: Literal[False] = False) -> GF2:
def decode(self, codeword: ArrayLike, errors: Literal[False] = False) -> GF2:
...
@overload
def decode(self, codeword: Union[np.ndarray, GF2], errors: Literal[True]) -> Tuple[GF2, Union[np.integer, np.ndarray]]:
def decode(self, codeword: ArrayLike, errors: Literal[True]) -> Tuple[GF2, Union[np.integer, np.ndarray]]:
...
def decode(self, codeword, errors=False):
r"""
Expand Down Expand Up @@ -760,8 +758,7 @@ def decode(self, codeword, errors=False):
d, e = bch.decode(c, errors=True); d, e
np.array_equal(d, m)
"""
if not isinstance(codeword, (np.ndarray, GF2)):
raise TypeError(f"Argument `codeword` must be a subclass of np.ndarray (or a galois.GF2 array), not {type(codeword)}.")
codeword = GF2(codeword) # This performs type/value checking
if self.is_systematic:
if not codeword.shape[-1] <= self.n:
raise ValueError(f"For a systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension less than or equal to {self.n}, not shape {codeword.shape}.")
Expand Down
25 changes: 11 additions & 14 deletions galois/_codes/_reed_solomon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .._polys import Poly, matlab_primitive_poly
from .._polys._dense import divmod_jit, roots_jit, evaluate_elementwise_jit
from .._prime import factors
from ..typing import PolyLike
from ..typing import ArrayLike, PolyLike

from ._cyclic import poly_to_generator_matrix, roots_to_parity_check_matrix

Expand Down Expand Up @@ -182,7 +182,7 @@ def __str__(self) -> str:

return string

def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = False) -> FieldArray:
def encode(self, message: ArrayLike, parity_only: bool = False) -> FieldArray:
r"""
Encodes the message :math:`\mathbf{m}` into the Reed-Solomon codeword :math:`\mathbf{c}`.
Expand Down Expand Up @@ -290,8 +290,7 @@ def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = Fal
p = rs.encode(m, parity_only=True); p
"""
if not isinstance(message, np.ndarray):
raise TypeError(f"Argument `message` must be a subclass of np.ndarray (or a galois.GF2 array), not {type(message)}.")
message = self.field(message) # This performs type/value checking
if parity_only and not self.is_systematic:
raise ValueError("Argument `parity_only=True` only applies to systematic codes.")
if self.is_systematic:
Expand All @@ -304,17 +303,17 @@ def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = Fal
ks = message.shape[-1] # The number of input message symbols (could be less than self.k for shortened codes)

if parity_only:
parity = message.view(self.field) @ self.G[-ks:, self.k:]
parity = message @ self.G[-ks:, self.k:]
return parity
elif self.is_systematic:
parity = message.view(self.field) @ self.G[-ks:, self.k:]
parity = message @ self.G[-ks:, self.k:]
codeword = np.hstack((message, parity))
return codeword
else:
codeword = message.view(self.field) @ self.G
codeword = message @ self.G
return codeword

def detect(self, codeword: Union[np.ndarray, FieldArray]) -> Union[np.bool_, np.ndarray]:
def detect(self, codeword: ArrayLike) -> Union[np.bool_, np.ndarray]:
r"""
Detects if errors are present in the Reed-Solomon codeword :math:`\mathbf{c}`.
Expand Down Expand Up @@ -445,8 +444,7 @@ def detect(self, codeword: Union[np.ndarray, FieldArray]) -> Union[np.bool_, np.
c
rs.detect(c)
"""
if not isinstance(codeword, np.ndarray):
raise TypeError(f"Argument `codeword` must be a subclass of np.ndarray (or a galois.GF2 array), not {type(codeword)}.")
codeword = self.field(codeword) # This performs type/value checking
if self.is_systematic:
if not codeword.shape[-1] <= self.n:
raise ValueError(f"For a systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension less than or equal to {self.n}, not shape {codeword.shape}.")
Expand All @@ -471,10 +469,10 @@ def detect(self, codeword: Union[np.ndarray, FieldArray]) -> Union[np.bool_, np.
return detected

@overload
def decode(self, codeword: Union[np.ndarray, FieldArray], errors: Literal[False] = False) -> FieldArray:
def decode(self, codeword: ArrayLike, errors: Literal[False] = False) -> FieldArray:
...
@overload
def decode(self, codeword: Union[np.ndarray, FieldArray], errors: Literal[True]) -> Tuple[FieldArray, Union[np.integer, np.ndarray]]:
def decode(self, codeword: ArrayLike, errors: Literal[True]) -> Tuple[FieldArray, Union[np.integer, np.ndarray]]:
...
def decode(self, codeword, errors=False):
r"""
Expand Down Expand Up @@ -650,8 +648,7 @@ def decode(self, codeword, errors=False):
d, e = rs.decode(c, errors=True); d, e
np.array_equal(d, m)
"""
if not isinstance(codeword, np.ndarray):
raise TypeError(f"Argument `codeword` must be a subclass of np.ndarray (or a galois.FieldArray), not {type(codeword)}.")
codeword = self.field(codeword) # This performs type/value checking
if self.is_systematic:
if not codeword.shape[-1] <= self.n:
raise ValueError(f"For a systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension less than or equal to {self.n}, not shape {codeword.shape}.")
Expand Down
15 changes: 15 additions & 0 deletions tests/codes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,18 @@ def random_errors(GF, N, n, max_errors):
E[i, random.sample(list(range(n)), N_errors[i])] = GF.Random(N_errors[i], low=1)

return E, N_errors


def random_type(array):
"""
Randomly vary the input type to encode()/decode() across various ArrayLike inputs.
"""
x = random.randint(0, 2)
if x == 0:
# A FieldArray instance
return array
elif x == 1:
# A np.ndarray instance
return array.view(np.ndarray)
else:
return array.tolist()
96 changes: 25 additions & 71 deletions tests/codes/test_bch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import galois

from .helper import random_errors
from .helper import random_errors, random_type

CODES = [
(15, 11), # GF(2^4) with t=1
Expand All @@ -32,17 +32,13 @@ def test_exceptions():
n, k = 15, 7
bch = galois.BCH(n, k)
GF = galois.GF2
with pytest.raises(TypeError):
bch.decode(GF.Random(n).tolist())
with pytest.raises(ValueError):
bch.decode(GF.Random(n + 1))

# Non-systematic
n, k = 15, 7
bch = galois.BCH(n, k, systematic=False)
GF = galois.GF2
with pytest.raises(TypeError):
bch.decode(GF.Random(n).tolist())
with pytest.raises(ValueError):
bch.decode(GF.Random(n - 1))

Expand All @@ -58,20 +54,13 @@ def test_all_correctable(self, size):
E, N_errors = random_errors(galois.GF2, N, n, bch.t)
R = C + E

DEC_M = bch.decode(R)
RR = random_type(R)
DEC_M = bch.decode(RR)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
RR = random_type(R)
DEC_M, N_corr = bch.decode(RR, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)
Expand All @@ -88,20 +77,13 @@ def test_some_uncorrectable(self, size):

corr_idxs = np.where(N_errors <= bch.t)[0]

DEC_M = bch.decode(R)
RR = random_type(R)
DEC_M = bch.decode(RR)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
RR = random_type(R)
DEC_M, N_corr = bch.decode(RR, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])
Expand All @@ -120,20 +102,13 @@ def test_all_correctable(self, size):
E, N_errors = random_errors(galois.GF2, N, ns, bch.t)
R = C + E

DEC_M = bch.decode(R)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)

DEC_M = bch.decode(R.view(np.ndarray))
RR = random_type(R)
DEC_M = bch.decode(RR)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
RR = random_type(R)
DEC_M, N_corr = bch.decode(RR, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)
Expand All @@ -152,20 +127,13 @@ def test_some_uncorrectable(self, size):

corr_idxs = np.where(N_errors <= bch.t)[0]

DEC_M = bch.decode(R)
RR = random_type(R)
DEC_M = bch.decode(RR)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
RR = random_type(R)
DEC_M, N_corr = bch.decode(RR, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])
Expand All @@ -182,20 +150,13 @@ def test_all_correctable(self, size):
E, N_errors = random_errors(galois.GF2, N, n, bch.t)
R = C + E

DEC_M = bch.decode(R)
RR = random_type(R)
DEC_M = bch.decode(RR)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
RR = random_type(R)
DEC_M, N_corr = bch.decode(RR, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)
Expand All @@ -212,20 +173,13 @@ def test_some_uncorrectable(self, size):

corr_idxs = np.where(N_errors <= bch.t)[0]

DEC_M = bch.decode(R)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

DEC_M = bch.decode(R.view(np.ndarray))
RR = random_type(R)
DEC_M = bch.decode(RR)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
RR = random_type(R)
DEC_M, N_corr = bch.decode(RR, errors=True)
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])
Loading

0 comments on commit 0ddcbd7

Please sign in to comment.