Skip to content

Commit

Permalink
Clean up _convert_codeword_to_message()
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Nov 11, 2022
1 parent 1b36c35 commit f8d63ac
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/galois/_codes/_cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,22 @@ def decode(self, codeword, output="message", errors=False):
return super().decode(codeword, output=output, errors=errors)

def _convert_codeword_to_message(self, codeword: FieldArray) -> FieldArray:
ns = codeword.shape[-1] # The number of codeword symbols (could be less than self.n for shortened codes)
ks = self.k - (self.n - ns) # The number of message symbols (could be less than self.k for shortened codes)

if self.is_systematic:
message = codeword[..., 0:self.k]
message = codeword[..., 0:ks]
else:
message, _ = divmod_jit(self.field)(codeword, self.generator_poly.coeffs)

return message

def _convert_codeword_to_parity(self, codeword: FieldArray) -> FieldArray:
if self.is_systematic:
parity = codeword[..., -(self.n - self.k):]
else:
_, parity = divmod_jit(self.field)(codeword, self.generator_poly.coeffs)

return parity

@property
Expand Down
7 changes: 2 additions & 5 deletions src/galois/_codes/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def encode(self, message: ArrayLike, output: Literal["codeword", "parity"] = "co

message, is_message_1d = self._check_and_convert_message(message)
codeword = self._encode_message(message)

if is_message_1d:
codeword = codeword[0,:]

Expand Down Expand Up @@ -155,14 +156,10 @@ def decode(self, codeword, output="message", errors=False):
codeword, is_codeword_1d = self._check_and_convert_codeword(codeword)
dec_codeword, N_errors = self._decode_codeword(codeword)

ns = codeword.shape[-1] # The number of codeword symbols (could be less than self.n for shortened codes)
ks = self.k - (self.n - ns) # The number of message symbols (could be less than self.k for shortened codes)

if output == "message":
decoded = self._convert_codeword_to_message(dec_codeword)
decoded = decoded[:, :ks]
else:
decoded = dec_codeword[:, :ns]
decoded = dec_codeword

if is_codeword_1d:
decoded, N_errors = decoded[0,:], int(N_errors[0])
Expand Down

0 comments on commit f8d63ac

Please sign in to comment.