Skip to content

Commit

Permalink
JIT compile polynomial subtraction routine
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Nov 10, 2022
1 parent 3d090c8 commit 7eed640
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
38 changes: 38 additions & 0 deletions src/galois/_polys/_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,44 @@ def subtract(a: Array, b: Array) -> Array:
return c


class subtract_jit(Function):
"""
Computes polynomial subtraction of two polynomials.
Algorithm:
c(x) = a(x) - b(x)
"""
def __call__(self, a: Array, b: Array) -> Array:
verify_isinstance(a, self.field)
verify_isinstance(b, self.field)
assert a.ndim == 1 and b.ndim == 1
dtype = a.dtype

if self.field.ufunc_mode != "python-calculate":
r = self.jit(a.astype(np.int64), b.astype(np.int64))
r = r.astype(dtype)
else:
r = self.python(a.view(np.ndarray), b.view(np.ndarray))
r = self.field._view(r)

return r

def set_globals(self):
# pylint: disable=global-variable-undefined
global SUBTRACT
SUBTRACT = self.field._subtract.ufunc

_SIGNATURE = numba.types.FunctionType(int64[:](int64[:], int64[:]))

@staticmethod
def implementation(a, b):
dtype = a.dtype
c = np.zeros(max(a.size, b.size), dtype=dtype)
c[-a.size:] = a
c[-b.size:] = SUBTRACT(c[-b.size:], b)
return c


def multiply(a: Array, b: Array) -> Array:
"""
c(x) = a(x) * b(x)
Expand Down
4 changes: 2 additions & 2 deletions src/galois/_polys/_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,7 @@ def __sub__(self, other: Poly | Array) -> Poly:
else:
a = _convert_to_coeffs(self, self.field)
b = _convert_to_coeffs(other, self.field)
c = _dense.subtract(a, b)
c = _dense.subtract_jit(self.field)(a, b)
return Poly(c, field=self.field)

def __rsub__(self, other: Poly | Array) -> Poly:
Expand All @@ -1820,7 +1820,7 @@ def __rsub__(self, other: Poly | Array) -> Poly:
else:
a = _convert_to_coeffs(other, self.field)
b = _convert_to_coeffs(self, self.field)
c = _dense.subtract(a, b)
c = _dense.subtract_jit(self.field)(a, b)
return Poly(c, field=self.field)

def __mul__(self, other: Poly | Array | int) -> Poly:
Expand Down

0 comments on commit 7eed640

Please sign in to comment.