Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow construct RIF element from question-style string #38998

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sage/libs/mpfi/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cdef extern from "mpfi.h":
int mpfi_set_z(mpfi_ptr, mpz_t)
int mpfi_set_q(mpfi_ptr, mpq_t)
int mpfi_set_fr(mpfi_ptr, mpfr_srcptr)
int mpfi_set_str(mpfi_ptr, char *, int)
int mpfi_set_str(mpfi_ptr, const char *, int)

# combined initialization and assignment functions
int mpfi_init_set(mpfi_ptr, mpfi_srcptr)
Expand All @@ -36,7 +36,7 @@ cdef extern from "mpfi.h":
int mpfi_init_set_z(mpfi_ptr, mpz_srcptr)
int mpfi_init_set_q(mpfi_ptr, mpq_srcptr)
int mpfi_init_set_fr(mpfi_ptr, mpfr_srcptr)
int mpfi_init_set_str(mpfi_ptr, char *, int)
int mpfi_init_set_str(mpfi_ptr, const char *, int)

# swapping two intervals
void mpfi_swap(mpfi_ptr, mpfi_ptr)
Expand Down
4 changes: 2 additions & 2 deletions src/sage/libs/mpfr/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ cdef extern from "mpfr.h":
# int mpfr_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
int mpfr_set_ui_2exp(mpfr_t rop, unsigned long int op, mp_exp_t e, mpfr_rnd_t rnd)
int mpfr_set_si_2exp(mpfr_t rop, long int op, mp_exp_t e, mpfr_rnd_t rnd)
int mpfr_set_str(mpfr_t rop, char *s, int base, mpfr_rnd_t rnd)
int mpfr_set_str(mpfr_t rop, const char *s, int base, mpfr_rnd_t rnd)
int mpfr_strtofr(mpfr_t rop, char *nptr, char **endptr, int base, mpfr_rnd_t rnd)
void mpfr_set_inf(mpfr_t x, int sign)
void mpfr_set_nan(mpfr_t x)
Expand All @@ -43,7 +43,7 @@ cdef extern from "mpfr.h":
int mpfr_init_set_z(mpfr_t rop, mpz_t op, mpfr_rnd_t rnd)
int mpfr_init_set_q(mpfr_t rop, mpq_t op, mpfr_rnd_t rnd)
# int mpfr_init_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
int mpfr_init_set_str(mpfr_t x, char *s, int base, mpfr_rnd_t rnd)
int mpfr_init_set_str(mpfr_t x, const char *s, int base, mpfr_rnd_t rnd)

# Conversion Functions
double mpfr_get_d(mpfr_t op, mpfr_rnd_t rnd)
Expand Down
276 changes: 276 additions & 0 deletions src/sage/rings/convert/mpfi.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ Convert Sage/Python objects to real/complex intervals
# http://www.gnu.org/licenses/
#*****************************************************************************

import re

from cpython.float cimport PyFloat_AS_DOUBLE
from cpython.complex cimport PyComplex_RealAsDouble, PyComplex_ImagAsDouble

from libc.stdio cimport printf

from sage.libs.mpfr cimport *
from sage.libs.mpfi cimport *
from sage.libs.gmp.mpz cimport *
from sage.libs.gsl.complex cimport *

from sage.arith.long cimport integer_check_long
Expand Down Expand Up @@ -45,6 +50,243 @@ cdef inline int return_real(mpfi_ptr im) noexcept:
return 0


NUMBER = re.compile(rb'([+-]?(0[XxBb])?[0-9A-Za-z]+)\.([0-9A-Za-z]*)\?([0-9]*)(?:([EePp@])([+-]?[0-9]+))?')
# example: -0xABC.DEF?12@5
# match groups: (-0xABC) (0x) (DEF) (12) (@) (5)

cdef int _from_str_question_style(mpfi_ptr x, bytes s, int base) except -1:
"""
Convert a string in question style to an MPFI interval.

INPUT:

- ``x`` -- a pre-initialized MPFI interval

- ``s`` -- the string to convert

- ``base`` -- base to use for string conversion

OUTPUT:

- if conversion is possible: set ``x`` and return 0.

- in all other cases: return some nonzero value, or raise an exception.

TESTS:

Double check that ``ZZ``, ``RR`` and ``RIF`` follows the string
conversion rule for base different from `10` (except ``ZZ``
which only allows base up to `36`)::

sage: ZZ("0x123", base=0)
291
sage: RR("0x123.e1", base=0) # rel tol 1e-12
291.878906250000
sage: RR("0x123.@1", base=0) # rel tol 1e-12
4656.00000000000
sage: RIF("0x123.4@1", base=0)
4660
sage: ZZ("1Xx", base=36) # case insensitive
2517
sage: ZZ("1Xx", base=62)
Traceback (most recent call last):
...
ValueError: base (=62) must be 0 or between 2 and 36
sage: RR("1Xx", base=36) # rel tol 1e-12
2517.00000000000
sage: RR("0x123", base=36) # rel tol 1e-12
1.54101900000000e6
sage: RR("-1Xx@-1", base=62) # rel tol 1e-12
-95.9516129032258
sage: RIF("1Xx@-1", base=62) # rel tol 1e-12
95.95161290322580?
sage: RIF("1aE1", base=11)
Traceback (most recent call last):
...
TypeError: unable to convert '1aE1' to real interval
sage: RIF("1aE1", base=11)
Traceback (most recent call last):
...
TypeError: unable to convert '1aE1' to real interval

General checks::

sage: RIF("123456.?2").endpoints() # rel tol 1e-12
(123454.0, 123458.0)
sage: RIF("1234.56?2").endpoints() # rel tol 1e-12
(1234.54, 1234.58)
sage: RIF("1234.56?2e2").endpoints() # rel tol 1e-12
(123454.0, 123458.0)
sage: x = RIF("-1234.56?2e2"); x.endpoints() # rel tol 1e-12
(-123458.0, -123454.0)
sage: x
-1.2346?e5
sage: x.str(style="question", error_digits=1)
'-123456.?2'
sage: RIF("1.?100").endpoints() # rel tol 1e-12
(-99.0, 101.0)
sage: RIF("1.?100").str(style="question", error_digits=3)
'1.?100'

Large exponent (ensure precision is not lost)::

sage: x = RIF("1.123456?2e100000000"); x
1.12346?e100000000
sage: x.str(style="question", error_digits=3)
'1.12345600?201e100000000'

Large precision::

sage: F = RealIntervalField(1000)
sage: x = F(sqrt(2)); x.endpoints() # rel tol 1e-290
(1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140798,
1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140799)
sage: x in F(x.str(style="question", error_digits=3))
True
sage: x in F(x.str(style="question", error_digits=0))
True
sage: F("1.123456789123456789123456789123456789123456789123456789123456789123456789?987654321987654321987654321e500").endpoints() # rel tol 1e-290
(1.123456789123456789123456789123456789123456788135802467135802467135802468e500,
1.12345678912345678912345678912345678912345679011111111111111111111111111e500)

Stress test::

sage: for F in [RealIntervalField(15), RIF, RealIntervalField(100), RealIntervalField(1000)]:
....: for i in range(1000):
....: a, b = randint(-10^9, 10^9), randint(0, 50)
....: c, d = randint(-2^b, 2^b), randint(2, 5)
....: x = a * F(d)^c
....: assert x in F(x.str(style="question", error_digits=3)), (x, a, c, d)
....: assert x in F(x.str(style="question", error_digits=0)), (x, a, c, d)

Base different from `10` (note that the error and exponent are specified in decimal)::

sage: RIF("10000.?0", base=2).endpoints() # rel tol 1e-12
(16.0, 16.0)
sage: RIF("10000.?0e10", base=2).endpoints() # rel tol 1e-12
(16384.0, 16384.0)
sage: x = RIF("10000.?10", base=2); x.endpoints() # rel tol 1e-12
(6.0, 26.0)
sage: x.str(base=2, style="question", error_digits=2)
'10000.000?80'
sage: x = RIF("10000.000?80", base=2); x.endpoints() # rel tol 1e-12
(6.0, 26.0)
sage: x = RIF("12a.?", base=16); x.endpoints() # rel tol 1e-12
(297.0, 299.0)
sage: x = RIF("12a.BcDeF?", base=16); x.endpoints() # rel tol 1e-12
(298.737775802611, 298.737777709962)
sage: x = RIF("12a.BcDeF?@10", base=16); x.endpoints() # rel tol 1e-12
(3.28465658150911e14, 3.28465660248065e14)
sage: x = RIF("12a.BcDeF?p10", base=16); x.endpoints() # rel tol 1e-12
(305907.482421875, 305907.484375000)
sage: x = RIF("0x12a.BcDeF?p10", base=0); x.endpoints() # rel tol 1e-12
(305907.482421875, 305907.484375000)

Space is allowed::

sage: RIF("-1234.56?2").endpoints() # rel tol 1e-12
(-1234.58, -1234.54)
sage: RIF("- 1234.56 ?2").endpoints() # rel tol 1e-12
(-1234.58, -1234.54)

Erroneous input::

sage: RIF("1234.56?2e2.3")
Traceback (most recent call last):
...
TypeError: unable to convert '1234.56?2e2.3' to real interval
sage: RIF("1234?2") # decimal point required
Traceback (most recent call last):
...
TypeError: unable to convert '1234?2' to real interval
sage: RIF("1234.?2e")
Traceback (most recent call last):
...
TypeError: unable to convert '1234.?2e' to real interval
sage: RIF("1.?e999999999999999999999999")
[-infinity .. +infinity]
sage: RIF("0X1.?", base=33) # X is not valid digit in base 33
Traceback (most recent call last):
...
TypeError: unable to convert '0X1.?' to real interval
sage: RIF("1.a?1e10", base=12)
Traceback (most recent call last):
...
TypeError: unable to convert '1.a?1e10' to real interval
sage: RIF("1.1?a@10", base=12)
Traceback (most recent call last):
...
TypeError: unable to convert '1.1?a@10' to real interval
sage: RIF("0x1?2e1", base=0) # e is not allowed in base > 10, use @ instead
Traceback (most recent call last):
...
TypeError: unable to convert '0x1?2e1' to real interval
sage: RIF("0x1?2p1", base=36)
Traceback (most recent call last):
...
TypeError: unable to convert '0x1?2p1' to real interval
"""
cdef mpz_t error_part
cdef mpfi_t error
cdef mpfr_t radius, neg_radius
cdef bytes int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string, optional_expo, tmp

match = NUMBER.fullmatch(s)
if match is None:
return 1
int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string = match.groups()

if (base > 10 or (base == 0 and base_prefix in (b'0X', b'0X'))) and e in (b'e', b'E'):
return 1
if base > 16 and e in (b'p', b'P'):
return 1
if base > 16 or not base_prefix:
base_prefix = b''

if error_string:
if mpz_init_set_str(error_part, error_string, 10):
mpz_clear(error_part)
return 1
else:
mpz_init_set_ui(error_part, 1)

optional_expo = e + sci_expo_string if e else b''
if mpfi_set_str(x, int_part_string + b'.' + frac_part_string + optional_expo, base):
mpz_clear(error_part)
return 1

mpfr_init2(radius, mpfi_get_prec(x))
tmp = base_prefix + (
b'0.' + b'0'*(len(frac_part_string)-1) + b'1' + optional_expo
if frac_part_string else
b'1.' + optional_expo)
# if base = 0:
# when s = '-0x123.456@7', tmp = '0x0.001@7'
# when s = '-0x123.@7', tmp = '0x1.@7'
# if base = 36:
# when s = '-0x123.456@7', tmp = '0.001@7'
if mpfr_set_str(radius, tmp, base, MPFR_RNDU):
mpfr_clear(radius)
mpz_clear(error_part)
return 1

mpfr_mul_z(radius, radius, error_part, MPFR_RNDU)
mpz_clear(error_part)

mpfr_init2(neg_radius, mpfi_get_prec(x))
mpfr_neg(neg_radius, radius, MPFR_RNDD)

mpfi_init2(error, mpfi_get_prec(x))
mpfi_interv_fr(error, neg_radius, radius)
mpfr_clear(radius)
mpfr_clear(neg_radius)

mpfi_add(x, x, error)
mpfi_clear(error)

return 0


cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
"""
Convert any object ``x`` to an MPFI interval or a pair of
Expand Down Expand Up @@ -72,13 +314,42 @@ cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
imaginary component is 0.

- in all other cases: raise an exception.

TESTS::

sage: RIF('0xabc')
Traceback (most recent call last):
...
TypeError: unable to convert '0xabc' to real interval
sage: RIF("0x123.e1", base=0) # rel tol 1e-12
291.87890625000000?
sage: RIF("0x123.@1", base=0) # rel tol 1e-12
4656
sage: RIF("1Xx", base=36) # rel tol 1e-12
2517
sage: RIF("-1Xx@-10", base=62) # rel tol 1e-12
-7.088054920481391?e-15
sage: RIF("1", base=1)
Traceback (most recent call last):
...
ValueError: base (=1) must be 0 or between 2 and 62
sage: RIF("1", base=-1)
Traceback (most recent call last):
...
ValueError: base (=-1) must be 0 or between 2 and 62
sage: RIF("1", base=63)
Traceback (most recent call last):
...
ValueError: base (=63) must be 0 or between 2 and 62
"""
cdef RealIntervalFieldElement ri
cdef ComplexIntervalFieldElement zi
cdef ComplexNumber zn
cdef ComplexDoubleElement zd
cdef bytes s

if base != 0 and (base < 2 or base > 62):
raise ValueError(f"base (={base}) must be 0 or between 2 and 62")
if im is not NULL and isinstance(x, tuple):
# For complex numbers, interpret tuples as real/imag parts
if len(x) != 2:
Expand Down Expand Up @@ -157,6 +428,11 @@ cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
if isinstance(x, unicode):
x = x.encode("ascii")
if isinstance(x, bytes):
if b"?" in x:
if _from_str_question_style(re, (<bytes>x).replace(b' ', b''), base):
x = bytes_to_str(x)
raise TypeError(f"unable to convert {x!r} to real interval")
return return_real(im)
s = (<bytes>x).replace(b'..', b',').replace(b' ', b'').replace(b'+infinity', b'@inf@').replace(b'-infinity', b'-@inf@')
if mpfi_set_str(re, s, base):
x = bytes_to_str(x)
Expand Down
29 changes: 29 additions & 0 deletions src/sage/rings/real_mpfr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,33 @@ cdef class RealNumber(sage.structure.element.RingElement):

sage: RealNumber('1_3.1e-32_45')
1.31000000000000e-3244

Test conversion from base different from `10`::

sage: RR('0xabc')
Traceback (most recent call last):
...
TypeError: unable to convert '0xabc' to a real number
sage: RR("0x123.e1", base=0) # rel tol 1e-12
291.878906250000
sage: RR("0x123.@1", base=0) # rel tol 1e-12
4656.00000000000
sage: RR("1Xx", base=36) # rel tol 1e-12
2517.00000000000
sage: RR("-1Xx@-10", base=62) # rel tol 1e-12
-7.08805492048139e-15
sage: RR("1", base=1)
Traceback (most recent call last):
...
ValueError: base (=1) must be 0 or between 2 and 62
sage: RR("1", base=-1)
Traceback (most recent call last):
...
ValueError: base (=-1) must be 0 or between 2 and 62
sage: RR("1", base=63)
Traceback (most recent call last):
...
ValueError: base (=63) must be 0 or between 2 and 62
"""
if x is not None:
self._set(x, base)
Expand Down Expand Up @@ -1485,6 +1512,8 @@ cdef class RealNumber(sage.structure.element.RingElement):
# Real Numbers are supposed to be immutable.
cdef RealField_class parent
parent = self._parent
if base != 0 and (base < 2 or base > 62):
raise ValueError(f"base (={base}) must be 0 or between 2 and 62")
if isinstance(x, RealNumber):
if isinstance(x, RealLiteral):
s = (<RealLiteral>x).literal
Expand Down
Loading