Skip to content

Commit

Permalink
sagemathgh-38998: Allow construct RIF element from question-style string
Browse files Browse the repository at this point in the history
    
As in the title. It allows you to say e.g. `RIF("1.23?2e-5")`.

Partially handles sagemath#36797. (only
for real case. Complex case is not handled yet, but in principle it
should not be too difficult.)

(I don't see any disadvantage of allowing this, it's backwards
compatible)

Issue: currently

```
sage: RIF("10", base=37)
37
sage: ZZ("10", base=37)
[error]
```

should this inconsistency be fixed? If so how?

[Edit: actually the rule of conversion should probably follow [string to
mpfr conversion rule](https://www.mpfr.org/mpfr-current/mpfr.html#index-
mpfr_005fstrtofr) or [string to mpz conversion
rule](https://gmplib.org/manual/Assigning-Integers) ]

### 📝 Checklist

- [x] The title is concise and informative.
- [x] The description explains in detail what this PR is about.
- [x] I have linked a relevant issue or discussion.
- [x] I have created tests covering the changes.
- [x] I have updated the documentation and checked the documentation
preview. (there's no documentation change, but should we explicitly
mention the feature? I think the feature to construct from `[a..b]`
isn't explicitly mentioned either…?)

### ⌛ Dependencies

<!-- List all open PRs that this PR logically depends on. For example,
-->
<!-- - sagemath#12345: short description why this is a dependency -->
<!-- - sagemath#34567: ... -->


sagemath#39001
    
URL: sagemath#38998
Reported by: user202729
Reviewer(s): Travis Scrimshaw
  • Loading branch information
Release Manager committed Dec 12, 2024
2 parents efe5f64 + 805ecee commit 13951f2
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/sage/libs/mpfi/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cdef extern from "mpfi.h":
int mpfi_set_z(mpfi_ptr, mpz_t)
int mpfi_set_q(mpfi_ptr, mpq_t)
int mpfi_set_fr(mpfi_ptr, mpfr_srcptr)
int mpfi_set_str(mpfi_ptr, char *, int)
int mpfi_set_str(mpfi_ptr, const char *, int)

# combined initialization and assignment functions
int mpfi_init_set(mpfi_ptr, mpfi_srcptr)
Expand All @@ -36,7 +36,7 @@ cdef extern from "mpfi.h":
int mpfi_init_set_z(mpfi_ptr, mpz_srcptr)
int mpfi_init_set_q(mpfi_ptr, mpq_srcptr)
int mpfi_init_set_fr(mpfi_ptr, mpfr_srcptr)
int mpfi_init_set_str(mpfi_ptr, char *, int)
int mpfi_init_set_str(mpfi_ptr, const char *, int)

# swapping two intervals
void mpfi_swap(mpfi_ptr, mpfi_ptr)
Expand Down
4 changes: 2 additions & 2 deletions src/sage/libs/mpfr/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ cdef extern from "mpfr.h":
# int mpfr_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
int mpfr_set_ui_2exp(mpfr_t rop, unsigned long int op, mp_exp_t e, mpfr_rnd_t rnd)
int mpfr_set_si_2exp(mpfr_t rop, long int op, mp_exp_t e, mpfr_rnd_t rnd)
int mpfr_set_str(mpfr_t rop, char *s, int base, mpfr_rnd_t rnd)
int mpfr_set_str(mpfr_t rop, const char *s, int base, mpfr_rnd_t rnd)
int mpfr_strtofr(mpfr_t rop, char *nptr, char **endptr, int base, mpfr_rnd_t rnd)
void mpfr_set_inf(mpfr_t x, int sign)
void mpfr_set_nan(mpfr_t x)
Expand All @@ -43,7 +43,7 @@ cdef extern from "mpfr.h":
int mpfr_init_set_z(mpfr_t rop, mpz_t op, mpfr_rnd_t rnd)
int mpfr_init_set_q(mpfr_t rop, mpq_t op, mpfr_rnd_t rnd)
# int mpfr_init_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
int mpfr_init_set_str(mpfr_t x, char *s, int base, mpfr_rnd_t rnd)
int mpfr_init_set_str(mpfr_t x, const char *s, int base, mpfr_rnd_t rnd)

# Conversion Functions
double mpfr_get_d(mpfr_t op, mpfr_rnd_t rnd)
Expand Down
247 changes: 247 additions & 0 deletions src/sage/rings/convert/mpfi.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ Convert Sage/Python objects to real/complex intervals
# http://www.gnu.org/licenses/
#*****************************************************************************

import re

from cpython.float cimport PyFloat_AS_DOUBLE
from cpython.complex cimport PyComplex_RealAsDouble, PyComplex_ImagAsDouble

from libc.stdio cimport printf

from sage.libs.mpfr cimport *
from sage.libs.mpfi cimport *
from sage.libs.gmp.mpz cimport *
from sage.libs.gsl.complex cimport *

from sage.arith.long cimport integer_check_long
Expand Down Expand Up @@ -45,6 +50,243 @@ cdef inline int return_real(mpfi_ptr im) noexcept:
return 0


NUMBER = re.compile(rb'([+-]?(0[XxBb])?[0-9A-Za-z]+)\.([0-9A-Za-z]*)\?([0-9]*)(?:([EePp@])([+-]?[0-9]+))?')
# example: -0xABC.DEF?12@5
# match groups: (-0xABC) (0x) (DEF) (12) (@) (5)

cdef int _from_str_question_style(mpfi_ptr x, bytes s, int base) except -1:
"""
Convert a string in question style to an MPFI interval.
INPUT:
- ``x`` -- a pre-initialized MPFI interval
- ``s`` -- the string to convert
- ``base`` -- base to use for string conversion
OUTPUT:
- if conversion is possible: set ``x`` and return 0.
- in all other cases: return some nonzero value, or raise an exception.
TESTS:
Double check that ``ZZ``, ``RR`` and ``RIF`` follows the string
conversion rule for base different from `10` (except ``ZZ``
which only allows base up to `36`)::
sage: ZZ("0x123", base=0)
291
sage: RR("0x123.e1", base=0) # rel tol 1e-12
291.878906250000
sage: RR("0x123.@1", base=0) # rel tol 1e-12
4656.00000000000
sage: RIF("0x123.4@1", base=0)
4660
sage: ZZ("1Xx", base=36) # case insensitive
2517
sage: ZZ("1Xx", base=62)
Traceback (most recent call last):
...
ValueError: base (=62) must be 0 or between 2 and 36
sage: RR("1Xx", base=36) # rel tol 1e-12
2517.00000000000
sage: RR("0x123", base=36) # rel tol 1e-12
1.54101900000000e6
sage: RR("-1Xx@-1", base=62) # rel tol 1e-12
-95.9516129032258
sage: RIF("1Xx@-1", base=62) # rel tol 1e-12
95.95161290322580?
sage: RIF("1aE1", base=11)
Traceback (most recent call last):
...
TypeError: unable to convert '1aE1' to real interval
sage: RIF("1aE1", base=11)
Traceback (most recent call last):
...
TypeError: unable to convert '1aE1' to real interval
General checks::
sage: RIF("123456.?2").endpoints() # rel tol 1e-12
(123454.0, 123458.0)
sage: RIF("1234.56?2").endpoints() # rel tol 1e-12
(1234.54, 1234.58)
sage: RIF("1234.56?2e2").endpoints() # rel tol 1e-12
(123454.0, 123458.0)
sage: x = RIF("-1234.56?2e2"); x.endpoints() # rel tol 1e-12
(-123458.0, -123454.0)
sage: x
-1.2346?e5
sage: x.str(style="question", error_digits=1)
'-123456.?2'
sage: RIF("1.?100").endpoints() # rel tol 1e-12
(-99.0, 101.0)
sage: RIF("1.?100").str(style="question", error_digits=3)
'1.?100'
Large exponent (ensure precision is not lost)::
sage: x = RIF("1.123456?2e100000000"); x
1.12346?e100000000
sage: x.str(style="question", error_digits=3)
'1.12345600?201e100000000'
Large precision::
sage: F = RealIntervalField(1000)
sage: x = F(sqrt(2)); x.endpoints() # rel tol 1e-290
(1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140798,
1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140799)
sage: x in F(x.str(style="question", error_digits=3))
True
sage: x in F(x.str(style="question", error_digits=0))
True
sage: F("1.123456789123456789123456789123456789123456789123456789123456789123456789?987654321987654321987654321e500").endpoints() # rel tol 1e-290
(1.123456789123456789123456789123456789123456788135802467135802467135802468e500,
1.12345678912345678912345678912345678912345679011111111111111111111111111e500)
Stress test::
sage: for F in [RealIntervalField(15), RIF, RealIntervalField(100), RealIntervalField(1000)]:
....: for i in range(1000):
....: a, b = randint(-10^9, 10^9), randint(0, 50)
....: c, d = randint(-2^b, 2^b), randint(2, 5)
....: x = a * F(d)^c
....: assert x in F(x.str(style="question", error_digits=3)), (x, a, c, d)
....: assert x in F(x.str(style="question", error_digits=0)), (x, a, c, d)
Base different from `10` (note that the error and exponent are specified in decimal)::
sage: RIF("10000.?0", base=2).endpoints() # rel tol 1e-12
(16.0, 16.0)
sage: RIF("10000.?0e10", base=2).endpoints() # rel tol 1e-12
(16384.0, 16384.0)
sage: x = RIF("10000.?10", base=2); x.endpoints() # rel tol 1e-12
(6.0, 26.0)
sage: x.str(base=2, style="question", error_digits=2)
'10000.000?80'
sage: x = RIF("10000.000?80", base=2); x.endpoints() # rel tol 1e-12
(6.0, 26.0)
sage: x = RIF("12a.?", base=16); x.endpoints() # rel tol 1e-12
(297.0, 299.0)
sage: x = RIF("12a.BcDeF?", base=16); x.endpoints() # rel tol 1e-12
(298.737775802611, 298.737777709962)
sage: x = RIF("12a.BcDeF?@10", base=16); x.endpoints() # rel tol 1e-12
(3.28465658150911e14, 3.28465660248065e14)
sage: x = RIF("12a.BcDeF?p10", base=16); x.endpoints() # rel tol 1e-12
(305907.482421875, 305907.484375000)
sage: x = RIF("0x12a.BcDeF?p10", base=0); x.endpoints() # rel tol 1e-12
(305907.482421875, 305907.484375000)
Space is allowed::
sage: RIF("-1234.56?2").endpoints() # rel tol 1e-12
(-1234.58, -1234.54)
sage: RIF("- 1234.56 ?2").endpoints() # rel tol 1e-12
(-1234.58, -1234.54)
Erroneous input::
sage: RIF("1234.56?2e2.3")
Traceback (most recent call last):
...
TypeError: unable to convert '1234.56?2e2.3' to real interval
sage: RIF("1234?2") # decimal point required
Traceback (most recent call last):
...
TypeError: unable to convert '1234?2' to real interval
sage: RIF("1234.?2e")
Traceback (most recent call last):
...
TypeError: unable to convert '1234.?2e' to real interval
sage: RIF("1.?e999999999999999999999999")
[-infinity .. +infinity]
sage: RIF("0X1.?", base=33) # X is not valid digit in base 33
Traceback (most recent call last):
...
TypeError: unable to convert '0X1.?' to real interval
sage: RIF("1.a?1e10", base=12)
Traceback (most recent call last):
...
TypeError: unable to convert '1.a?1e10' to real interval
sage: RIF("1.1?a@10", base=12)
Traceback (most recent call last):
...
TypeError: unable to convert '1.1?a@10' to real interval
sage: RIF("0x1?2e1", base=0) # e is not allowed in base > 10, use @ instead
Traceback (most recent call last):
...
TypeError: unable to convert '0x1?2e1' to real interval
sage: RIF("0x1?2p1", base=36)
Traceback (most recent call last):
...
TypeError: unable to convert '0x1?2p1' to real interval
"""
cdef mpz_t error_part
cdef mpfi_t error
cdef mpfr_t radius, neg_radius
cdef bytes int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string, optional_expo, tmp

match = NUMBER.fullmatch(s)
if match is None:
return 1
int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string = match.groups()

if (base > 10 or (base == 0 and base_prefix in (b'0X', b'0X'))) and e in (b'e', b'E'):
return 1
if base > 16 and e in (b'p', b'P'):
return 1
if base > 16 or not base_prefix:
base_prefix = b''

if error_string:
if mpz_init_set_str(error_part, error_string, 10):
mpz_clear(error_part)
return 1
else:
mpz_init_set_ui(error_part, 1)

optional_expo = e + sci_expo_string if e else b''
if mpfi_set_str(x, int_part_string + b'.' + frac_part_string + optional_expo, base):
mpz_clear(error_part)
return 1

mpfr_init2(radius, mpfi_get_prec(x))
tmp = base_prefix + (
b'0.' + b'0'*(len(frac_part_string)-1) + b'1' + optional_expo
if frac_part_string else
b'1.' + optional_expo)
# if base = 0:
# when s = '-0x123.456@7', tmp = '0x0.001@7'
# when s = '-0x123.@7', tmp = '0x1.@7'
# if base = 36:
# when s = '-0x123.456@7', tmp = '0.001@7'
if mpfr_set_str(radius, tmp, base, MPFR_RNDU):
mpfr_clear(radius)
mpz_clear(error_part)
return 1

mpfr_mul_z(radius, radius, error_part, MPFR_RNDU)
mpz_clear(error_part)

mpfr_init2(neg_radius, mpfi_get_prec(x))
mpfr_neg(neg_radius, radius, MPFR_RNDD)

mpfi_init2(error, mpfi_get_prec(x))
mpfi_interv_fr(error, neg_radius, radius)
mpfr_clear(radius)
mpfr_clear(neg_radius)

mpfi_add(x, x, error)
mpfi_clear(error)

return 0


cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
"""
Convert any object ``x`` to an MPFI interval or a pair of
Expand Down Expand Up @@ -186,6 +428,11 @@ cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
if isinstance(x, unicode):
x = x.encode("ascii")
if isinstance(x, bytes):
if b"?" in x:
if _from_str_question_style(re, (<bytes>x).replace(b' ', b''), base):
x = bytes_to_str(x)
raise TypeError(f"unable to convert {x!r} to real interval")
return return_real(im)
s = (<bytes>x).replace(b'..', b',').replace(b' ', b'').replace(b'+infinity', b'@inf@').replace(b'-infinity', b'-@inf@')
if mpfi_set_str(re, s, base):
x = bytes_to_str(x)
Expand Down

0 comments on commit 13951f2

Please sign in to comment.