Skip to content

Commit

Permalink
Restructure linear algebra code
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Apr 27, 2022
1 parent 124aeb1 commit 6d5b872
Show file tree
Hide file tree
Showing 5 changed files with 503 additions and 467 deletions.
1 change: 1 addition & 0 deletions galois/_domains/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(cls, name, bases, namespace, **kwargs):
cls._ufunc_mode = None # This is set in the first call to compile

cls._name = "Undefined" # Needs overridden
cls._is_prime_field = False # Defaults to False for Galois rings

# A dictionary of ufuncs and LUTs
cls._ufuncs = {}
Expand Down
115 changes: 0 additions & 115 deletions galois/_domains/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from numba import int64, uint64
import numpy as np

from . import _linalg
from ._ufunc import RingUfunc, FieldUfunc


Expand Down Expand Up @@ -47,18 +46,6 @@ class RingFunction(RingUfunc, abc.ABC):
np.fft.ifft: "_ifft",
}

_OVERRIDDEN_LINALG_FUNCTIONS = {
np.dot: _linalg.dot,
np.vdot: _linalg.vdot,
np.inner: _linalg.inner,
np.outer: _linalg.outer,
# np.tensordot: _linalg."tensordot",
np.linalg.det: _linalg.det,
np.linalg.matrix_rank: _linalg.matrix_rank,
np.linalg.solve: _linalg.solve,
np.linalg.inv: _linalg.inv,
}

_FUNCTION_CACHE_CALCULATE = {}

def __array_function__(self, func, types, args, kwargs):
Expand All @@ -70,9 +57,6 @@ def __array_function__(self, func, types, args, kwargs):
if func in field._OVERRIDDEN_FUNCTIONS:
output = getattr(field, field._OVERRIDDEN_FUNCTIONS[func])(*args, **kwargs)

elif func in field._OVERRIDDEN_LINALG_FUNCTIONS:
output = field._OVERRIDDEN_LINALG_FUNCTIONS[func](*args, **kwargs)

elif func in field._UNSUPPORTED_FUNCTIONS:
raise NotImplementedError(f"The NumPy function {func.__name__!r} is not supported on FieldArray. If you believe this function should be supported, please submit a GitHub issue at https://github.com/mhostetter/galois/issues.\n\nIf you'd like to perform this operation on the data, you should first call `array = array.view(np.ndarray)` and then call the function.")

Expand All @@ -89,10 +73,6 @@ def __array_function__(self, func, types, args, kwargs):

return output

def dot(self, b, out=None):
# The `np.dot(a, b)` ufunc is also available as `a.dot(b)`. Need to override this method for consistent results.
return _linalg.dot(self, b, out=out)

###############################################################################
# Individual functions, pre-compiled (cached)
###############################################################################
Expand Down Expand Up @@ -197,101 +177,6 @@ def _convolve_calculate(a, b, ADD, MULTIPLY, CHARACTERISTIC, DEGREE, IRREDUCIBLE

return c

###############################################################################
# Matrix multiplication
###############################################################################

@classmethod
def _matmul(cls, A, B, out=None, **kwargs): # pylint: disable=unused-argument
if not type(A) is type(B):
raise TypeError(f"Operation 'matmul' requires both arrays be in the same Galois field, not {type(A)} and {type(B)}.")
if not (A.ndim >= 1 and B.ndim >= 1):
raise ValueError(f"Operation 'matmul' requires both arrays have dimension at least 1, not {A.ndim}-D and {B.ndim}-D.")
if not (A.ndim <= 2 and B.ndim <= 2):
raise ValueError("Operation 'matmul' currently only supports matrix multiplication up to 2-D. If you would like matrix multiplication of N-D arrays, please submit a GitHub issue at https://github.com/mhostetter/galois/issues.")
field = type(A)
dtype = A.dtype

if field.is_prime_field:
return _linalg._lapack_linalg(A, B, np.matmul, out=out)

prepend, append = False, False
if A.ndim == 1:
A = A.reshape((1,A.size))
prepend = True
if B.ndim == 1:
B = B.reshape((B.size,1))
append = True

if not A.shape[-1] == B.shape[-2]:
raise ValueError(f"Operation 'matmul' requires the last dimension of A to match the second-to-last dimension of B, not {A.shape} and {B.shape}.")

# if A.ndim > 2 and B.ndim == 2:
# new_shape = list(A.shape[:-2]) + list(B.shape)
# B = np.broadcast_to(B, new_shape)
# if B.ndim > 2 and A.ndim == 2:
# new_shape = list(B.shape[:-2]) + list(A.shape)
# A = np.broadcast_to(A, new_shape)

if cls.ufunc_mode != "python-calculate":
A = A.astype(np.int64)
B = B.astype(np.int64)
add = cls._func_calculate("add")
multiply = cls._func_calculate("multiply")
C = cls._function("matmul")(A, B, add, multiply, cls.characteristic, cls.degree, int(cls.irreducible_poly))
C = C.astype(dtype)
else:
A = A.view(np.ndarray)
B = B.view(np.ndarray)
add = cls._func_python("add")
multiply = cls._func_python("multiply")
C = cls._function("matmul")(A, B, add, multiply, cls.characteristic, cls.degree, int(cls.irreducible_poly))
C = field._view(C)

shape = list(C.shape)
if prepend:
shape = shape[1:]
if append:
shape = shape[:-1]
C = C.reshape(shape)

# TODO: Determine a better way to do this
if out is not None:
assert isinstance(out, tuple) and len(out) == 1 # TODO: Why is `out` getting populated as tuple?
out = out[0]
out[:] = C[:]

return C

_MATMUL_CALCULATE_SIG = numba.types.FunctionType(int64[:,:](
int64[:,:],
int64[:,:],
RingUfunc._BINARY_CALCULATE_SIG,
RingUfunc._BINARY_CALCULATE_SIG,
int64,
int64,
int64
))

@staticmethod
@numba.extending.register_jitable
def _matmul_calculate(A, B, ADD, MULTIPLY, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY):
args = CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY
dtype = A.dtype

assert A.ndim == 2 and B.ndim == 2
assert A.shape[-1] == B.shape[-2]

M, K = A.shape
K, N = B.shape
C = np.zeros((M, N), dtype=dtype)
for i in range(M):
for j in range(N):
for k in range(K):
C[i,j] = ADD(C[i,j], MULTIPLY(A[i,k], B[k,j], *args), *args)

return C

###############################################################################
# FFT and IFFT
# TODO: Determine how to handle recursion with a single JIT-compiled or
Expand Down
Loading

0 comments on commit 6d5b872

Please sign in to comment.