diff --git a/src/flint/test/test.py b/src/flint/test/test.py index 9519d715..71bcea10 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -742,9 +742,9 @@ def test_fmpq(): assert raises(lambda: Q("1.0"), ValueError) assert raises(lambda: Q("1.5"), ValueError) assert raises(lambda: Q("1/2/3"), ValueError) - assert raises(lambda: Q([]), ValueError) - assert raises(lambda: Q(1, []), ValueError) - assert raises(lambda: Q([], 1), ValueError) + assert raises(lambda: Q([]), TypeError) + assert raises(lambda: Q(1, []), TypeError) + assert raises(lambda: Q([], 1), TypeError) assert bool(Q(0)) == False assert bool(Q(1)) == True assert Q(1,3) + Q(2,3) == 1 @@ -1049,9 +1049,8 @@ def test_fmpq_mat(): assert raises(lambda: Q(None), TypeError) assert Q([[1,2,3],[4,5,6]]) == Q(2,3,[1,2,3,4,5,6]) assert raises(lambda: Q(2,3,[1,2,3,4,5]), ValueError) - # XXX: Should be TypeError not ValueError: - assert raises(lambda: Q([[1,2,3],[4,[],6]]), ValueError) - assert raises(lambda: Q(2,3,[1,2,3,4,[],6]), ValueError) + assert raises(lambda: Q([[1,2,3],[4,[],6]]), TypeError) + assert raises(lambda: Q(2,3,[1,2,3,4,[],6]), TypeError) assert raises(lambda: Q(2,3,[1,2],[3,4]), ValueError) assert bool(Q([[1]])) is True assert bool(Q([[0]])) is False @@ -1815,6 +1814,204 @@ def test_fmpz_mod_dlog(): assert g**x == a +def _all_polys(): + return [ + # (poly_type, scalar_type, is_field) + (flint.fmpz_poly, flint.fmpz, False), + (flint.fmpq_poly, flint.fmpq, True), + (lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True), + ] + + +def test_polys(): + for P, S, is_field in _all_polys(): + + assert P([S(1)]) == P([1]) == P(P([1])) == P(1) + + assert raises(lambda: P([None]), TypeError) + assert raises(lambda: P(object()), TypeError) + assert raises(lambda: P(None), TypeError) + assert raises(lambda: P(None, None), TypeError) + assert raises(lambda: P([1,2], None), TypeError) + assert raises(lambda: P(1, None), TypeError) + + assert len(P([])) == P([]).length() == 0 + assert len(P([1])) == P([1]).length() == 1 + assert len(P([1,2])) == P([1,2]).length() == 2 + assert len(P([1,2,3])) == P([1,2,3]).length() == 3 + + assert P([]).degree() == -1 + assert P([1]).degree() == 0 + assert P([1,2]).degree() == 1 + assert P([1,2,3]).degree() == 2 + + assert (P([1]) == P([1])) is True + assert (P([1]) != P([1])) is False + assert (P([1]) == P([2])) is False + assert (P([1]) != P([2])) is True + + assert (P([1]) == None) is False + assert (P([1]) != None) is True + assert (None == P([1])) is False + assert (None != P([1])) is True + + assert raises(lambda: P([1]) < P([1]), TypeError) + assert raises(lambda: P([1]) <= P([1]), TypeError) + assert raises(lambda: P([1]) > P([1]), TypeError) + assert raises(lambda: P([1]) >= P([1]), TypeError) + assert raises(lambda: P([1]) < None, TypeError) + assert raises(lambda: P([1]) <= None, TypeError) + assert raises(lambda: P([1]) > None, TypeError) + assert raises(lambda: P([1]) >= None, TypeError) + assert raises(lambda: None < P([1]), TypeError) + assert raises(lambda: None <= P([1]), TypeError) + assert raises(lambda: None > P([1]), TypeError) + assert raises(lambda: None >= P([1]), TypeError) + + assert P([1, 2, 3])[1] == S(2) + assert P([1, 2, 3])[-1] == S(0) + assert P([1, 2, 3])[3] == S(0) + + p = P([1, 2, 3]) + p[1] = S(4) + assert p == P([1, 4, 3]) + + def setbad(obj, i, val): + obj[i] = val + + assert raises(lambda: setbad(p, 2, None), TypeError) + assert raises(lambda: setbad(p, -1, 1), ValueError) + + for v in [], [1], [1, 2]: + if P == flint.fmpz_poly: + assert P(v).repr() == f'fmpz_poly({v!r})' + elif P == flint.fmpq_poly: + assert P(v).repr() == f'fmpq_poly({v!r})' + else: + assert P(v).repr() == f'nmod_poly({v!r}, 17)' + + assert repr(P([])) == '0' + assert repr(P([1])) == '1' + assert repr(P([1, 2])) == '2*x + 1' + assert repr(P([1, 2, 3])) == '3*x^2 + 2*x + 1' + + p = P([1, 2, 3]) + assert p(0) == p(S(0)) == S(1) == 1 + assert p(1) == p(S(1)) == S(6) == 6 + assert p(p) == P([6, 16, 36, 36, 27]) + assert raises(lambda: p(None), TypeError) + + assert bool(P([])) is False + assert bool(P([1])) is True + + assert +P([1, 2, 3]) == P([1, 2, 3]) + assert -P([1, 2, 3]) == P([-1, -2, -3]) + + assert P([1, 2, 3]) + P([4, 5, 6]) == P([5, 7, 9]) + + for T in [int, S, flint.fmpz]: + assert P([1, 2, 3]) + T(1) == P([2, 2, 3]) + assert T(1) + P([1, 2, 3]) == P([2, 2, 3]) + + assert raises(lambda: P([1, 2, 3]) + None, TypeError) + assert raises(lambda: None + P([1, 2, 3]), TypeError) + + assert P([1, 2, 3]) - P([4, 5, 6]) == P([-3, -3, -3]) + + for T in [int, S, flint.fmpz]: + assert P([1, 2, 3]) - T(1) == P([0, 2, 3]) + assert T(1) - P([1, 2, 3]) == P([0, -2, -3]) + + assert raises(lambda: P([1, 2, 3]) - None, TypeError) + assert raises(lambda: None - P([1, 2, 3]), TypeError) + + assert P([1, 2, 3]) * P([4, 5, 6]) == P([4, 13, 28, 27, 18]) + + for T in [int, S, flint.fmpz]: + assert P([1, 2, 3]) * T(2) == P([2, 4, 6]) + assert T(2) * P([1, 2, 3]) == P([2, 4, 6]) + + assert raises(lambda: P([1, 2, 3]) * None, TypeError) + assert raises(lambda: None * P([1, 2, 3]), TypeError) + + assert P([1, 2, 1]) // P([1, 1]) == P([1, 1]) + assert P([1, 2, 1]) % P([1, 1]) == P([0]) + assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0])) + + if is_field: + assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2]) + assert P([1, 1]) % 2 == P([0]) + else: + assert P([1, 1]) // 2 == P([0, 0]) + assert P([1, 1]) % 2 == P([1, 1]) + + assert 1 // P([1, 1]) == P([0]) + assert 1 % P([1, 1]) == P([1]) + assert divmod(1, P([1, 1])) == (P([0]), P([1])) + + assert raises(lambda: P([1, 2, 1]) // None, TypeError) + assert raises(lambda: P([1, 2, 1]) % None, TypeError) + assert raises(lambda: divmod(P([1, 2, 1]), None), TypeError) + + assert raises(lambda: None // P([1, 1]), TypeError) + assert raises(lambda: None % P([1, 1]), TypeError) + assert raises(lambda: divmod(None, P([1, 1])), TypeError) + + assert raises(lambda: P([1, 2, 1]) // 0, ZeroDivisionError) + assert raises(lambda: P([1, 2, 1]) % 0, ZeroDivisionError) + assert raises(lambda: divmod(P([1, 2, 1]), 0), ZeroDivisionError) + + assert raises(lambda: P([1, 2, 1]) // P([0]), ZeroDivisionError) + assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError) + assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError) + + if is_field: + assert P([2, 2]) / 2 == P([1, 1]) + assert P([1, 2]) / 2 == P([S(1)/2, 1]) + assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) + else: + assert raises(lambda: P([2, 2]) / 2, TypeError) + + assert raises(lambda: 1 / P([1, 1]), TypeError) + assert raises(lambda: P([1, 2, 1]) / P([1, 1]), TypeError) + assert raises(lambda: P([1, 2, 1]) / P([1, 2]), TypeError) + + assert P([1, 1]) ** 0 == P([1]) + assert P([1, 1]) ** 1 == P([1, 1]) + assert P([1, 1]) ** 2 == P([1, 2, 1]) + assert raises(lambda: P([1, 1]) ** -1, ValueError) + assert raises(lambda: P([1, 1]) ** None, TypeError) + # XXX: Not sure what this should do in general: + assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError) + + assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1]) + assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) + + if is_field: + p1 = P([1, 0, 1]) + p2 = P([2, 1]) + g, s, t = P([1]), P([1])/5, P([2, -1])/5 + assert p1.xgcd(p2) == (g, s, t) + assert raises(lambda: p1.xgcd(None), TypeError) + + assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)]) + + assert P([1, 2, 1]).sqrt() == P([1, 1]) + assert P([1, 2, 2]).sqrt() is None + if P == flint.fmpq_poly: + assert P([1, 2, 1], 3).sqrt() is None + assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2) + + assert P([]).deflation() == (P([]), 1) + assert P([1, 2]).deflation() == (P([1, 2]), 1) + assert P([1, 0, 2]).deflation() == (P([1, 2]), 2) + + assert P([1, 2, 1]).derivative() == P([2, 2]) + + if is_field: + assert P([1, 2, 1]).integral() == P([0, 1, 1, S(1)/3]) + + all_tests = [ test_pyflint, @@ -1835,5 +2032,6 @@ def test_fmpz_mod_dlog(): test_nmod_mat, test_arb, test_fmpz_mod, - test_fmpz_mod_dlog + test_fmpz_mod_dlog, + test_polys, ] diff --git a/src/flint/types/fmpq.pyx b/src/flint/types/fmpq.pyx index 06e99957..cc25fc7b 100644 --- a/src/flint/types/fmpq.pyx +++ b/src/flint/types/fmpq.pyx @@ -71,12 +71,14 @@ cdef class fmpq(flint_scalar): def __dealloc__(self): fmpq_clear(self.val) - def __init__(self, p=None, q=None): - cdef long x - if q is None: - if p is None: - return # zero - elif typecheck(p, fmpq): + def __init__(self, *args): + if not args: + return # zero + elif len(args) == 2: + p, q = args + elif len(args) == 1: + p = args[0] + if typecheck(p, fmpq): fmpq_set(self.val, (p).val) return elif typecheck(p, str): @@ -90,17 +92,21 @@ cdef class fmpq(flint_scalar): else: p = any_as_fmpq(p) if p is NotImplemented: - raise ValueError("cannot create fmpq from object of type %s" % type(p)) + raise TypeError("cannot create fmpq from object of type %s" % type(p)) fmpq_set(self.val, (p).val) return + else: + raise TypeError("fmpq() takes at most 2 arguments (%d given)" % len(args)) + p = any_as_fmpz(p) if p is NotImplemented: - raise ValueError("cannot create fmpq from object of type %s" % type(p)) + raise TypeError("cannot create fmpq from object of type %s" % type(p)) q = any_as_fmpz(q) if q is NotImplemented: - raise ValueError("cannot create fmpq from object of type %s" % type(q)) + raise TypeError("cannot create fmpq from object of type %s" % type(q)) if fmpz_is_zero((q).val): raise ZeroDivisionError("cannot create rational number with zero denominator") + fmpz_set(fmpq_numref(self.val), (p).val) fmpz_set(fmpq_denref(self.val), (q).val) fmpq_canonicalise(self.val) diff --git a/src/flint/types/fmpq_poly.pyx b/src/flint/types/fmpq_poly.pyx index 8c85b28d..eec0bd58 100644 --- a/src/flint/types/fmpq_poly.pyx +++ b/src/flint/types/fmpq_poly.pyx @@ -78,18 +78,26 @@ cdef class fmpq_poly(flint_poly): def __dealloc__(self): fmpq_poly_clear(self.val) - def __init__(self, p=None, q=None): - if p is not None: - if typecheck(p, fmpq_poly): - fmpq_poly_set(self.val, (p).val) - elif typecheck(p, fmpz_poly): - fmpq_poly_set_fmpz_poly(self.val, (p).val) - elif isinstance(p, list): - fmpq_poly_set_list(self.val, p) - else: - raise TypeError("cannot create fmpq_poly from input of type %s", type(p)) - if q is not None: - q = any_as_fmpz(q) + def __init__(self, *args): + if len(args) == 0: + return + elif len(args) > 2: + raise TypeError("fmpq_poly() takes 0, 1 or 2 arguments (%d given)" % len(args)) + + p = args[0] + if typecheck(p, fmpq_poly): + fmpq_poly_set(self.val, (p).val) + elif typecheck(p, fmpz_poly): + fmpq_poly_set_fmpz_poly(self.val, (p).val) + elif isinstance(p, list): + fmpq_poly_set_list(self.val, p) + elif (v := any_as_fmpq(p)) is not NotImplemented: + fmpq_poly_set_fmpq(self.val, (v).val) + else: + raise TypeError("cannot create fmpq_poly from input of type %s", type(p)) + + if len(args) == 2: + q = any_as_fmpz(args[1]) if q is NotImplemented: raise TypeError("denominator must be an integer, got %s", type(q)) if fmpz_is_zero((q).val): @@ -326,12 +334,14 @@ cdef class fmpq_poly(flint_poly): return t return t._divmod_(s) - def __pow__(fmpq_poly self, ulong exp, mod): + def __pow__(fmpq_poly self, exp, mod): cdef fmpq_poly res if mod is not None: raise NotImplementedError("fmpz_poly modular exponentiation") + if exp < 0: + raise ValueError("fmpq_poly negative exponent") res = fmpq_poly.__new__(fmpq_poly) - fmpq_poly_pow(res.val, self.val, exp) + fmpq_poly_pow(res.val, self.val, exp) return res def gcd(self, other): @@ -384,6 +394,32 @@ cdef class fmpq_poly(flint_poly): fac[i] = (base, exp) return c / self.denom(), fac + def sqrt(self): + """ + Return the exact square root of this polynomial or ``None``. + + >>> p = fmpq_poly([1,2,1],4) + >>> p + 1/4*x^2 + 1/2*x + 1/4 + >>> p.sqrt() + 1/2*x + 1/2 + + """ + d = self.denom() + n = self.numer() + d, r = d.sqrtrem() + if r != 0: + return None + n = n.sqrt() + if n is None: + return None + return fmpq_poly(n, d) + + def deflation(self): + num, n = self.numer().deflation() + num = fmpq_poly(num, self.denom()) + return num, n + def complex_roots(self, **kwargs): """ Computes the complex roots of this polynomial. See diff --git a/src/flint/types/fmpz.pyx b/src/flint/types/fmpz.pyx index 51d46ba1..31659c4f 100644 --- a/src/flint/types/fmpz.pyx +++ b/src/flint/types/fmpz.pyx @@ -71,18 +71,22 @@ cdef class fmpz(flint_scalar): def __dealloc__(self): fmpz_clear(self.val) - def __init__(self, val=None): + def __init__(self, *args): cdef long x - if val is not None: - if typecheck(val, fmpz): - fmpz_set(self.val, (val).val) - else: - if fmpz_set_any_ref(self.val, val) == FMPZ_UNKNOWN: # XXX - if typecheck(val, str): - if fmpz_set_str(self.val, chars_from_str(val), 10) != 0: - raise ValueError("invalid string for fmpz") - return - raise TypeError("cannot create fmpz from type %s" % type(val)) + if not args: + return + elif len(args) != 1: + raise TypeError("fmpz takes zero or one arguments.") + val = args[0] + if typecheck(val, fmpz): + fmpz_set(self.val, (val).val) + else: + if fmpz_set_any_ref(self.val, val) == FMPZ_UNKNOWN: # XXX + if typecheck(val, str): + if fmpz_set_str(self.val, chars_from_str(val), 10) != 0: + raise ValueError("invalid string for fmpz") + return + raise TypeError("cannot create fmpz from type %s" % type(val)) @property def numerator(self): diff --git a/src/flint/types/fmpz_poly.pyx b/src/flint/types/fmpz_poly.pyx index 8a461755..f56c4957 100644 --- a/src/flint/types/fmpz_poly.pyx +++ b/src/flint/types/fmpz_poly.pyx @@ -1,4 +1,3 @@ -from cpython.version cimport PY_MAJOR_VERSION from cpython.int cimport PyInt_AS_LONG from cpython.list cimport PyList_GET_SIZE from cpython.long cimport PyLong_Check @@ -38,10 +37,6 @@ cdef any_as_fmpz_poly(x): res = fmpz_poly.__new__(fmpz_poly) fmpz_poly_set_fmpz(res.val, (x).val) return res - elif PY_MAJOR_VERSION < 3 and PyInt_Check(x): - res = fmpz_poly.__new__(fmpz_poly) - fmpz_poly_set_si(res.val, PyInt_AS_LONG(x)) - return res elif PyLong_Check(x): res = fmpz_poly.__new__(fmpz_poly) t = fmpz(x) @@ -82,14 +77,21 @@ cdef class fmpz_poly(flint_poly): def __dealloc__(self): fmpz_poly_clear(self.val) - def __init__(self, val=None): - if val is not None: - if typecheck(val, fmpz_poly): - fmpz_poly_set(self.val, (val).val) - elif isinstance(val, list): - fmpz_poly_set_list(self.val, val) - else: - raise TypeError("cannot create fmpz_poly from input of type %s", type(val)) + def __init__(self, *args): + if not args: + return + elif len(args) == 1: + val = args[0] + else: + raise TypeError("fmpz_poly() takes 0 or 1 arguments") + if typecheck(val, fmpz_poly): + fmpz_poly_set(self.val, (val).val) + elif isinstance(val, list): + fmpz_poly_set_list(self.val, val) + elif (v := any_as_fmpz(val)) is not NotImplemented: + fmpz_poly_set_fmpz(self.val, (v).val) + else: + raise TypeError("cannot create fmpz_poly from input of type %s", type(val)) def __len__(self): return fmpz_poly_length(self.val) @@ -286,12 +288,14 @@ cdef class fmpz_poly(flint_poly): return other return other._divmod_(self) - def __pow__(fmpz_poly self, ulong exp, mod): + def __pow__(fmpz_poly self, exp, mod): cdef fmpz_poly res if mod is not None: raise NotImplementedError("fmpz_poly modular exponentiation") + if exp < 0: + raise ValueError("fmpz_poly negative exponent") res = fmpz_poly.__new__(fmpz_poly) - fmpz_poly_pow(res.val, self.val, exp) + fmpz_poly_pow(res.val, self.val, exp) return res def gcd(self, other): diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 8c7b297b..73c02643 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -10,6 +10,7 @@ from flint.flintlib.nmod_vec cimport * from flint.flintlib.nmod_poly cimport * from flint.flintlib.nmod_poly_factor cimport * from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly +from flint.flintlib.ulong_extras cimport n_gcdinv cdef any_as_nmod_poly(obj, nmod_t mod): cdef nmod_poly r @@ -68,13 +69,15 @@ cdef class nmod_poly(flint_poly): # cdef nmod_poly_t val - #def __cinit__(self): + def __cinit__(self): + nmod_poly_init(self.val, 1) def __dealloc__(self): nmod_poly_clear(self.val) def __init__(self, val=None, ulong mod=0): cdef ulong m2 + cdef mp_limb_t v if typecheck(val, nmod_poly): m2 = nmod_poly_modulus((val).val) if m2 != mod: @@ -89,6 +92,9 @@ cdef class nmod_poly(flint_poly): fmpz_poly_get_nmod_poly(self.val, (val).val) elif typecheck(val, list): nmod_poly_set_list(self.val, val) + elif any_as_nmod(&v, val, self.val.mod): + nmod_poly_fit_length(self.val, 1) + nmod_poly_set_coeff_ui(self.val, 0, v) else: raise TypeError("cannot create nmod_poly from input of type %s", type(val)) @@ -174,6 +180,18 @@ cdef class nmod_poly(flint_poly): return r raise TypeError("cannot call nmod_poly with input of type %s", type(other)) + def derivative(self): + cdef nmod_poly res = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + nmod_poly_derivative(res.val, self.val) + return res + + def integral(self): + cdef nmod_poly res = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + nmod_poly_integral(res.val, self.val) + return res + def __pos__(self): return self @@ -290,19 +308,28 @@ cdef class nmod_poly(flint_poly): return t return t._divmod_(s) + def __truediv__(s, t): + try: + t = nmod(t, (s).val.mod.n) + except TypeError: + return NotImplemented + return s * t ** -1 + def __mod__(s, t): return divmod(s, t)[1] # XXX def __rmod__(s, t): return divmod(t, s)[1] # XXX - def __pow__(nmod_poly self, ulong exp, mod): + def __pow__(nmod_poly self, exp, mod): cdef nmod_poly res if mod is not None: raise NotImplementedError("nmod_poly modular exponentiation") + if exp < 0: + raise ValueError("negative exponent") res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, (self).val.mod.n, (self).val.mod.ninv) - nmod_poly_pow(res.val, self.val, exp) + nmod_poly_pow(res.val, self.val, exp) return res def gcd(self, other): @@ -325,6 +352,20 @@ cdef class nmod_poly(flint_poly): nmod_poly_gcd(res.val, self.val, (other).val) return res + def xgcd(self, other): + cdef nmod_poly res1, res2, res3 + other = any_as_nmod_poly(other, (self).val.mod) + if other is NotImplemented: + raise TypeError("cannot convert input to fmpq_poly") + res1 = nmod_poly.__new__(nmod_poly) + res2 = nmod_poly.__new__(nmod_poly) + res3 = nmod_poly.__new__(nmod_poly) + nmod_poly_init(res1.val, (self).val.mod.n) + nmod_poly_init(res2.val, (self).val.mod.n) + nmod_poly_init(res3.val, (self).val.mod.n) + nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (other).val) + return (res1, res2, res3) + def factor(self, algorithm=None): """ Factors self into irreducible factors, returning a tuple @@ -372,3 +413,26 @@ cdef class nmod_poly(flint_poly): nmod_poly_factor_clear(fac) return c, res + def sqrt(nmod_poly self): + """Return exact square root or ``None``. """ + cdef nmod_poly res + res = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + if nmod_poly_sqrt(res.val, self.val): + return res + else: + return None + + def deflation(self): + cdef nmod_poly v + cdef ulong n + if nmod_poly_is_zero(self.val): + return self, 1 + n = nmod_poly_deflation(self.val) + if n == 1: + return self, int(n) + else: + v = nmod_poly.__new__(nmod_poly) + nmod_poly_init(v.val, self.val.mod.n) + nmod_poly_deflate(v.val, self.val, n) + return v, int(n)