Skip to content

Commit

Permalink
Reorganize default ufunc dispatcher definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Jul 25, 2022
1 parent 182a493 commit b3fb252
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 111 deletions.
3 changes: 2 additions & 1 deletion galois/_domains/_calculate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
A module containing various ufunc dispatchers using explicit calculation.
A module containing various ufunc dispatchers with explicit calculation arithmetic added. Various algorithms for
each type of arithmetic are implemented here.
"""
from typing import Any, Type

Expand Down
131 changes: 22 additions & 109 deletions galois/_domains/_lookup.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
"""
A module containing various ufunc dispatchers using lookup tables.
A module containing various ufunc dispatchers with lookup table arithmetic added. These "lookup" implementations use
exponential, logarithm (base primitive element), and Zech logarithm (base primitive element) lookup tables to reduce
the complex finite field arithmetic to a few table lookups and an integer addition/subtraction.
"""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from ._ufunc import UFunc
from . import _ufunc

if TYPE_CHECKING:
from ._array import Array


class add_ufunc(UFunc):
class add_ufunc(_ufunc.add_ufunc):
"""
Default addition ufunc dispatcher.
Addition ufunc dispatcher with lookup table arithmetic added.
"""
type = "binary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
self._verify_operands_in_same_field(ufunc, inputs, meta)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global EXP, LOG, ZECH_LOG, ZECH_E
Expand Down Expand Up @@ -66,19 +57,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover
return EXP[m + ZECH_LOG[n - m]]


class negative_ufunc(UFunc):
class negative_ufunc(_ufunc.negative_ufunc):
"""
Default additive inverse ufunc dispatcher.
Additive inverse ufunc dispatcher with lookup table arithmetic added.
"""
type = "unary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
self._verify_unary_method_not_reduction(ufunc, method)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global EXP, LOG, ZECH_E
Expand All @@ -104,19 +86,10 @@ def lookup(a: int) -> int: # pragma: no cover
return EXP[ZECH_E + m]


class subtract_ufunc(UFunc):
class subtract_ufunc(_ufunc.subtract_ufunc):
"""
Default subtraction ufunc dispatcher.
Subtraction ufunc dispatcher with lookup table arithmetic added.
"""
type = "binary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
self._verify_operands_in_same_field(ufunc, inputs, meta)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global ORDER, EXP, LOG, ZECH_LOG, ZECH_E
Expand Down Expand Up @@ -164,23 +137,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover
return EXP[m + ZECH_LOG[z]]


class multiply_ufunc(UFunc):
class multiply_ufunc(_ufunc.multiply_ufunc):
"""
Default multiplication ufunc dispatcher.
Multiplication ufunc dispatcher with lookup table arithmetic added.
"""
type = "binary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
if len(meta["non_field_operands"]) > 0:
# Scalar multiplication
self._verify_operands_in_field_or_int(ufunc, inputs, meta)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
inputs[meta["non_field_operands"][0]] = np.mod(inputs[meta["non_field_operands"][0]], self.field.characteristic)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global EXP, LOG
Expand All @@ -205,19 +165,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover
return EXP[m + n]


class reciprocal_ufunc(UFunc):
class reciprocal_ufunc(_ufunc.reciprocal_ufunc):
"""
Default multiplicative inverse ufunc dispatcher.
Multiplicative inverse ufunc dispatcher with lookup table arithmetic added.
"""
type = "unary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
self._verify_unary_method_not_reduction(ufunc, method)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global ORDER, EXP, LOG
Expand All @@ -244,19 +195,10 @@ def lookup(a: int) -> int: # pragma: no cover
return EXP[(ORDER - 1) - m]


class divide_ufunc(UFunc):
class divide_ufunc(_ufunc.divide_ufunc):
"""
Default division ufunc dispatcher.
Division ufunc dispatcher with lookup table arithmetic added.
"""
type = "binary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
self._verify_operands_in_same_field(ufunc, inputs, meta)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global ORDER, EXP, LOG
Expand Down Expand Up @@ -288,20 +230,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover
return EXP[(ORDER - 1) + m - n] # We add `ORDER - 1` to guarantee the index is non-negative


class power_ufunc(UFunc):
class power_ufunc(_ufunc.power_ufunc):
"""
Default exponentiation ufunc dispatcher.
Exponentiation ufunc dispatcher with lookup table arithmetic added.
"""
type = "binary"

def __call__(self, ufunc, method, inputs, kwargs, meta):
self._verify_binary_method_not_reduction(ufunc, method)
self._verify_operands_first_field_second_int(ufunc, inputs, meta)
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
output = self._view_output_as_field(output, self.field, meta["dtype"])
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global ORDER, EXP, LOG
Expand Down Expand Up @@ -335,19 +267,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover
return EXP[(m * b) % (ORDER - 1)] # TODO: Do b % (ORDER - 1) first? b could be very large and overflow int64


class log_ufunc(UFunc):
class log_ufunc(_ufunc.log_ufunc):
"""
Default logarithm ufunc dispatcher.
Logarithm ufunc dispatcher with lookup table arithmetic added.
"""
type = "binary"

def __call__(self, ufunc, method, inputs, kwargs, meta): # pylint: disable=unused-argument
self._verify_method_only_call(ufunc, method)
inputs = list(inputs) + [int(self.field.primitive_element)]
inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs)
output = getattr(self.ufunc, method)(*inputs, **kwargs)
return output

def set_lookup_globals(self):
# pylint: disable=global-variable-undefined
global LOG
Expand All @@ -369,20 +292,10 @@ def lookup(a: int, b: int) -> int: # pragma: no cover
return LOG[a]


class sqrt_ufunc(UFunc):
class sqrt_ufunc(_ufunc.sqrt_ufunc):
"""
Default square root ufunc dispatcher.
Square root ufunc dispatcher with lookup table arithmetic added.
"""
type = "unary"

def __call__(self, ufunc, method, inputs, kwargs, meta): # pylint: disable=unused-argument
self._verify_method_only_call(ufunc, method)
x = inputs[0]
b = x.is_quadratic_residue() # Boolean indicating if the inputs are quadratic residues
if not np.all(b):
raise ArithmeticError(f"Input array has elements that are quadratic non-residues (do not have a square root). Use `x.is_quadratic_residue()` to determine if elements have square roots in {self.field.name}.\n{x[~b]}")
return self.implementation(*inputs)

def implementation(self, a: Array) -> Array:
"""
Computes the square root of an element in a Galois field or Galois ring.
Expand Down
Loading

0 comments on commit b3fb252

Please sign in to comment.