Skip to content

Commit

Permalink
Merge pull request #168 from oscarbenjamin/pr_comparisons
Browse files Browse the repository at this point in the history
Make comparisons is_one, and is_zero consistent for polys
  • Loading branch information
oscarbenjamin authored Jul 22, 2024
2 parents 3be51d9 + f24f9aa commit 2e4b0ab
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 41 deletions.
83 changes: 61 additions & 22 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ def test_pickling():

def test_fmpz_mod():
from flint import fmpz_mod_ctx, fmpz, fmpz_mod

p_sml = 163
p_med = 2**127 - 1
p_big = 2**255 - 19
Expand Down Expand Up @@ -1754,7 +1754,7 @@ def test_fmpz_mod():
assert raises(lambda: F_test(test_x) * "AAA", TypeError)
assert raises(lambda: F_test(test_x) * F_other(test_x), ValueError)

# Exponentiation
# Exponentiation

assert F_test(0)**0 == pow(0, 0, test_mod)
assert F_test(0)**1 == pow(0, 1, test_mod)
Expand Down Expand Up @@ -1804,7 +1804,7 @@ def test_fmpz_mod():

assert fmpz(test_y) / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod
assert test_y / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod

def test_fmpz_mod_dlog():
from flint import fmpz, fmpz_mod_ctx

Expand All @@ -1826,7 +1826,7 @@ def test_fmpz_mod_dlog():
F = fmpz_mod_ctx(163)
g = F(2)
a = g**123

assert 123 == g.discrete_log(a)

a_int = pow(2, 123, 163)
Expand Down Expand Up @@ -1877,7 +1877,7 @@ def test_fmpz_mod_poly():
assert repr(R3) == "fmpz_mod_poly_ctx(13)"

assert R1.modulus() == 11

assert R1.is_prime()
assert R1.zero() == 0
assert R1.one() == 1
Expand Down Expand Up @@ -1946,7 +1946,7 @@ def test_fmpz_mod_poly():
assert str(f) == "8*x^3 + 7*x^2 + 6*x + 7"

# TODO: currently repr does pretty printing
# just like str, we should address this. Mainly,
# just like str, we should address this. Mainly,
# the issue is we want nice `repr` behaviour in
# interactive shells, which currently is why this
# choice has been made
Expand Down Expand Up @@ -1992,7 +1992,7 @@ def test_fmpz_mod_poly():
F_sml = fmpz_mod_ctx(p_sml)
F_med = fmpz_mod_ctx(p_med)
F_big = fmpz_mod_ctx(p_big)

R_sml = fmpz_mod_poly_ctx(F_sml)
R_med = fmpz_mod_poly_ctx(F_med)
R_big = fmpz_mod_poly_ctx(F_big)
Expand All @@ -2003,14 +2003,14 @@ def test_fmpz_mod_poly():
f_bad = R_cmp([2,2,2,2,2])

for (F_test, R_test) in [(F_sml, R_sml), (F_med, R_med), (F_big, R_big)]:

f = R_test([-1,-2])
g = R_test([-3,-4])

# pos, neg
assert f is +f
assert -f == R_test([1,2])

# add
assert raises(lambda: f + f_cmp, ValueError)
assert raises(lambda: f + "AAA", TypeError)
Expand Down Expand Up @@ -2063,7 +2063,7 @@ def test_fmpz_mod_poly():
assert raises(lambda: f / "AAA", TypeError)
assert raises(lambda: f / 0, ZeroDivisionError)
assert raises(lambda: f_cmp / 2, ZeroDivisionError)

assert (f + f) / 2 == f
assert (f + f) / fmpz(2) == f
assert (f + f) / F_test(2) == f
Expand All @@ -2077,7 +2077,7 @@ def test_fmpz_mod_poly():
assert (f + f) // 2 == f
assert (f + f) // fmpz(2) == f
assert (f + f) // F_test(2) == f
assert 2 // R_test(2) == 1
assert 2 // R_test(2) == 1
assert (f + 1) // f == 1

# pow
Expand Down Expand Up @@ -2171,7 +2171,7 @@ def test_fmpz_mod_poly():
f1 = R_test([-3, 1])
f2 = R_test([-5, 1])
assert f1.resultant(f2) == (3 - 5)
assert raises(lambda: f.resultant("AAA"), TypeError)
assert raises(lambda: f.resultant("AAA"), TypeError)

# sqrt
f1 = R_test.random_element(irreducible=True)
Expand Down Expand Up @@ -2428,14 +2428,14 @@ def _all_polys():
(flint.fmpz_poly, flint.fmpz, False),
(flint.fmpq_poly, flint.fmpq, True),
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
True),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
True),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
True),
]

Expand Down Expand Up @@ -2467,6 +2467,28 @@ def test_polys():
assert (P([1]) == P([2])) is False
assert (P([1]) != P([2])) is True

assert (P([1]) == 1) is True
assert (P([1]) != 1) is False
assert (P([1]) == 2) is False
assert (P([1]) != 2) is True

assert (1 == P([1])) is True
assert (1 != P([1])) is False
assert (2 == P([1])) is False
assert (2 != P([1])) is True

s1, s2 = S(1), S(2)

assert (P([s1]) == s1) is True
assert (P([s1]) != s1) is False
assert (P([s1]) == s2) is False
assert (P([s1]) != s2) is True

assert (s1 == P([s1])) is True
assert (s1 != P([s1])) is False
assert (s1 == P([s2])) is False
assert (s1 != P([s2])) is True

assert (P([1]) == None) is False
assert (P([1]) != None) is True
assert (None == P([1])) is False
Expand Down Expand Up @@ -2500,12 +2522,17 @@ def setbad(obj, i, val):
assert raises(lambda: setbad(p, -1, 1), ValueError)

for v in [], [1], [1, 2]:
if P == flint.fmpz_poly:
p = P(v)
if type(p) == flint.fmpz_poly:
assert P(v).repr() == f'fmpz_poly({v!r})'
elif P == flint.fmpq_poly:
elif type(p) == flint.fmpq_poly:
assert P(v).repr() == f'fmpq_poly({v!r})'
elif P == flint.nmod_poly:
elif type(p) == flint.nmod_poly:
assert P(v).repr() == f'nmod_poly({v!r}, 17)'
elif type(p) == flint.fmpz_mod_poly:
pass # fmpz_mod_poly does not have .repr() ...
else:
assert False

assert repr(P([])) == '0'
assert repr(P([1])) == '1'
Expand All @@ -2521,6 +2548,12 @@ def setbad(obj, i, val):
assert bool(P([])) is False
assert bool(P([1])) is True

assert P([]).is_zero() is True
assert P([1]).is_zero() is False

assert P([]).is_one() is False
assert P([1]).is_one() is True

assert +P([1, 2, 3]) == P([1, 2, 3])
assert -P([1, 2, 3]) == P([-1, -2, -3])

Expand Down Expand Up @@ -2600,7 +2633,7 @@ def setbad(obj, i, val):
assert P([1, 1]) ** 2 == P([1, 2, 1])
assert raises(lambda: P([1, 1]) ** -1, ValueError)
assert raises(lambda: P([1, 1]) ** None, TypeError)

# # XXX: Not sure what this should do in general:
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)

Expand Down Expand Up @@ -2825,6 +2858,12 @@ def quick_poly():
assert bool(P(ctx=ctx)) is False
assert bool(P(1, ctx=ctx)) is True

assert P(ctx=ctx).is_zero() is True
assert P(1, ctx=ctx).is_zero() is False

assert P(ctx=ctx).is_one() is False
assert P(1, ctx=ctx).is_one() is True

assert +quick_poly() \
== quick_poly()
assert -quick_poly() \
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/acb_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ cdef class acb_mat(flint_mat):
else:
raise ValueError("acb_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
raise NotImplementedError

cpdef long nrows(s):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/arb_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ cdef class arb_mat(flint_mat):
else:
raise ValueError("arb_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
raise NotImplementedError

cpdef long nrows(s):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ cdef class fmpq(flint_scalar):
def __trunc__(self):
return self.trunc()

def __nonzero__(self):
def __bool__(self):
return not fmpq_is_zero(self.val)

def __round__(self, ndigits=None):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpq_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ cdef class fmpq_mat(flint_mat):
else:
raise TypeError("fmpq_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
return not fmpq_mat_is_zero(self.val)

def __richcmp__(s, t, int op):
Expand Down
9 changes: 6 additions & 3 deletions src/flint/types/fmpq_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ cdef class fmpq_mpoly(flint_mpoly):
def __bool__(self):
return not fmpq_mpoly_is_zero(self.val, self.ctx.val)

def is_zero(self):
return <bint>fmpq_mpoly_is_zero(self.val, self.ctx.val)

def is_one(self):
return <bint>fmpq_mpoly_is_one(self.val, self.ctx.val)

def __richcmp__(self, other, int op):
if not (op == Py_EQ or op == Py_NE):
return NotImplemented
Expand Down Expand Up @@ -782,9 +788,6 @@ cdef class fmpq_mpoly(flint_mpoly):
"""
return self.ctx

def is_one(self):
return fmpq_mpoly_is_one(self.val, self.ctx.val)

def coefficient(self, slong i):
"""
Return the coefficient at index `i`.
Expand Down
8 changes: 7 additions & 1 deletion src/flint/types/fmpq_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ cdef class fmpq_poly(flint_poly):
else:
return "fmpq_poly(%s, %s)" % ([int(c) for c in n.coeffs()], d)

def __nonzero__(self):
def __bool__(self):
return not fmpq_poly_is_zero(self.val)

def is_zero(self):
return <bint>fmpq_poly_is_zero(self.val)

def is_one(self):
return <bint>fmpq_poly_is_one(self.val)

def __call__(self, other):
t = any_as_fmpz(other)
if t is not NotImplemented:
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpz.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ cdef class fmpz(flint_scalar):
def repr(self):
return "fmpz(%s)" % self.str()

def __nonzero__(self):
def __bool__(self):
return not fmpz_is_zero(self.val)

def __pos__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpz_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ cdef class fmpz_mat(flint_mat):
else:
raise TypeError("fmpz_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
return not fmpz_mat_is_zero(self.val)

def __richcmp__(fmpz_mat s, t, int op):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpz_mod_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ cdef class fmpz_mod_mat(flint_mat):
e = self.ctx.any_as_fmpz_mod(value)
self._setitem(i, j, e.val)

def __nonzero__(self):
def __bool__(self):
"""Return ``True`` if the matrix has any nonzero entries."""
cdef bint zero
zero = compat_fmpz_mod_mat_is_zero(self.val, self.ctx.val)
Expand Down
9 changes: 6 additions & 3 deletions src/flint/types/fmpz_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ cdef class fmpz_mpoly(flint_mpoly):
def __bool__(self):
return not fmpz_mpoly_is_zero(self.val, self.ctx.val)

def is_zero(self):
return <bint>fmpz_mpoly_is_zero(self.val, self.ctx.val)

def is_one(self):
return <bint>fmpz_mpoly_is_one(self.val, self.ctx.val)

def __richcmp__(self, other, int op):
if not (op == Py_EQ or op == Py_NE):
return NotImplemented
Expand Down Expand Up @@ -764,9 +770,6 @@ cdef class fmpz_mpoly(flint_mpoly):
"""
return self.ctx

def is_one(self):
return fmpz_mpoly_is_one(self.val, self.ctx.val)

def coefficient(self, slong i):
"""
Return the coefficient at index `i`.
Expand Down
8 changes: 7 additions & 1 deletion src/flint/types/fmpz_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,15 @@ cdef class fmpz_poly(flint_poly):
def repr(self):
return "fmpz_poly([%s])" % (", ".join(map(str, self.coeffs())))

def __nonzero__(self):
def __bool__(self):
return not fmpz_poly_is_zero(self.val)

def is_zero(self):
return <bint>fmpz_poly_is_zero(self.val)

def is_one(self):
return <bint>fmpz_poly_is_one(self.val)

def __call__(self, other):
t = any_as_fmpz(other)
if t is not NotImplemented:
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/nmod.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ cdef class nmod(flint_scalar):
def __hash__(self):
return hash((int(self.val), self.modulus))

def __nonzero__(self):
def __bool__(self):
return self.val != 0

def __pos__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/nmod_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ cdef class nmod_mat(flint_mat):
else:
raise TypeError("nmod_mat: expected 1-3 arguments plus modulus")

def __nonzero__(self):
def __bool__(self):
return not nmod_mat_is_zero(self.val)

def __richcmp__(s, t, int op):
Expand Down
Loading

0 comments on commit 2e4b0ab

Please sign in to comment.