Skip to content

Add generic tests for polynomial types #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 205 additions & 7 deletions src/flint/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,9 +742,9 @@ def test_fmpq():
assert raises(lambda: Q("1.0"), ValueError)
assert raises(lambda: Q("1.5"), ValueError)
assert raises(lambda: Q("1/2/3"), ValueError)
assert raises(lambda: Q([]), ValueError)
assert raises(lambda: Q(1, []), ValueError)
assert raises(lambda: Q([], 1), ValueError)
assert raises(lambda: Q([]), TypeError)
assert raises(lambda: Q(1, []), TypeError)
assert raises(lambda: Q([], 1), TypeError)
assert bool(Q(0)) == False
assert bool(Q(1)) == True
assert Q(1,3) + Q(2,3) == 1
Expand Down Expand Up @@ -1049,9 +1049,8 @@ def test_fmpq_mat():
assert raises(lambda: Q(None), TypeError)
assert Q([[1,2,3],[4,5,6]]) == Q(2,3,[1,2,3,4,5,6])
assert raises(lambda: Q(2,3,[1,2,3,4,5]), ValueError)
# XXX: Should be TypeError not ValueError:
assert raises(lambda: Q([[1,2,3],[4,[],6]]), ValueError)
assert raises(lambda: Q(2,3,[1,2,3,4,[],6]), ValueError)
assert raises(lambda: Q([[1,2,3],[4,[],6]]), TypeError)
assert raises(lambda: Q(2,3,[1,2,3,4,[],6]), TypeError)
assert raises(lambda: Q(2,3,[1,2],[3,4]), ValueError)
assert bool(Q([[1]])) is True
assert bool(Q([[0]])) is False
Expand Down Expand Up @@ -1815,6 +1814,204 @@ def test_fmpz_mod_dlog():
assert g**x == a


def _all_polys():
return [
# (poly_type, scalar_type, is_field)
(flint.fmpz_poly, flint.fmpz, False),
(flint.fmpq_poly, flint.fmpq, True),
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True),
]


def test_polys():
for P, S, is_field in _all_polys():

assert P([S(1)]) == P([1]) == P(P([1])) == P(1)

assert raises(lambda: P([None]), TypeError)
Comment on lines +1817 to +1831
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GiacomoPope if we write the tests more like this then we don't need to duplicate the test code for different types and it also means that we can test that the different types have consistent method names and signatures.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with this. I think once #87 is merged I will take the time to address the tests before working on fq types.

assert raises(lambda: P(object()), TypeError)
assert raises(lambda: P(None), TypeError)
assert raises(lambda: P(None, None), TypeError)
assert raises(lambda: P([1,2], None), TypeError)
assert raises(lambda: P(1, None), TypeError)

assert len(P([])) == P([]).length() == 0
assert len(P([1])) == P([1]).length() == 1
assert len(P([1,2])) == P([1,2]).length() == 2
assert len(P([1,2,3])) == P([1,2,3]).length() == 3

assert P([]).degree() == -1
assert P([1]).degree() == 0
assert P([1,2]).degree() == 1
assert P([1,2,3]).degree() == 2

assert (P([1]) == P([1])) is True
assert (P([1]) != P([1])) is False
assert (P([1]) == P([2])) is False
assert (P([1]) != P([2])) is True

assert (P([1]) == None) is False
assert (P([1]) != None) is True
assert (None == P([1])) is False
assert (None != P([1])) is True

assert raises(lambda: P([1]) < P([1]), TypeError)
assert raises(lambda: P([1]) <= P([1]), TypeError)
assert raises(lambda: P([1]) > P([1]), TypeError)
assert raises(lambda: P([1]) >= P([1]), TypeError)
assert raises(lambda: P([1]) < None, TypeError)
assert raises(lambda: P([1]) <= None, TypeError)
assert raises(lambda: P([1]) > None, TypeError)
assert raises(lambda: P([1]) >= None, TypeError)
assert raises(lambda: None < P([1]), TypeError)
assert raises(lambda: None <= P([1]), TypeError)
assert raises(lambda: None > P([1]), TypeError)
assert raises(lambda: None >= P([1]), TypeError)

assert P([1, 2, 3])[1] == S(2)
assert P([1, 2, 3])[-1] == S(0)
assert P([1, 2, 3])[3] == S(0)

p = P([1, 2, 3])
p[1] = S(4)
assert p == P([1, 4, 3])

def setbad(obj, i, val):
obj[i] = val

assert raises(lambda: setbad(p, 2, None), TypeError)
assert raises(lambda: setbad(p, -1, 1), ValueError)

for v in [], [1], [1, 2]:
if P == flint.fmpz_poly:
assert P(v).repr() == f'fmpz_poly({v!r})'
elif P == flint.fmpq_poly:
assert P(v).repr() == f'fmpq_poly({v!r})'
else:
assert P(v).repr() == f'nmod_poly({v!r}, 17)'

assert repr(P([])) == '0'
assert repr(P([1])) == '1'
assert repr(P([1, 2])) == '2*x + 1'
assert repr(P([1, 2, 3])) == '3*x^2 + 2*x + 1'

p = P([1, 2, 3])
assert p(0) == p(S(0)) == S(1) == 1
assert p(1) == p(S(1)) == S(6) == 6
assert p(p) == P([6, 16, 36, 36, 27])
assert raises(lambda: p(None), TypeError)

assert bool(P([])) is False
assert bool(P([1])) is True

assert +P([1, 2, 3]) == P([1, 2, 3])
assert -P([1, 2, 3]) == P([-1, -2, -3])

assert P([1, 2, 3]) + P([4, 5, 6]) == P([5, 7, 9])

for T in [int, S, flint.fmpz]:
assert P([1, 2, 3]) + T(1) == P([2, 2, 3])
assert T(1) + P([1, 2, 3]) == P([2, 2, 3])

assert raises(lambda: P([1, 2, 3]) + None, TypeError)
assert raises(lambda: None + P([1, 2, 3]), TypeError)

assert P([1, 2, 3]) - P([4, 5, 6]) == P([-3, -3, -3])

for T in [int, S, flint.fmpz]:
assert P([1, 2, 3]) - T(1) == P([0, 2, 3])
assert T(1) - P([1, 2, 3]) == P([0, -2, -3])

assert raises(lambda: P([1, 2, 3]) - None, TypeError)
assert raises(lambda: None - P([1, 2, 3]), TypeError)

assert P([1, 2, 3]) * P([4, 5, 6]) == P([4, 13, 28, 27, 18])

for T in [int, S, flint.fmpz]:
assert P([1, 2, 3]) * T(2) == P([2, 4, 6])
assert T(2) * P([1, 2, 3]) == P([2, 4, 6])

assert raises(lambda: P([1, 2, 3]) * None, TypeError)
assert raises(lambda: None * P([1, 2, 3]), TypeError)

assert P([1, 2, 1]) // P([1, 1]) == P([1, 1])
assert P([1, 2, 1]) % P([1, 1]) == P([0])
assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0]))

if is_field:
assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2])
assert P([1, 1]) % 2 == P([0])
else:
assert P([1, 1]) // 2 == P([0, 0])
assert P([1, 1]) % 2 == P([1, 1])

assert 1 // P([1, 1]) == P([0])
assert 1 % P([1, 1]) == P([1])
assert divmod(1, P([1, 1])) == (P([0]), P([1]))

assert raises(lambda: P([1, 2, 1]) // None, TypeError)
assert raises(lambda: P([1, 2, 1]) % None, TypeError)
assert raises(lambda: divmod(P([1, 2, 1]), None), TypeError)

assert raises(lambda: None // P([1, 1]), TypeError)
assert raises(lambda: None % P([1, 1]), TypeError)
assert raises(lambda: divmod(None, P([1, 1])), TypeError)

assert raises(lambda: P([1, 2, 1]) // 0, ZeroDivisionError)
assert raises(lambda: P([1, 2, 1]) % 0, ZeroDivisionError)
assert raises(lambda: divmod(P([1, 2, 1]), 0), ZeroDivisionError)

assert raises(lambda: P([1, 2, 1]) // P([0]), ZeroDivisionError)
assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError)
assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError)

if is_field:
assert P([2, 2]) / 2 == P([1, 1])
assert P([1, 2]) / 2 == P([S(1)/2, 1])
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
else:
assert raises(lambda: P([2, 2]) / 2, TypeError)

assert raises(lambda: 1 / P([1, 1]), TypeError)
assert raises(lambda: P([1, 2, 1]) / P([1, 1]), TypeError)
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), TypeError)

assert P([1, 1]) ** 0 == P([1])
assert P([1, 1]) ** 1 == P([1, 1])
assert P([1, 1]) ** 2 == P([1, 2, 1])
assert raises(lambda: P([1, 1]) ** -1, ValueError)
assert raises(lambda: P([1, 1]) ** None, TypeError)
# XXX: Not sure what this should do in general:
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)

assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)

if is_field:
p1 = P([1, 0, 1])
p2 = P([2, 1])
g, s, t = P([1]), P([1])/5, P([2, -1])/5
assert p1.xgcd(p2) == (g, s, t)
assert raises(lambda: p1.xgcd(None), TypeError)

assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)])

assert P([1, 2, 1]).sqrt() == P([1, 1])
assert P([1, 2, 2]).sqrt() is None
if P == flint.fmpq_poly:
assert P([1, 2, 1], 3).sqrt() is None
assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2)

assert P([]).deflation() == (P([]), 1)
assert P([1, 2]).deflation() == (P([1, 2]), 1)
assert P([1, 0, 2]).deflation() == (P([1, 2]), 2)

assert P([1, 2, 1]).derivative() == P([2, 2])

if is_field:
assert P([1, 2, 1]).integral() == P([0, 1, 1, S(1)/3])



all_tests = [
test_pyflint,
Expand All @@ -1835,5 +2032,6 @@ def test_fmpz_mod_dlog():
test_nmod_mat,
test_arb,
test_fmpz_mod,
test_fmpz_mod_dlog
test_fmpz_mod_dlog,
test_polys,
]
24 changes: 15 additions & 9 deletions src/flint/types/fmpq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ cdef class fmpq(flint_scalar):
def __dealloc__(self):
fmpq_clear(self.val)

def __init__(self, p=None, q=None):
cdef long x
if q is None:
if p is None:
return # zero
elif typecheck(p, fmpq):
def __init__(self, *args):
if not args:
return # zero
elif len(args) == 2:
p, q = args
elif len(args) == 1:
p = args[0]
if typecheck(p, fmpq):
fmpq_set(self.val, (<fmpq>p).val)
return
elif typecheck(p, str):
Expand All @@ -90,17 +92,21 @@ cdef class fmpq(flint_scalar):
else:
p = any_as_fmpq(p)
if p is NotImplemented:
raise ValueError("cannot create fmpq from object of type %s" % type(p))
raise TypeError("cannot create fmpq from object of type %s" % type(p))
fmpq_set(self.val, (<fmpq>p).val)
return
else:
raise TypeError("fmpq() takes at most 2 arguments (%d given)" % len(args))

p = any_as_fmpz(p)
if p is NotImplemented:
raise ValueError("cannot create fmpq from object of type %s" % type(p))
raise TypeError("cannot create fmpq from object of type %s" % type(p))
q = any_as_fmpz(q)
if q is NotImplemented:
raise ValueError("cannot create fmpq from object of type %s" % type(q))
raise TypeError("cannot create fmpq from object of type %s" % type(q))
if fmpz_is_zero((<fmpz>q).val):
raise ZeroDivisionError("cannot create rational number with zero denominator")

fmpz_set(fmpq_numref(self.val), (<fmpz>p).val)
fmpz_set(fmpq_denref(self.val), (<fmpz>q).val)
fmpq_canonicalise(self.val)
Expand Down
64 changes: 50 additions & 14 deletions src/flint/types/fmpq_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,26 @@ cdef class fmpq_poly(flint_poly):
def __dealloc__(self):
fmpq_poly_clear(self.val)

def __init__(self, p=None, q=None):
if p is not None:
if typecheck(p, fmpq_poly):
fmpq_poly_set(self.val, (<fmpq_poly>p).val)
elif typecheck(p, fmpz_poly):
fmpq_poly_set_fmpz_poly(self.val, (<fmpz_poly>p).val)
elif isinstance(p, list):
fmpq_poly_set_list(self.val, p)
else:
raise TypeError("cannot create fmpq_poly from input of type %s", type(p))
if q is not None:
q = any_as_fmpz(q)
def __init__(self, *args):
if len(args) == 0:
return
elif len(args) > 2:
raise TypeError("fmpq_poly() takes 0, 1 or 2 arguments (%d given)" % len(args))

p = args[0]
if typecheck(p, fmpq_poly):
fmpq_poly_set(self.val, (<fmpq_poly>p).val)
elif typecheck(p, fmpz_poly):
fmpq_poly_set_fmpz_poly(self.val, (<fmpz_poly>p).val)
elif isinstance(p, list):
fmpq_poly_set_list(self.val, p)
elif (v := any_as_fmpq(p)) is not NotImplemented:
fmpq_poly_set_fmpq(self.val, (<fmpq>v).val)
else:
raise TypeError("cannot create fmpq_poly from input of type %s", type(p))

if len(args) == 2:
q = any_as_fmpz(args[1])
if q is NotImplemented:
raise TypeError("denominator must be an integer, got %s", type(q))
if fmpz_is_zero((<fmpz>q).val):
Expand Down Expand Up @@ -326,12 +334,14 @@ cdef class fmpq_poly(flint_poly):
return t
return t._divmod_(s)

def __pow__(fmpq_poly self, ulong exp, mod):
def __pow__(fmpq_poly self, exp, mod):
cdef fmpq_poly res
if mod is not None:
raise NotImplementedError("fmpz_poly modular exponentiation")
if exp < 0:
raise ValueError("fmpq_poly negative exponent")
res = fmpq_poly.__new__(fmpq_poly)
fmpq_poly_pow(res.val, self.val, exp)
fmpq_poly_pow(res.val, self.val, <ulong>exp)
return res

def gcd(self, other):
Expand Down Expand Up @@ -384,6 +394,32 @@ cdef class fmpq_poly(flint_poly):
fac[i] = (base, exp)
return c / self.denom(), fac

def sqrt(self):
"""
Return the exact square root of this polynomial or ``None``.

>>> p = fmpq_poly([1,2,1],4)
>>> p
1/4*x^2 + 1/2*x + 1/4
>>> p.sqrt()
1/2*x + 1/2

"""
d = self.denom()
n = self.numer()
d, r = d.sqrtrem()
if r != 0:
return None
n = n.sqrt()
if n is None:
return None
return fmpq_poly(n, d)

def deflation(self):
num, n = self.numer().deflation()
num = fmpq_poly(num, self.denom())
return num, n

def complex_roots(self, **kwargs):
"""
Computes the complex roots of this polynomial. See
Expand Down
Loading