diff --git a/galois/_domains/_calculate.py b/galois/_domains/_calculate.py index d04bea17e..598c15eda 100644 --- a/galois/_domains/_calculate.py +++ b/galois/_domains/_calculate.py @@ -1,5 +1,6 @@ """ -A module containing various ufunc dispatchers using explicit calculation. +A module containing various ufunc dispatchers with explicit calculation arithmetic added. Various algorithms for +each type of arithmetic are implemented here. """ from typing import Any, Type diff --git a/galois/_domains/_lookup.py b/galois/_domains/_lookup.py index a5796f3d0..c5c34f2a4 100644 --- a/galois/_domains/_lookup.py +++ b/galois/_domains/_lookup.py @@ -1,31 +1,22 @@ """ -A module containing various ufunc dispatchers using lookup tables. +A module containing various ufunc dispatchers with lookup table arithmetic added. These "lookup" implementations use +exponential, logarithm (base primitive element), and Zech logarithm (base primitive element) lookup tables to reduce +the complex finite field arithmetic to a few table lookups and an integer addition/subtraction. """ from __future__ import annotations from typing import TYPE_CHECKING -import numpy as np - -from ._ufunc import UFunc +from . import _ufunc if TYPE_CHECKING: from ._array import Array -class add_ufunc(UFunc): +class add_ufunc(_ufunc.add_ufunc): """ - Default addition ufunc dispatcher. + Addition ufunc dispatcher with lookup table arithmetic added. """ - type = "binary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - self._verify_operands_in_same_field(ufunc, inputs, meta) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global EXP, LOG, ZECH_LOG, ZECH_E @@ -66,19 +57,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover return EXP[m + ZECH_LOG[n - m]] -class negative_ufunc(UFunc): +class negative_ufunc(_ufunc.negative_ufunc): """ - Default additive inverse ufunc dispatcher. + Additive inverse ufunc dispatcher with lookup table arithmetic added. """ - type = "unary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - self._verify_unary_method_not_reduction(ufunc, method) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global EXP, LOG, ZECH_E @@ -104,19 +86,10 @@ def lookup(a: int) -> int: # pragma: no cover return EXP[ZECH_E + m] -class subtract_ufunc(UFunc): +class subtract_ufunc(_ufunc.subtract_ufunc): """ - Default subtraction ufunc dispatcher. + Subtraction ufunc dispatcher with lookup table arithmetic added. """ - type = "binary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - self._verify_operands_in_same_field(ufunc, inputs, meta) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global ORDER, EXP, LOG, ZECH_LOG, ZECH_E @@ -164,23 +137,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover return EXP[m + ZECH_LOG[z]] -class multiply_ufunc(UFunc): +class multiply_ufunc(_ufunc.multiply_ufunc): """ - Default multiplication ufunc dispatcher. + Multiplication ufunc dispatcher with lookup table arithmetic added. """ - type = "binary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - if len(meta["non_field_operands"]) > 0: - # Scalar multiplication - self._verify_operands_in_field_or_int(ufunc, inputs, meta) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - inputs[meta["non_field_operands"][0]] = np.mod(inputs[meta["non_field_operands"][0]], self.field.characteristic) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global EXP, LOG @@ -205,19 +165,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover return EXP[m + n] -class reciprocal_ufunc(UFunc): +class reciprocal_ufunc(_ufunc.reciprocal_ufunc): """ - Default multiplicative inverse ufunc dispatcher. + Multiplicative inverse ufunc dispatcher with lookup table arithmetic added. """ - type = "unary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - self._verify_unary_method_not_reduction(ufunc, method) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global ORDER, EXP, LOG @@ -244,19 +195,10 @@ def lookup(a: int) -> int: # pragma: no cover return EXP[(ORDER - 1) - m] -class divide_ufunc(UFunc): +class divide_ufunc(_ufunc.divide_ufunc): """ - Default division ufunc dispatcher. + Division ufunc dispatcher with lookup table arithmetic added. """ - type = "binary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - self._verify_operands_in_same_field(ufunc, inputs, meta) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global ORDER, EXP, LOG @@ -288,20 +230,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover return EXP[(ORDER - 1) + m - n] # We add `ORDER - 1` to guarantee the index is non-negative -class power_ufunc(UFunc): +class power_ufunc(_ufunc.power_ufunc): """ - Default exponentiation ufunc dispatcher. + Exponentiation ufunc dispatcher with lookup table arithmetic added. """ - type = "binary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): - self._verify_binary_method_not_reduction(ufunc, method) - self._verify_operands_first_field_second_int(ufunc, inputs, meta) - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - output = self._view_output_as_field(output, self.field, meta["dtype"]) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global ORDER, EXP, LOG @@ -335,19 +267,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover return EXP[(m * b) % (ORDER - 1)] # TODO: Do b % (ORDER - 1) first? b could be very large and overflow int64 -class log_ufunc(UFunc): +class log_ufunc(_ufunc.log_ufunc): """ - Default logarithm ufunc dispatcher. + Logarithm ufunc dispatcher with lookup table arithmetic added. """ - type = "binary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): # pylint: disable=unused-argument - self._verify_method_only_call(ufunc, method) - inputs = list(inputs) + [int(self.field.primitive_element)] - inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) - output = getattr(self.ufunc, method)(*inputs, **kwargs) - return output - def set_lookup_globals(self): # pylint: disable=global-variable-undefined global LOG @@ -369,20 +292,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover return LOG[a] -class sqrt_ufunc(UFunc): +class sqrt_ufunc(_ufunc.sqrt_ufunc): """ - Default square root ufunc dispatcher. + Square root ufunc dispatcher with lookup table arithmetic added. """ - type = "unary" - - def __call__(self, ufunc, method, inputs, kwargs, meta): # pylint: disable=unused-argument - self._verify_method_only_call(ufunc, method) - x = inputs[0] - b = x.is_quadratic_residue() # Boolean indicating if the inputs are quadratic residues - if not np.all(b): - raise ArithmeticError(f"Input array has elements that are quadratic non-residues (do not have a square root). Use `x.is_quadratic_residue()` to determine if elements have square roots in {self.field.name}.\n{x[~b]}") - return self.implementation(*inputs) - def implementation(self, a: Array) -> Array: """ Computes the square root of an element in a Galois field or Galois ring. diff --git a/galois/_domains/_ufunc.py b/galois/_domains/_ufunc.py index b3641d4fa..167e8217f 100644 --- a/galois/_domains/_ufunc.py +++ b/galois/_domains/_ufunc.py @@ -257,12 +257,103 @@ def _view_output_as_field(self, output, field, dtype): ############################################################################### -# Default ufunc dispatchers that simply invoke other ufuncs +# Basic ufunc dispatchers, but they still need need lookup and calculate +# arithmetic implemented ############################################################################### +class add_ufunc(UFunc): + """ + Default addition ufunc dispatcher. + """ + type = "binary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + self._verify_operands_in_same_field(ufunc, inputs, meta) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + +class negative_ufunc(UFunc): + """ + Default additive inverse ufunc dispatcher. + """ + type = "unary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + self._verify_unary_method_not_reduction(ufunc, method) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + +class subtract_ufunc(UFunc): + """ + Default subtraction ufunc dispatcher. + """ + type = "binary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + self._verify_operands_in_same_field(ufunc, inputs, meta) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + +class multiply_ufunc(UFunc): + """ + Default multiplication ufunc dispatcher. + """ + type = "binary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + if len(meta["non_field_operands"]) > 0: + # Scalar multiplication + self._verify_operands_in_field_or_int(ufunc, inputs, meta) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + inputs[meta["non_field_operands"][0]] = np.mod(inputs[meta["non_field_operands"][0]], self.field.characteristic) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + +class reciprocal_ufunc(UFunc): + """ + Default multiplicative inverse ufunc dispatcher. + """ + type = "unary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + self._verify_unary_method_not_reduction(ufunc, method) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + +class divide_ufunc(UFunc): + """ + Default division ufunc dispatcher. + """ + type = "binary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + self._verify_operands_in_same_field(ufunc, inputs, meta) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + class divmod_ufunc(UFunc): """ Default division with remainder ufunc dispatcher. + + NOTE: This does not need its own implementation. Instead, it invokes other ufuncs. """ type = "binary" @@ -277,6 +368,8 @@ def __call__(self, ufunc, method, inputs, kwargs, meta): class remainder_ufunc(UFunc): """ Default remainder ufunc dispatcher. + + NOTE: This does not need its own implementation. Instead, it invokes other ufuncs. """ type = "binary" @@ -287,9 +380,26 @@ def __call__(self, ufunc, method, inputs, kwargs, meta): return output +class power_ufunc(UFunc): + """ + Default exponentiation ufunc dispatcher. + """ + type = "binary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): + self._verify_binary_method_not_reduction(ufunc, method) + self._verify_operands_first_field_second_int(ufunc, inputs, meta) + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + output = self._view_output_as_field(output, self.field, meta["dtype"]) + return output + + class square_ufunc(UFunc): """ Default squaring ufunc dispatcher. + + NOTE: This does not need its own implementation. Instead, it invokes other ufuncs. """ type = "unary" @@ -302,6 +412,41 @@ def __call__(self, ufunc, method, inputs, kwargs, meta): return output +class log_ufunc(UFunc): + """ + Default logarithm ufunc dispatcher. + """ + type = "binary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): # pylint: disable=unused-argument + self._verify_method_only_call(ufunc, method) + inputs = list(inputs) + [int(self.field.primitive_element)] + inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) + output = getattr(self.ufunc, method)(*inputs, **kwargs) + return output + + +class sqrt_ufunc(UFunc): + """ + Default square root ufunc dispatcher. + """ + type = "unary" + + def __call__(self, ufunc, method, inputs, kwargs, meta): # pylint: disable=unused-argument + self._verify_method_only_call(ufunc, method) + x = inputs[0] + b = x.is_quadratic_residue() # Boolean indicating if the inputs are quadratic residues + if not np.all(b): + raise ArithmeticError(f"Input array has elements that are quadratic non-residues (do not have a square root). Use `x.is_quadratic_residue()` to determine if elements have square roots in {self.field.name}.\n{x[~b]}") + return self.implementation(*inputs) + + def implementation(self, a: Array) -> Array: + """ + Computes the square root of an element in a Galois field or Galois ring. + """ + raise NotImplementedError + + class matmul_ufunc(UFunc): """ Default matrix multiplication ufunc dispatcher.