From c9e5c99a59a525d9edbe97de6d2e24660485e105 Mon Sep 17 00:00:00 2001 From: user202729 <25191436+user202729@users.noreply.github.com> Date: Tue, 19 Nov 2024 09:55:20 +0700 Subject: [PATCH] Allow construct RIF element from question-style string --- src/sage/libs/mpfi/__init__.pxd | 4 +- src/sage/libs/mpfr/__init__.pxd | 4 +- src/sage/rings/convert/mpfi.pyx | 247 ++++++++++++++++++++++++++++++++ 3 files changed, 251 insertions(+), 4 deletions(-) diff --git a/src/sage/libs/mpfi/__init__.pxd b/src/sage/libs/mpfi/__init__.pxd index b55a5129d18..6b81af16593 100644 --- a/src/sage/libs/mpfi/__init__.pxd +++ b/src/sage/libs/mpfi/__init__.pxd @@ -26,7 +26,7 @@ cdef extern from "mpfi.h": int mpfi_set_z(mpfi_ptr, mpz_t) int mpfi_set_q(mpfi_ptr, mpq_t) int mpfi_set_fr(mpfi_ptr, mpfr_srcptr) - int mpfi_set_str(mpfi_ptr, char *, int) + int mpfi_set_str(mpfi_ptr, const char *, int) # combined initialization and assignment functions int mpfi_init_set(mpfi_ptr, mpfi_srcptr) @@ -36,7 +36,7 @@ cdef extern from "mpfi.h": int mpfi_init_set_z(mpfi_ptr, mpz_srcptr) int mpfi_init_set_q(mpfi_ptr, mpq_srcptr) int mpfi_init_set_fr(mpfi_ptr, mpfr_srcptr) - int mpfi_init_set_str(mpfi_ptr, char *, int) + int mpfi_init_set_str(mpfi_ptr, const char *, int) # swapping two intervals void mpfi_swap(mpfi_ptr, mpfi_ptr) diff --git a/src/sage/libs/mpfr/__init__.pxd b/src/sage/libs/mpfr/__init__.pxd index facac9aa6c7..8bb85ff3c1a 100644 --- a/src/sage/libs/mpfr/__init__.pxd +++ b/src/sage/libs/mpfr/__init__.pxd @@ -27,7 +27,7 @@ cdef extern from "mpfr.h": # int mpfr_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd) int mpfr_set_ui_2exp(mpfr_t rop, unsigned long int op, mp_exp_t e, mpfr_rnd_t rnd) int mpfr_set_si_2exp(mpfr_t rop, long int op, mp_exp_t e, mpfr_rnd_t rnd) - int mpfr_set_str(mpfr_t rop, char *s, int base, mpfr_rnd_t rnd) + int mpfr_set_str(mpfr_t rop, const char *s, int base, mpfr_rnd_t rnd) int mpfr_strtofr(mpfr_t rop, char *nptr, char **endptr, int base, mpfr_rnd_t rnd) void mpfr_set_inf(mpfr_t x, int sign) void mpfr_set_nan(mpfr_t x) @@ -43,7 +43,7 @@ cdef extern from "mpfr.h": int mpfr_init_set_z(mpfr_t rop, mpz_t op, mpfr_rnd_t rnd) int mpfr_init_set_q(mpfr_t rop, mpq_t op, mpfr_rnd_t rnd) # int mpfr_init_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd) - int mpfr_init_set_str(mpfr_t x, char *s, int base, mpfr_rnd_t rnd) + int mpfr_init_set_str(mpfr_t x, const char *s, int base, mpfr_rnd_t rnd) # Conversion Functions double mpfr_get_d(mpfr_t op, mpfr_rnd_t rnd) diff --git a/src/sage/rings/convert/mpfi.pyx b/src/sage/rings/convert/mpfi.pyx index 106a94c5aca..4007455bcfd 100644 --- a/src/sage/rings/convert/mpfi.pyx +++ b/src/sage/rings/convert/mpfi.pyx @@ -11,11 +11,16 @@ Convert Sage/Python objects to real/complex intervals # http://www.gnu.org/licenses/ #***************************************************************************** +import re + from cpython.float cimport PyFloat_AS_DOUBLE from cpython.complex cimport PyComplex_RealAsDouble, PyComplex_ImagAsDouble +from libc.stdio cimport printf + from sage.libs.mpfr cimport * from sage.libs.mpfi cimport * +from sage.libs.gmp.mpz cimport * from sage.libs.gsl.complex cimport * from sage.arith.long cimport integer_check_long @@ -45,6 +50,243 @@ cdef inline int return_real(mpfi_ptr im) noexcept: return 0 +NUMBER = re.compile(rb'([+-]?(0[XxBb])?[0-9A-Za-z]+)\.([0-9A-Za-z]*)\?([0-9]*)(?:([EePp@])([+-]?[0-9]+))?') +# example: -0xABC.DEF?12@5 +# match groups: (-0xABC) (0x) (DEF) (12) (@) (5) + +cdef int _from_str_question_style(mpfi_ptr x, bytes s, int base) except -1: + """ + Convert a string in question style to an MPFI interval. + + INPUT: + + - ``x`` -- a pre-initialized MPFI interval + + - ``s`` -- the string to convert + + - ``base`` -- base to use for string conversion + + OUTPUT: + + - if conversion is possible: set ``x`` and return 0. + + - in all other cases: return some nonzero value, or raise an exception. + + TESTS: + + Double check that ``ZZ``, ``RR`` and ``RIF`` follows the string + conversion rule for base different from `10` (except ``ZZ`` + which only allows base up to `36`):: + + sage: ZZ("0x123", base=0) + 291 + sage: RR("0x123.e1", base=0) # rel tol 1e-12 + 291.878906250000 + sage: RR("0x123.@1", base=0) # rel tol 1e-12 + 4656.00000000000 + sage: RIF("0x123.4@1", base=0) + 4660 + sage: ZZ("1Xx", base=36) # case insensitive + 2517 + sage: ZZ("1Xx", base=62) + Traceback (most recent call last): + ... + ValueError: base (=62) must be 0 or between 2 and 36 + sage: RR("1Xx", base=36) # rel tol 1e-12 + 2517.00000000000 + sage: RR("0x123", base=36) # rel tol 1e-12 + 1.54101900000000e6 + sage: RR("-1Xx@-1", base=62) # rel tol 1e-12 + -95.9516129032258 + sage: RIF("1Xx@-1", base=62) # rel tol 1e-12 + 95.95161290322580? + sage: RIF("1aE1", base=11) + Traceback (most recent call last): + ... + TypeError: unable to convert '1aE1' to real interval + sage: RIF("1aE1", base=11) + Traceback (most recent call last): + ... + TypeError: unable to convert '1aE1' to real interval + + General checks:: + + sage: RIF("123456.?2").endpoints() # rel tol 1e-12 + (123454.0, 123458.0) + sage: RIF("1234.56?2").endpoints() # rel tol 1e-12 + (1234.54, 1234.58) + sage: RIF("1234.56?2e2").endpoints() # rel tol 1e-12 + (123454.0, 123458.0) + sage: x = RIF("-1234.56?2e2"); x.endpoints() # rel tol 1e-12 + (-123458.0, -123454.0) + sage: x + -1.2346?e5 + sage: x.str(style="question", error_digits=1) + '-123456.?2' + sage: RIF("1.?100").endpoints() # rel tol 1e-12 + (-99.0, 101.0) + sage: RIF("1.?100").str(style="question", error_digits=3) + '1.?100' + + Large exponent (ensure precision is not lost):: + + sage: x = RIF("1.123456?2e1000000000"); x + 1.12346?e1000000000 + sage: x.str(style="question", error_digits=3) + '1.12345600?201e1000000000' + + Large precision:: + + sage: F = RealIntervalField(1000) + sage: x = F(sqrt(2)); x.endpoints() # rel tol 1e-290 + (1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140798, + 1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140799) + sage: x in F(x.str(style="question", error_digits=3)) + True + sage: x in F(x.str(style="question", error_digits=0)) + True + sage: F("1.123456789123456789123456789123456789123456789123456789123456789123456789?987654321987654321987654321e500").endpoints() # rel tol 1e-290 + (1.123456789123456789123456789123456789123456788135802467135802467135802468e500, + 1.12345678912345678912345678912345678912345679011111111111111111111111111e500) + + Stress test:: + + sage: for F in [RealIntervalField(15), RIF, RealIntervalField(100), RealIntervalField(1000)]: + ....: for i in range(1000): + ....: a, b = randint(-10^9, 10^9), randint(0, 50) + ....: c, d = randint(-2^b, 2^b), randint(2, 5) + ....: x = a * F(d)^c + ....: assert x in F(x.str(style="question", error_digits=3)), (x, a, c, d) + ....: assert x in F(x.str(style="question", error_digits=0)), (x, a, c, d) + + Base different from `10` (note that the error and exponent are specified in decimal):: + + sage: RIF("10000.?0", base=2).endpoints() # rel tol 1e-12 + (16.0, 16.0) + sage: RIF("10000.?0e10", base=2).endpoints() # rel tol 1e-12 + (16384.0, 16384.0) + sage: x = RIF("10000.?10", base=2); x.endpoints() # rel tol 1e-12 + (6.0, 26.0) + sage: x.str(base=2, style="question", error_digits=2) + '10000.000?80' + sage: x = RIF("10000.000?80", base=2); x.endpoints() # rel tol 1e-12 + (6.0, 26.0) + sage: x = RIF("12a.?", base=16); x.endpoints() # rel tol 1e-12 + (297.0, 299.0) + sage: x = RIF("12a.BcDeF?", base=16); x.endpoints() # rel tol 1e-12 + (298.737775802611, 298.737777709962) + sage: x = RIF("12a.BcDeF?@10", base=16); x.endpoints() # rel tol 1e-12 + (3.28465658150911e14, 3.28465660248065e14) + sage: x = RIF("12a.BcDeF?p10", base=16); x.endpoints() # rel tol 1e-12 + (305907.482421875, 305907.484375000) + sage: x = RIF("0x12a.BcDeF?p10", base=0); x.endpoints() # rel tol 1e-12 + (305907.482421875, 305907.484375000) + + Space is allowed:: + + sage: RIF("-1234.56?2").endpoints() # rel tol 1e-12 + (-1234.58, -1234.54) + sage: RIF("- 1234.56 ?2").endpoints() # rel tol 1e-12 + (-1234.58, -1234.54) + + Erroneous input:: + + sage: RIF("1234.56?2e2.3") + Traceback (most recent call last): + ... + TypeError: unable to convert '1234.56?2e2.3' to real interval + sage: RIF("1234?2") # decimal point required + Traceback (most recent call last): + ... + TypeError: unable to convert '1234?2' to real interval + sage: RIF("1234.?2e") + Traceback (most recent call last): + ... + TypeError: unable to convert '1234.?2e' to real interval + sage: RIF("1.?e999999999999999999999999") + [-infinity .. +infinity] + sage: RIF("0X1.?", base=33) # X is not valid digit in base 33 + Traceback (most recent call last): + ... + TypeError: unable to convert '0X1.?' to real interval + sage: RIF("1.a?1e10", base=12) + Traceback (most recent call last): + ... + TypeError: unable to convert '1.a?1e10' to real interval + sage: RIF("1.1?a@10", base=12) + Traceback (most recent call last): + ... + TypeError: unable to convert '1.1?a@10' to real interval + sage: RIF("0x1?2e1", base=0) # e is not allowed in base > 10, use @ instead + Traceback (most recent call last): + ... + TypeError: unable to convert '0x1?2e1' to real interval + sage: RIF("0x1?2p1", base=36) + Traceback (most recent call last): + ... + TypeError: unable to convert '0x1?2p1' to real interval + """ + cdef mpz_t error_part + cdef mpfi_t error + cdef mpfr_t radius, neg_radius + cdef bytes int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string, optional_expo, tmp + + match = NUMBER.fullmatch(s) + if match is None: + return 1 + int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string = match.groups() + + if (base > 10 or (base == 0 and base_prefix in (b'0X', b'0X'))) and e in (b'e', b'E'): + return 1 + if base > 16 and e in (b'p', b'P'): + return 1 + if base > 16 or not base_prefix: + base_prefix = b'' + + if error_string: + if mpz_init_set_str(error_part, error_string, 10): + mpz_clear(error_part) + return 1 + else: + mpz_init_set_ui(error_part, 1) + + optional_expo = e + sci_expo_string if e else b'' + if mpfi_set_str(x, int_part_string + b'.' + frac_part_string + optional_expo, base): + mpz_clear(error_part) + return 1 + + mpfr_init2(radius, mpfi_get_prec(x)) + tmp = base_prefix + ( + b'0.' + b'0'*(len(frac_part_string)-1) + b'1' + optional_expo + if frac_part_string else + b'1.' + optional_expo) + # if base = 0: + # when s = '-0x123.456@7', tmp = '0x0.001@7' + # when s = '-0x123.@7', tmp = '0x1.@7' + # if base = 36: + # when s = '-0x123.456@7', tmp = '0.001@7' + if mpfr_set_str(radius, tmp, base, MPFR_RNDU): + mpfr_clear(radius) + mpz_clear(error_part) + return 1 + + mpfr_mul_z(radius, radius, error_part, MPFR_RNDU) + mpz_clear(error_part) + + mpfr_init2(neg_radius, mpfi_get_prec(x)) + mpfr_neg(neg_radius, radius, MPFR_RNDD) + + mpfi_init2(error, mpfi_get_prec(x)) + mpfi_interv_fr(error, neg_radius, radius) + mpfr_clear(radius) + mpfr_clear(neg_radius) + + mpfi_add(x, x, error) + mpfi_clear(error) + + return 0 + + cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1: """ Convert any object ``x`` to an MPFI interval or a pair of @@ -186,6 +428,11 @@ cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1: if isinstance(x, unicode): x = x.encode("ascii") if isinstance(x, bytes): + if b"?" in x: + if _from_str_question_style(re, (x).replace(b' ', b''), base): + x = bytes_to_str(x) + raise TypeError(f"unable to convert {x!r} to real interval") + return return_real(im) s = (x).replace(b'..', b',').replace(b' ', b'').replace(b'+infinity', b'@inf@').replace(b'-infinity', b'-@inf@') if mpfi_set_str(re, s, base): x = bytes_to_str(x)