Skip to content

Commit

Permalink
Allow construct RIF element from question-style string
Browse files Browse the repository at this point in the history
  • Loading branch information
user202729 committed Nov 19, 2024
1 parent ff9d834 commit c9e5c99
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/sage/libs/mpfi/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cdef extern from "mpfi.h":
int mpfi_set_z(mpfi_ptr, mpz_t)
int mpfi_set_q(mpfi_ptr, mpq_t)
int mpfi_set_fr(mpfi_ptr, mpfr_srcptr)
int mpfi_set_str(mpfi_ptr, char *, int)
int mpfi_set_str(mpfi_ptr, const char *, int)

# combined initialization and assignment functions
int mpfi_init_set(mpfi_ptr, mpfi_srcptr)
Expand All @@ -36,7 +36,7 @@ cdef extern from "mpfi.h":
int mpfi_init_set_z(mpfi_ptr, mpz_srcptr)
int mpfi_init_set_q(mpfi_ptr, mpq_srcptr)
int mpfi_init_set_fr(mpfi_ptr, mpfr_srcptr)
int mpfi_init_set_str(mpfi_ptr, char *, int)
int mpfi_init_set_str(mpfi_ptr, const char *, int)

# swapping two intervals
void mpfi_swap(mpfi_ptr, mpfi_ptr)
Expand Down
4 changes: 2 additions & 2 deletions src/sage/libs/mpfr/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ cdef extern from "mpfr.h":
# int mpfr_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
int mpfr_set_ui_2exp(mpfr_t rop, unsigned long int op, mp_exp_t e, mpfr_rnd_t rnd)
int mpfr_set_si_2exp(mpfr_t rop, long int op, mp_exp_t e, mpfr_rnd_t rnd)
int mpfr_set_str(mpfr_t rop, char *s, int base, mpfr_rnd_t rnd)
int mpfr_set_str(mpfr_t rop, const char *s, int base, mpfr_rnd_t rnd)
int mpfr_strtofr(mpfr_t rop, char *nptr, char **endptr, int base, mpfr_rnd_t rnd)
void mpfr_set_inf(mpfr_t x, int sign)
void mpfr_set_nan(mpfr_t x)
Expand All @@ -43,7 +43,7 @@ cdef extern from "mpfr.h":
int mpfr_init_set_z(mpfr_t rop, mpz_t op, mpfr_rnd_t rnd)
int mpfr_init_set_q(mpfr_t rop, mpq_t op, mpfr_rnd_t rnd)
# int mpfr_init_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
int mpfr_init_set_str(mpfr_t x, char *s, int base, mpfr_rnd_t rnd)
int mpfr_init_set_str(mpfr_t x, const char *s, int base, mpfr_rnd_t rnd)

# Conversion Functions
double mpfr_get_d(mpfr_t op, mpfr_rnd_t rnd)
Expand Down
247 changes: 247 additions & 0 deletions src/sage/rings/convert/mpfi.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ Convert Sage/Python objects to real/complex intervals
# http://www.gnu.org/licenses/
#*****************************************************************************

import re

from cpython.float cimport PyFloat_AS_DOUBLE
from cpython.complex cimport PyComplex_RealAsDouble, PyComplex_ImagAsDouble

from libc.stdio cimport printf

from sage.libs.mpfr cimport *
from sage.libs.mpfi cimport *
from sage.libs.gmp.mpz cimport *
from sage.libs.gsl.complex cimport *

from sage.arith.long cimport integer_check_long
Expand Down Expand Up @@ -45,6 +50,243 @@ cdef inline int return_real(mpfi_ptr im) noexcept:
return 0


NUMBER = re.compile(rb'([+-]?(0[XxBb])?[0-9A-Za-z]+)\.([0-9A-Za-z]*)\?([0-9]*)(?:([EePp@])([+-]?[0-9]+))?')
# example: -0xABC.DEF?12@5
# match groups: (-0xABC) (0x) (DEF) (12) (@) (5)

cdef int _from_str_question_style(mpfi_ptr x, bytes s, int base) except -1:
"""
Convert a string in question style to an MPFI interval.
INPUT:
- ``x`` -- a pre-initialized MPFI interval
- ``s`` -- the string to convert
- ``base`` -- base to use for string conversion
OUTPUT:
- if conversion is possible: set ``x`` and return 0.
- in all other cases: return some nonzero value, or raise an exception.
TESTS:
Double check that ``ZZ``, ``RR`` and ``RIF`` follows the string
conversion rule for base different from `10` (except ``ZZ``
which only allows base up to `36`)::
sage: ZZ("0x123", base=0)
291
sage: RR("0x123.e1", base=0) # rel tol 1e-12
291.878906250000
sage: RR("0x123.@1", base=0) # rel tol 1e-12
4656.00000000000
sage: RIF("0x123.4@1", base=0)
4660
sage: ZZ("1Xx", base=36) # case insensitive
2517
sage: ZZ("1Xx", base=62)
Traceback (most recent call last):
...
ValueError: base (=62) must be 0 or between 2 and 36
sage: RR("1Xx", base=36) # rel tol 1e-12
2517.00000000000
sage: RR("0x123", base=36) # rel tol 1e-12
1.54101900000000e6
sage: RR("-1Xx@-1", base=62) # rel tol 1e-12
-95.9516129032258
sage: RIF("1Xx@-1", base=62) # rel tol 1e-12
95.95161290322580?
sage: RIF("1aE1", base=11)
Traceback (most recent call last):
...
TypeError: unable to convert '1aE1' to real interval
sage: RIF("1aE1", base=11)
Traceback (most recent call last):
...
TypeError: unable to convert '1aE1' to real interval
General checks::
sage: RIF("123456.?2").endpoints() # rel tol 1e-12
(123454.0, 123458.0)
sage: RIF("1234.56?2").endpoints() # rel tol 1e-12
(1234.54, 1234.58)
sage: RIF("1234.56?2e2").endpoints() # rel tol 1e-12
(123454.0, 123458.0)
sage: x = RIF("-1234.56?2e2"); x.endpoints() # rel tol 1e-12
(-123458.0, -123454.0)
sage: x
-1.2346?e5
sage: x.str(style="question", error_digits=1)
'-123456.?2'
sage: RIF("1.?100").endpoints() # rel tol 1e-12
(-99.0, 101.0)
sage: RIF("1.?100").str(style="question", error_digits=3)
'1.?100'
Large exponent (ensure precision is not lost)::
sage: x = RIF("1.123456?2e1000000000"); x
1.12346?e1000000000
sage: x.str(style="question", error_digits=3)
'1.12345600?201e1000000000'
Large precision::
sage: F = RealIntervalField(1000)
sage: x = F(sqrt(2)); x.endpoints() # rel tol 1e-290
(1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140798,
1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140799)
sage: x in F(x.str(style="question", error_digits=3))
True
sage: x in F(x.str(style="question", error_digits=0))
True
sage: F("1.123456789123456789123456789123456789123456789123456789123456789123456789?987654321987654321987654321e500").endpoints() # rel tol 1e-290
(1.123456789123456789123456789123456789123456788135802467135802467135802468e500,
1.12345678912345678912345678912345678912345679011111111111111111111111111e500)
Stress test::
sage: for F in [RealIntervalField(15), RIF, RealIntervalField(100), RealIntervalField(1000)]:
....: for i in range(1000):
....: a, b = randint(-10^9, 10^9), randint(0, 50)
....: c, d = randint(-2^b, 2^b), randint(2, 5)
....: x = a * F(d)^c
....: assert x in F(x.str(style="question", error_digits=3)), (x, a, c, d)
....: assert x in F(x.str(style="question", error_digits=0)), (x, a, c, d)
Base different from `10` (note that the error and exponent are specified in decimal)::
sage: RIF("10000.?0", base=2).endpoints() # rel tol 1e-12
(16.0, 16.0)
sage: RIF("10000.?0e10", base=2).endpoints() # rel tol 1e-12
(16384.0, 16384.0)
sage: x = RIF("10000.?10", base=2); x.endpoints() # rel tol 1e-12
(6.0, 26.0)
sage: x.str(base=2, style="question", error_digits=2)
'10000.000?80'
sage: x = RIF("10000.000?80", base=2); x.endpoints() # rel tol 1e-12
(6.0, 26.0)
sage: x = RIF("12a.?", base=16); x.endpoints() # rel tol 1e-12
(297.0, 299.0)
sage: x = RIF("12a.BcDeF?", base=16); x.endpoints() # rel tol 1e-12
(298.737775802611, 298.737777709962)
sage: x = RIF("12a.BcDeF?@10", base=16); x.endpoints() # rel tol 1e-12
(3.28465658150911e14, 3.28465660248065e14)
sage: x = RIF("12a.BcDeF?p10", base=16); x.endpoints() # rel tol 1e-12
(305907.482421875, 305907.484375000)
sage: x = RIF("0x12a.BcDeF?p10", base=0); x.endpoints() # rel tol 1e-12
(305907.482421875, 305907.484375000)
Space is allowed::
sage: RIF("-1234.56?2").endpoints() # rel tol 1e-12
(-1234.58, -1234.54)
sage: RIF("- 1234.56 ?2").endpoints() # rel tol 1e-12
(-1234.58, -1234.54)
Erroneous input::
sage: RIF("1234.56?2e2.3")
Traceback (most recent call last):
...
TypeError: unable to convert '1234.56?2e2.3' to real interval
sage: RIF("1234?2") # decimal point required
Traceback (most recent call last):
...
TypeError: unable to convert '1234?2' to real interval
sage: RIF("1234.?2e")
Traceback (most recent call last):
...
TypeError: unable to convert '1234.?2e' to real interval
sage: RIF("1.?e999999999999999999999999")
[-infinity .. +infinity]
sage: RIF("0X1.?", base=33) # X is not valid digit in base 33
Traceback (most recent call last):
...
TypeError: unable to convert '0X1.?' to real interval
sage: RIF("1.a?1e10", base=12)
Traceback (most recent call last):
...
TypeError: unable to convert '1.a?1e10' to real interval
sage: RIF("1.1?a@10", base=12)
Traceback (most recent call last):
...
TypeError: unable to convert '1.1?a@10' to real interval
sage: RIF("0x1?2e1", base=0) # e is not allowed in base > 10, use @ instead
Traceback (most recent call last):
...
TypeError: unable to convert '0x1?2e1' to real interval
sage: RIF("0x1?2p1", base=36)
Traceback (most recent call last):
...
TypeError: unable to convert '0x1?2p1' to real interval
"""
cdef mpz_t error_part
cdef mpfi_t error
cdef mpfr_t radius, neg_radius
cdef bytes int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string, optional_expo, tmp

match = NUMBER.fullmatch(s)
if match is None:
return 1
int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string = match.groups()

if (base > 10 or (base == 0 and base_prefix in (b'0X', b'0X'))) and e in (b'e', b'E'):
return 1
if base > 16 and e in (b'p', b'P'):
return 1
if base > 16 or not base_prefix:
base_prefix = b''

if error_string:
if mpz_init_set_str(error_part, error_string, 10):
mpz_clear(error_part)
return 1
else:
mpz_init_set_ui(error_part, 1)

optional_expo = e + sci_expo_string if e else b''
if mpfi_set_str(x, int_part_string + b'.' + frac_part_string + optional_expo, base):
mpz_clear(error_part)
return 1

mpfr_init2(radius, mpfi_get_prec(x))
tmp = base_prefix + (
b'0.' + b'0'*(len(frac_part_string)-1) + b'1' + optional_expo
if frac_part_string else
b'1.' + optional_expo)
# if base = 0:
# when s = '-0x123.456@7', tmp = '0x0.001@7'
# when s = '-0x123.@7', tmp = '0x1.@7'
# if base = 36:
# when s = '-0x123.456@7', tmp = '0.001@7'
if mpfr_set_str(radius, tmp, base, MPFR_RNDU):
mpfr_clear(radius)
mpz_clear(error_part)
return 1

mpfr_mul_z(radius, radius, error_part, MPFR_RNDU)
mpz_clear(error_part)

mpfr_init2(neg_radius, mpfi_get_prec(x))
mpfr_neg(neg_radius, radius, MPFR_RNDD)

mpfi_init2(error, mpfi_get_prec(x))
mpfi_interv_fr(error, neg_radius, radius)
mpfr_clear(radius)
mpfr_clear(neg_radius)

mpfi_add(x, x, error)
mpfi_clear(error)

return 0


cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
"""
Convert any object ``x`` to an MPFI interval or a pair of
Expand Down Expand Up @@ -186,6 +428,11 @@ cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
if isinstance(x, unicode):
x = x.encode("ascii")
if isinstance(x, bytes):
if b"?" in x:
if _from_str_question_style(re, (<bytes>x).replace(b' ', b''), base):
x = bytes_to_str(x)
raise TypeError(f"unable to convert {x!r} to real interval")
return return_real(im)
s = (<bytes>x).replace(b'..', b',').replace(b' ', b'').replace(b'+infinity', b'@inf@').replace(b'-infinity', b'-@inf@')
if mpfi_set_str(re, s, base):
x = bytes_to_str(x)
Expand Down

0 comments on commit c9e5c99

Please sign in to comment.